diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index fdb6ac9..62e0bc6 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -2,8 +2,10 @@ package postgres import ( "database/sql" + "github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/stretchr/testify/assert" + "log/slog" "testing" "time" @@ -932,41 +934,79 @@ func TestTimeExpression(t *testing.T) { require.NoError(t, err) } -func TestIntervalUpsert(t *testing.T) { - testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { - stmt := SELECT(Employee.AllColumns).FROM(Employee). - WHERE(Employee.EmployeeID.EQ(Int(1))) +func TestIntervalSetFunctionality(t *testing.T) { + updateQuery := ` +UPDATE test_sample.employee +SET pto_accrual = INTERVAL '3 HOUR' +WHERE employee.employee_id = $1 +RETURNING employee.employee_id AS "employee.employee_id", + employee.first_name AS "employee.first_name", + employee.last_name AS "employee.last_name", + employee.employment_date AS "employee.employment_date", + employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual"; +` + insertQuery := ` +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6) +ON CONFLICT (employee_id) DO UPDATE + SET pto_accrual = excluded.pto_accrual +RETURNING employee.employee_id AS "employee.employee_id", + employee.first_name AS "employee.first_name", + employee.last_name AS "employee.last_name", + employee.employment_date AS "employee.employment_date", + employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual"; +` - //Validate initial dataset - var windy model.Employee - err := stmt.Query(tx, &windy) - assert.Equal(t, windy.EmployeeID, int32(1)) - assert.Equal(t, windy.FirstName, "Windy") - assert.Equal(t, windy.LastName, "Hays") - assert.Equal(t, *windy.PtoAccrual, "22:00:00") - assert.Nil(t, err) - windy.PtoAccrual = ptr.Of("3h") - //Update data - updateStmt := Employee.UPDATE(Employee.PtoAccrual).SET( - Employee.PtoAccrual.SET(INTERVAL(3, HOUR)), - ).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns) + testCases := []struct { + expectedQuery string + name string + duration string + expectedInterval string + statement func(employee *model.Employee) jet.Statement + }{ + { + name: "updateQuery", + expectedQuery: updateQuery, + duration: "3h", + expectedInterval: "03:00:00", + statement: func(employee *model.Employee) jet.Statement { + return Employee.UPDATE(Employee.PtoAccrual).SET( + Employee.PtoAccrual.SET(INTERVAL(3, HOUR)), + ).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns) + }, + }, + { + expectedQuery: insertQuery, + name: "insertQuery", + duration: "5h", + expectedInterval: "05:00:00", + statement: func(employee *model.Employee) jet.Statement { + return Employee.INSERT(Employee.AllColumns). + MODEL(employee). + ON_CONFLICT(Employee.EmployeeID). + DO_UPDATE(SET( + Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual), + )).RETURNING(Employee.AllColumns) + }, + }, + } - err = updateStmt.Query(tx, &windy) - err = stmt.Query(tx, &windy) - assert.Nil(t, err) - assert.Equal(t, *windy.PtoAccrual, "03:00:00") - //Upsert dataset with a different value - windy.PtoAccrual = ptr.Of("5h") - insertStmt := Employee.INSERT(Employee.AllColumns). - MODEL(&windy). - ON_CONFLICT(Employee.EmployeeID). - DO_UPDATE(SET( - Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual), - )).RETURNING(Employee.AllColumns) - err = insertStmt.Query(tx, &windy) - assert.Nil(t, err) - assert.Equal(t, *windy.PtoAccrual, "05:00:00") - }) + for _, tc := range testCases { + slog.Info("Running test", slog.Any("test", tc.name)) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var windy model.Employee + windy.PtoAccrual = ptr.Of(tc.duration) + stmt := tc.statement(&windy) + + testutils.AssertStatementSql(t, stmt, tc.expectedQuery) + err := stmt.Query(tx, &windy) + assert.Nil(t, err) + assert.Equal(t, *windy.PtoAccrual, tc.expectedInterval) + + }) + } } func TestInterval(t *testing.T) {