diff --git a/sqlbuilder/delete_statement.go b/sqlbuilder/delete_statement.go index 1c38783..868a579 100644 --- a/sqlbuilder/delete_statement.go +++ b/sqlbuilder/delete_statement.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "errors" "github.com/go-jet/jet/sqlbuilder/execution" @@ -75,6 +76,14 @@ func (d *deleteStatementImpl) Query(db execution.Db, destination interface{}) er return Query(d, db, destination) } +func (d *deleteStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { + return QueryContext(d, db, context, destination) +} + func (d *deleteStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { return Exec(d, db) } + +func (d *deleteStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { + return ExecContext(d, db, context) +} diff --git a/sqlbuilder/execution/db.go b/sqlbuilder/execution/db.go index 61747f1..ab55d2d 100644 --- a/sqlbuilder/execution/db.go +++ b/sqlbuilder/execution/db.go @@ -1,8 +1,13 @@ package execution -import "database/sql" +import ( + "context" + "database/sql" +) type Db interface { Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index 6609902..843d6a8 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -1,6 +1,7 @@ package execution import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -12,7 +13,7 @@ import ( "time" ) -func Query(db Db, query string, args []interface{}, destinationPtr interface{}) error { +func Query(db Db, context context.Context, query string, args []interface{}, destinationPtr interface{}) error { if destinationPtr == nil { return errors.New("Destination is nil. ") @@ -24,12 +25,12 @@ func Query(db Db, query string, args []interface{}, destinationPtr interface{}) } if destinationPtrType.Elem().Kind() == reflect.Slice { - return queryToSlice(db, query, args, destinationPtr) + return queryToSlice(db, context, query, args, destinationPtr) } else if destinationPtrType.Elem().Kind() == reflect.Struct { tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSliceValue := tempSlicePtrValue.Elem() - err := queryToSlice(db, query, args, tempSlicePtrValue.Interface()) + err := queryToSlice(db, context, query, args, tempSlicePtrValue.Interface()) if err != nil { return err @@ -53,7 +54,7 @@ func Query(db Db, query string, args []interface{}, destinationPtr interface{}) } } -func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{}) error { +func queryToSlice(db Db, ctx context.Context, query string, args []interface{}, slicePtr interface{}) error { if db == nil { return errors.New("db is nil") } @@ -67,7 +68,11 @@ func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{}) return errors.New("Destination has to be a pointer to slice. ") } - rows, err := db.Query(query, args...) + if ctx == nil { + ctx = context.Background() + } + + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return err diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index b97289e..2991083 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "errors" "github.com/go-jet/jet/sqlbuilder/execution" @@ -149,6 +150,14 @@ func (i *insertStatementImpl) Query(db execution.Db, destination interface{}) er return Query(i, db, destination) } +func (i *insertStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { + return QueryContext(i, db, context, destination) +} + func (i *insertStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { return Exec(i, db) } + +func (i *insertStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { + return ExecContext(i, db, context) +} diff --git a/sqlbuilder/lock_statement.go b/sqlbuilder/lock_statement.go index 4d960ad..cd6894d 100644 --- a/sqlbuilder/lock_statement.go +++ b/sqlbuilder/lock_statement.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "github.com/go-jet/jet/sqlbuilder/execution" "github.com/pkg/errors" @@ -96,6 +97,14 @@ func (l *lockStatementImpl) Query(db execution.Db, destination interface{}) erro return Query(l, db, destination) } +func (l *lockStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { + return QueryContext(l, db, context, destination) +} + func (l *lockStatementImpl) Exec(db execution.Db) (sql.Result, error) { return Exec(l, db) } + +func (l *lockStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { + return ExecContext(l, db, context) +} diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 4a0bfec..b273d44 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "errors" "github.com/go-jet/jet/sqlbuilder/execution" @@ -294,6 +295,14 @@ func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) er return Query(s, db, destination) } +func (s *selectStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { + return QueryContext(s, db, context, destination) +} + func (s *selectStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { return Exec(s, db) } + +func (s *selectStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { + return ExecContext(s, db, context) +} diff --git a/sqlbuilder/set_statement.go b/sqlbuilder/set_statement.go index 4f84be8..44cd24b 100644 --- a/sqlbuilder/set_statement.go +++ b/sqlbuilder/set_statement.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "errors" "github.com/go-jet/jet/sqlbuilder/execution" @@ -206,6 +207,14 @@ func (s *setStatementImpl) Query(db execution.Db, destination interface{}) error return Query(s, db, destination) } +func (s *setStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { + return QueryContext(s, db, context, destination) +} + func (s *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { return Exec(s, db) } + +func (s *setStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { + return ExecContext(s, db, context) +} diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go index 6e26919..9e806ac 100644 --- a/sqlbuilder/statement.go +++ b/sqlbuilder/statement.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "github.com/go-jet/jet/sqlbuilder/execution" "strconv" @@ -14,7 +15,10 @@ type Statement interface { DebugSql() (query string, err error) Query(db execution.Db, destination interface{}) error + QueryContext(db execution.Db, context context.Context, destination interface{}) error + Exec(db execution.Db) (sql.Result, error) + ExecContext(db execution.Db, context context.Context) (sql.Result, error) } func DebugSql(statement Statement) (string, error) { @@ -33,3 +37,43 @@ func DebugSql(statement Statement) (string, error) { return debugSql, nil } + +func Query(statement Statement, db execution.Db, destination interface{}) error { + query, args, err := statement.Sql() + + if err != nil { + return err + } + + return execution.Query(db, context.Background(), query, args, destination) +} + +func QueryContext(statement Statement, db execution.Db, context context.Context, destination interface{}) error { + query, args, err := statement.Sql() + + if err != nil { + return err + } + + return execution.Query(db, context, query, args, destination) +} + +func Exec(statement Statement, db execution.Db) (res sql.Result, err error) { + query, args, err := statement.Sql() + + if err != nil { + return + } + + return db.Exec(query, args...) +} + +func ExecContext(statement Statement, db execution.Db, context context.Context) (res sql.Result, err error) { + query, args, err := statement.Sql() + + if err != nil { + return + } + + return db.ExecContext(context, query, args...) +} diff --git a/sqlbuilder/update_statement.go b/sqlbuilder/update_statement.go index 03b64aa..c491e79 100644 --- a/sqlbuilder/update_statement.go +++ b/sqlbuilder/update_statement.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "errors" "github.com/go-jet/jet/sqlbuilder/execution" @@ -142,6 +143,14 @@ func (u *updateStatementImpl) Query(db execution.Db, destination interface{}) er return Query(u, db, destination) } +func (u *updateStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { + return QueryContext(u, db, context, destination) +} + func (u *updateStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { return Exec(u, db) } + +func (u *updateStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { + return ExecContext(u, db, context) +} diff --git a/sqlbuilder/utils.go b/sqlbuilder/utils.go index 25c4929..8b6bec3 100644 --- a/sqlbuilder/utils.go +++ b/sqlbuilder/utils.go @@ -1,9 +1,7 @@ package sqlbuilder import ( - "database/sql" "errors" - "github.com/go-jet/jet/sqlbuilder/execution" "github.com/serenize/snaker" "reflect" ) @@ -185,23 +183,3 @@ func mustBe(v reflect.Value, expected reflect.Kind) { panic("argument mismatch: expected " + expected.String() + ", got " + v.Type().String()) } } - -func Query(statement Statement, db execution.Db, destination interface{}) error { - query, args, err := statement.Sql() - - if err != nil { - return err - } - - return execution.Query(db, query, args, destination) -} - -func Exec(statement Statement, db execution.Db) (res sql.Result, err error) { - query, args, err := statement.Sql() - - if err != nil { - return - } - - return db.Exec(query, args...) -} diff --git a/tests/chinook_db_test.go b/tests/chinook_db_test.go index 30fea09..1c46b39 100644 --- a/tests/chinook_db_test.go +++ b/tests/chinook_db_test.go @@ -1,6 +1,7 @@ package tests import ( + "context" "encoding/json" "fmt" "github.com/davecgh/go-spew/spew" @@ -10,6 +11,7 @@ import ( "gotest.tools/assert" "io/ioutil" "testing" + "time" ) func TestSelect(t *testing.T) { @@ -152,6 +154,36 @@ ORDER BY "Album.AlbumId"; assert.DeepEqual(t, dest[1], album2) } +func TestQueryWithContext(t *testing.T) { + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + dest := []model.Album{} + + err := Album. + CROSS_JOIN(Track). + CROSS_JOIN(InvoiceLine). + SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns). + QueryContext(db, ctx, &dest) + + assert.Error(t, err, "context deadline exceeded") +} + +func TestExecWithContext(t *testing.T) { + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err := Album. + CROSS_JOIN(Track). + CROSS_JOIN(InvoiceLine). + SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns). + ExecContext(db, ctx) + + assert.Error(t, err, "pq: canceling statement due to user request") +} + func TestSubQueriesForQuotedNames(t *testing.T) { first10Artist := Artist. SELECT(Artist.AllColumns).