Statements execution with context.

This commit is contained in:
go-jet 2019-06-20 12:22:19 +02:00
parent cdfd8f1dff
commit 1ac324e198
11 changed files with 146 additions and 28 deletions

View file

@ -1,6 +1,7 @@
package sqlbuilder
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/sqlbuilder/execution"
@ -75,6 +76,14 @@ func (d *deleteStatementImpl) Query(db execution.Db, destination interface{}) er
return Query(d, db, destination)
}
func (d *deleteStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(d, db, context, destination)
}
func (d *deleteStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(d, db)
}
func (d *deleteStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(d, db, context)
}

View file

@ -1,8 +1,13 @@
package execution
import "database/sql"
import (
"context"
"database/sql"
)
type Db interface {
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}

View file

@ -1,6 +1,7 @@
package execution
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
@ -12,7 +13,7 @@ import (
"time"
)
func Query(db Db, query string, args []interface{}, destinationPtr interface{}) error {
func Query(db Db, context context.Context, query string, args []interface{}, destinationPtr interface{}) error {
if destinationPtr == nil {
return errors.New("Destination is nil. ")
@ -24,12 +25,12 @@ func Query(db Db, query string, args []interface{}, destinationPtr interface{})
}
if destinationPtrType.Elem().Kind() == reflect.Slice {
return queryToSlice(db, query, args, destinationPtr)
return queryToSlice(db, context, query, args, destinationPtr)
} else if destinationPtrType.Elem().Kind() == reflect.Struct {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem()
err := queryToSlice(db, query, args, tempSlicePtrValue.Interface())
err := queryToSlice(db, context, query, args, tempSlicePtrValue.Interface())
if err != nil {
return err
@ -53,7 +54,7 @@ func Query(db Db, query string, args []interface{}, destinationPtr interface{})
}
}
func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{}) error {
func queryToSlice(db Db, ctx context.Context, query string, args []interface{}, slicePtr interface{}) error {
if db == nil {
return errors.New("db is nil")
}
@ -67,7 +68,11 @@ func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{})
return errors.New("Destination has to be a pointer to slice. ")
}
rows, err := db.Query(query, args...)
if ctx == nil {
ctx = context.Background()
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return err

View file

@ -1,6 +1,7 @@
package sqlbuilder
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/sqlbuilder/execution"
@ -149,6 +150,14 @@ func (i *insertStatementImpl) Query(db execution.Db, destination interface{}) er
return Query(i, db, destination)
}
func (i *insertStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(i, db, context, destination)
}
func (i *insertStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(i, db)
}
func (i *insertStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(i, db, context)
}

View file

@ -1,6 +1,7 @@
package sqlbuilder
import (
"context"
"database/sql"
"github.com/go-jet/jet/sqlbuilder/execution"
"github.com/pkg/errors"
@ -96,6 +97,14 @@ func (l *lockStatementImpl) Query(db execution.Db, destination interface{}) erro
return Query(l, db, destination)
}
func (l *lockStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(l, db, context, destination)
}
func (l *lockStatementImpl) Exec(db execution.Db) (sql.Result, error) {
return Exec(l, db)
}
func (l *lockStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(l, db, context)
}

View file

@ -1,6 +1,7 @@
package sqlbuilder
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/sqlbuilder/execution"
@ -294,6 +295,14 @@ func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) er
return Query(s, db, destination)
}
func (s *selectStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(s, db, context, destination)
}
func (s *selectStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(s, db)
}
func (s *selectStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(s, db, context)
}

View file

@ -1,6 +1,7 @@
package sqlbuilder
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/sqlbuilder/execution"
@ -206,6 +207,14 @@ func (s *setStatementImpl) Query(db execution.Db, destination interface{}) error
return Query(s, db, destination)
}
func (s *setStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(s, db, context, destination)
}
func (s *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(s, db)
}
func (s *setStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(s, db, context)
}

View file

@ -1,6 +1,7 @@
package sqlbuilder
import (
"context"
"database/sql"
"github.com/go-jet/jet/sqlbuilder/execution"
"strconv"
@ -14,7 +15,10 @@ type Statement interface {
DebugSql() (query string, err error)
Query(db execution.Db, destination interface{}) error
QueryContext(db execution.Db, context context.Context, destination interface{}) error
Exec(db execution.Db) (sql.Result, error)
ExecContext(db execution.Db, context context.Context) (sql.Result, error)
}
func DebugSql(statement Statement) (string, error) {
@ -33,3 +37,43 @@ func DebugSql(statement Statement) (string, error) {
return debugSql, nil
}
func Query(statement Statement, db execution.Db, destination interface{}) error {
query, args, err := statement.Sql()
if err != nil {
return err
}
return execution.Query(db, context.Background(), query, args, destination)
}
func QueryContext(statement Statement, db execution.Db, context context.Context, destination interface{}) error {
query, args, err := statement.Sql()
if err != nil {
return err
}
return execution.Query(db, context, query, args, destination)
}
func Exec(statement Statement, db execution.Db) (res sql.Result, err error) {
query, args, err := statement.Sql()
if err != nil {
return
}
return db.Exec(query, args...)
}
func ExecContext(statement Statement, db execution.Db, context context.Context) (res sql.Result, err error) {
query, args, err := statement.Sql()
if err != nil {
return
}
return db.ExecContext(context, query, args...)
}

View file

@ -1,6 +1,7 @@
package sqlbuilder
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/sqlbuilder/execution"
@ -142,6 +143,14 @@ func (u *updateStatementImpl) Query(db execution.Db, destination interface{}) er
return Query(u, db, destination)
}
func (u *updateStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(u, db, context, destination)
}
func (u *updateStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(u, db)
}
func (u *updateStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(u, db, context)
}

View file

@ -1,9 +1,7 @@
package sqlbuilder
import (
"database/sql"
"errors"
"github.com/go-jet/jet/sqlbuilder/execution"
"github.com/serenize/snaker"
"reflect"
)
@ -185,23 +183,3 @@ func mustBe(v reflect.Value, expected reflect.Kind) {
panic("argument mismatch: expected " + expected.String() + ", got " + v.Type().String())
}
}
func Query(statement Statement, db execution.Db, destination interface{}) error {
query, args, err := statement.Sql()
if err != nil {
return err
}
return execution.Query(db, query, args, destination)
}
func Exec(statement Statement, db execution.Db) (res sql.Result, err error) {
query, args, err := statement.Sql()
if err != nil {
return
}
return db.Exec(query, args...)
}

View file

@ -1,6 +1,7 @@
package tests
import (
"context"
"encoding/json"
"fmt"
"github.com/davecgh/go-spew/spew"
@ -10,6 +11,7 @@ import (
"gotest.tools/assert"
"io/ioutil"
"testing"
"time"
)
func TestSelect(t *testing.T) {
@ -152,6 +154,36 @@ ORDER BY "Album.AlbumId";
assert.DeepEqual(t, dest[1], album2)
}
func TestQueryWithContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
dest := []model.Album{}
err := Album.
CROSS_JOIN(Track).
CROSS_JOIN(InvoiceLine).
SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
QueryContext(db, ctx, &dest)
assert.Error(t, err, "context deadline exceeded")
}
func TestExecWithContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err := Album.
CROSS_JOIN(Track).
CROSS_JOIN(InvoiceLine).
SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
ExecContext(db, ctx)
assert.Error(t, err, "pq: canceling statement due to user request")
}
func TestSubQueriesForQuotedNames(t *testing.T) {
first10Artist := Artist.
SELECT(Artist.AllColumns).