diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 2360a16..637ed05 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -469,3 +469,30 @@ FOR` assert.NilError(t, err) } } + +func TestExpressionWrappers(t *testing.T) { + query := SELECT( + BoolExp(Raw("true")), + IntExp(Raw("11")), + FloatExp(Raw("11.22")), + StringExp(Raw("'stringer'")), + TimeExp(Raw("'raw'")), + TimestampExp(Raw("'raw'")), + DateTimeExp(Raw("'raw'")), + DateExp(Raw("'date'")), + ) + + testutils.AssertStatementSql(t, query, ` +SELECT true, + 11, + 11.22, + 'stringer', + 'raw', + 'raw', + 'raw', + 'date'; +`) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 066f64c..8d63d8c 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -2,7 +2,6 @@ package postgres import ( "fmt" - "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/enum" @@ -1273,36 +1272,68 @@ OFFSET 20; } func TestAllSetOperators(t *testing.T) { + var select1 = Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17600)).AND(Payment.PaymentID.LT(Int(17610)))) + var select2 = Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17620)).AND(Payment.PaymentID.LT(Int(17630)))) - select1 := Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17600)).AND(Payment.PaymentID.LT(Int(17610)))) - select2 := Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17620)).AND(Payment.PaymentID.LT(Int(17630)))) - - type setOperator func(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) SetStatement - operators := []setOperator{ - UNION, - UNION_ALL, - INTERSECT, - INTERSECT_ALL, - } - - expectedDestLen := []int{ - 20, - 20, - 0, - 0, - 10, - 10, - } - - for i, operator := range operators { - query := operator(select1, select2) + t.Run("UNION", func(t *testing.T) { + query := select1.UNION(select2) dest := []model.Payment{} err := query.Query(db, &dest) assert.NilError(t, err) - assert.Equal(t, len(dest), expectedDestLen[i]) - } + assert.Equal(t, len(dest), 20) + }) + + t.Run("UNION_ALL", func(t *testing.T) { + query := select1.UNION_ALL(select2) + + dest := []model.Payment{} + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 20) + }) + + t.Run("INTERSECT", func(t *testing.T) { + query := select1.INTERSECT(select2) + + dest := []model.Payment{} + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 0) + }) + + t.Run("INTERSECT_ALL", func(t *testing.T) { + query := select1.INTERSECT_ALL(select2) + + dest := []model.Payment{} + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 0) + }) + + t.Run("EXCEPT", func(t *testing.T) { + query := select1.EXCEPT(select2) + + dest := []model.Payment{} + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 10) + }) + + t.Run("EXCEPT_ALL", func(t *testing.T) { + query := select1.EXCEPT_ALL(select2) + + dest := []model.Payment{} + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 10) + }) } func TestSelectWithCase(t *testing.T) { @@ -1559,3 +1590,32 @@ func TestQuickStartWithSubQueries(t *testing.T) { //jsonSave("./testdata/quick-start-dest2.json", dest2) testutils.AssertJSONFile(t, dest2, "./postgres/testdata/quick-start-dest2.json") } + +func TestExpressionWrappers(t *testing.T) { + query := SELECT( + BoolExp(Raw("true")), + IntExp(Raw("11")), + FloatExp(Raw("11.22")), + StringExp(Raw("'stringer'")), + TimeExp(Raw("'raw'")), + TimezExp(Raw("'raw'")), + TimestampExp(Raw("'raw'")), + TimestampzExp(Raw("'raw'")), + DateExp(Raw("'date'")), + ) + + testutils.AssertStatementSql(t, query, ` +SELECT true, + 11, + 11.22, + 'stringer', + 'raw', + 'raw', + 'raw', + 'raw', + 'date'; +`) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} diff --git a/tests/postgres/util_test.go b/tests/postgres/util_test.go index 88359e0..3f00e7b 100644 --- a/tests/postgres/util_test.go +++ b/tests/postgres/util_test.go @@ -23,6 +23,7 @@ func assertExecErr(t *testing.T, stmt jet.Statement, errorStr string) { assert.Error(t, err, errorStr) } + func BoolPtr(b bool) *bool { return &b }