Statements execution with context.
This commit is contained in:
parent
cdfd8f1dff
commit
1ac324e198
11 changed files with 146 additions and 28 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
package sqlbuilder
|
package sqlbuilder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||||
|
|
@ -75,6 +76,14 @@ func (d *deleteStatementImpl) Query(db execution.Db, destination interface{}) er
|
||||||
return Query(d, db, destination)
|
return Query(d, db, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *deleteStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||||
|
return QueryContext(d, db, context, destination)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *deleteStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
func (d *deleteStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
||||||
return Exec(d, db)
|
return Exec(d, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *deleteStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||||
|
return ExecContext(d, db, context)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,13 @@
|
||||||
package execution
|
package execution
|
||||||
|
|
||||||
import "database/sql"
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
type Db interface {
|
type Db interface {
|
||||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||||
|
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||||
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package execution
|
package execution
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
@ -12,7 +13,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Query(db Db, query string, args []interface{}, destinationPtr interface{}) error {
|
func Query(db Db, context context.Context, query string, args []interface{}, destinationPtr interface{}) error {
|
||||||
|
|
||||||
if destinationPtr == nil {
|
if destinationPtr == nil {
|
||||||
return errors.New("Destination is nil. ")
|
return errors.New("Destination is nil. ")
|
||||||
|
|
@ -24,12 +25,12 @@ func Query(db Db, query string, args []interface{}, destinationPtr interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
if destinationPtrType.Elem().Kind() == reflect.Slice {
|
if destinationPtrType.Elem().Kind() == reflect.Slice {
|
||||||
return queryToSlice(db, query, args, destinationPtr)
|
return queryToSlice(db, context, query, args, destinationPtr)
|
||||||
} else if destinationPtrType.Elem().Kind() == reflect.Struct {
|
} else if destinationPtrType.Elem().Kind() == reflect.Struct {
|
||||||
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
|
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
|
||||||
tempSliceValue := tempSlicePtrValue.Elem()
|
tempSliceValue := tempSlicePtrValue.Elem()
|
||||||
|
|
||||||
err := queryToSlice(db, query, args, tempSlicePtrValue.Interface())
|
err := queryToSlice(db, context, query, args, tempSlicePtrValue.Interface())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -53,7 +54,7 @@ func Query(db Db, query string, args []interface{}, destinationPtr interface{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{}) error {
|
func queryToSlice(db Db, ctx context.Context, query string, args []interface{}, slicePtr interface{}) error {
|
||||||
if db == nil {
|
if db == nil {
|
||||||
return errors.New("db is nil")
|
return errors.New("db is nil")
|
||||||
}
|
}
|
||||||
|
|
@ -67,7 +68,11 @@ func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{})
|
||||||
return errors.New("Destination has to be a pointer to slice. ")
|
return errors.New("Destination has to be a pointer to slice. ")
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.Query(query, args...)
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.QueryContext(ctx, query, args...)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package sqlbuilder
|
package sqlbuilder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||||
|
|
@ -149,6 +150,14 @@ func (i *insertStatementImpl) Query(db execution.Db, destination interface{}) er
|
||||||
return Query(i, db, destination)
|
return Query(i, db, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *insertStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||||
|
return QueryContext(i, db, context, destination)
|
||||||
|
}
|
||||||
|
|
||||||
func (i *insertStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
func (i *insertStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
||||||
return Exec(i, db)
|
return Exec(i, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *insertStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||||
|
return ExecContext(i, db, context)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package sqlbuilder
|
package sqlbuilder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
@ -96,6 +97,14 @@ func (l *lockStatementImpl) Query(db execution.Db, destination interface{}) erro
|
||||||
return Query(l, db, destination)
|
return Query(l, db, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *lockStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||||
|
return QueryContext(l, db, context, destination)
|
||||||
|
}
|
||||||
|
|
||||||
func (l *lockStatementImpl) Exec(db execution.Db) (sql.Result, error) {
|
func (l *lockStatementImpl) Exec(db execution.Db) (sql.Result, error) {
|
||||||
return Exec(l, db)
|
return Exec(l, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *lockStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||||
|
return ExecContext(l, db, context)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package sqlbuilder
|
package sqlbuilder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||||
|
|
@ -294,6 +295,14 @@ func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) er
|
||||||
return Query(s, db, destination)
|
return Query(s, db, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *selectStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||||
|
return QueryContext(s, db, context, destination)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *selectStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
func (s *selectStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
||||||
return Exec(s, db)
|
return Exec(s, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *selectStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||||
|
return ExecContext(s, db, context)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package sqlbuilder
|
package sqlbuilder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||||
|
|
@ -206,6 +207,14 @@ func (s *setStatementImpl) Query(db execution.Db, destination interface{}) error
|
||||||
return Query(s, db, destination)
|
return Query(s, db, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *setStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||||
|
return QueryContext(s, db, context, destination)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
func (s *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
||||||
return Exec(s, db)
|
return Exec(s, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *setStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||||
|
return ExecContext(s, db, context)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package sqlbuilder
|
package sqlbuilder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
@ -14,7 +15,10 @@ type Statement interface {
|
||||||
DebugSql() (query string, err error)
|
DebugSql() (query string, err error)
|
||||||
|
|
||||||
Query(db execution.Db, destination interface{}) error
|
Query(db execution.Db, destination interface{}) error
|
||||||
|
QueryContext(db execution.Db, context context.Context, destination interface{}) error
|
||||||
|
|
||||||
Exec(db execution.Db) (sql.Result, error)
|
Exec(db execution.Db) (sql.Result, error)
|
||||||
|
ExecContext(db execution.Db, context context.Context) (sql.Result, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func DebugSql(statement Statement) (string, error) {
|
func DebugSql(statement Statement) (string, error) {
|
||||||
|
|
@ -33,3 +37,43 @@ func DebugSql(statement Statement) (string, error) {
|
||||||
|
|
||||||
return debugSql, nil
|
return debugSql, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Query(statement Statement, db execution.Db, destination interface{}) error {
|
||||||
|
query, args, err := statement.Sql()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return execution.Query(db, context.Background(), query, args, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func QueryContext(statement Statement, db execution.Db, context context.Context, destination interface{}) error {
|
||||||
|
query, args, err := statement.Sql()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return execution.Query(db, context, query, args, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Exec(statement Statement, db execution.Db) (res sql.Result, err error) {
|
||||||
|
query, args, err := statement.Sql()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.Exec(query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExecContext(statement Statement, db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||||
|
query, args, err := statement.Sql()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.ExecContext(context, query, args...)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package sqlbuilder
|
package sqlbuilder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||||
|
|
@ -142,6 +143,14 @@ func (u *updateStatementImpl) Query(db execution.Db, destination interface{}) er
|
||||||
return Query(u, db, destination)
|
return Query(u, db, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *updateStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||||
|
return QueryContext(u, db, context, destination)
|
||||||
|
}
|
||||||
|
|
||||||
func (u *updateStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
func (u *updateStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
||||||
return Exec(u, db)
|
return Exec(u, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *updateStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||||
|
return ExecContext(u, db, context)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
package sqlbuilder
|
package sqlbuilder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
|
||||||
"github.com/serenize/snaker"
|
"github.com/serenize/snaker"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
@ -185,23 +183,3 @@ func mustBe(v reflect.Value, expected reflect.Kind) {
|
||||||
panic("argument mismatch: expected " + expected.String() + ", got " + v.Type().String())
|
panic("argument mismatch: expected " + expected.String() + ", got " + v.Type().String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Query(statement Statement, db execution.Db, destination interface{}) error {
|
|
||||||
query, args, err := statement.Sql()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return execution.Query(db, query, args, destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Exec(statement Statement, db execution.Db) (res sql.Result, err error) {
|
|
||||||
query, args, err := statement.Sql()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return db.Exec(query, args...)
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package tests
|
package tests
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
|
@ -10,6 +11,7 @@ import (
|
||||||
"gotest.tools/assert"
|
"gotest.tools/assert"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSelect(t *testing.T) {
|
func TestSelect(t *testing.T) {
|
||||||
|
|
@ -152,6 +154,36 @@ ORDER BY "Album.AlbumId";
|
||||||
assert.DeepEqual(t, dest[1], album2)
|
assert.DeepEqual(t, dest[1], album2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryWithContext(t *testing.T) {
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
dest := []model.Album{}
|
||||||
|
|
||||||
|
err := Album.
|
||||||
|
CROSS_JOIN(Track).
|
||||||
|
CROSS_JOIN(InvoiceLine).
|
||||||
|
SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
|
||||||
|
QueryContext(db, ctx, &dest)
|
||||||
|
|
||||||
|
assert.Error(t, err, "context deadline exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecWithContext(t *testing.T) {
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := Album.
|
||||||
|
CROSS_JOIN(Track).
|
||||||
|
CROSS_JOIN(InvoiceLine).
|
||||||
|
SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
|
||||||
|
ExecContext(db, ctx)
|
||||||
|
|
||||||
|
assert.Error(t, err, "pq: canceling statement due to user request")
|
||||||
|
}
|
||||||
|
|
||||||
func TestSubQueriesForQuotedNames(t *testing.T) {
|
func TestSubQueriesForQuotedNames(t *testing.T) {
|
||||||
first10Artist := Artist.
|
first10Artist := Artist.
|
||||||
SELECT(Artist.AllColumns).
|
SELECT(Artist.AllColumns).
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue