diff --git a/internal/jet/column_assigment.go b/internal/jet/column_assigment.go index 440f3eb..c888433 100644 --- a/internal/jet/column_assigment.go +++ b/internal/jet/column_assigment.go @@ -11,6 +11,13 @@ type columnAssigmentImpl struct { expression Expression } +func NewColumnAssignment(serializer ColumnSerializer, expression Expression) ColumnAssigment { + return &columnAssigmentImpl{ + column: serializer, + expression: expression, + } +} + func (a columnAssigmentImpl) isColumnAssigment() {} func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 97f90d6..cfedde4 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -9,6 +9,7 @@ import ( "github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/qrm" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,8 +18,6 @@ import ( "runtime" "testing" "time" - - "github.com/google/go-cmp/cmp" ) // UnixTimeComparer will compare time equality while ignoring time zone diff --git a/postgres/columns.go b/postgres/columns.go index 819da38..a70c234 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -109,13 +109,20 @@ type ColumnInterval interface { jet.Column From(subQuery SelectTable) ColumnInterval + SET(intervalExp IntervalExpression) ColumnAssigment } +//------------------------------------------------------// + type intervalColumnImpl struct { jet.ColumnExpressionImpl intervalInterfaceImpl } +func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment { + return jet.NewColumnAssignment(i, intervalExp) +} + func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval { newIntervalColumn := IntervalColumn(i.Name()) jet.SetTableName(newIntervalColumn, i.TableName()) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 2f1be14..70c4332 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -3,6 +3,8 @@ package postgres import ( "database/sql" "github.com/go-jet/jet/v2/internal/utils/ptr" + "github.com/stretchr/testify/assert" + "log/slog" "testing" "time" @@ -931,6 +933,68 @@ func TestTimeExpression(t *testing.T) { require.NoError(t, err) } +func TestIntervalSetFunctionality(t *testing.T) { + + t.Run("updateQueryIntervalTest", func(t *testing.T) { + slog.Info("Running test", slog.Any("test", t.Name())) + expectedQuery := ` +UPDATE test_sample.employee +SET pto_accrual = INTERVAL '3 HOUR' +WHERE employee.employee_id = $1 +RETURNING employee.employee_id AS "employee.employee_id", + employee.first_name AS "employee.first_name", + employee.last_name AS "employee.last_name", + employee.employment_date AS "employee.employment_date", + employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual"; +` + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var windy model.Employee + windy.PtoAccrual = ptr.Of("3h") + stmt := Employee.UPDATE(Employee.PtoAccrual).SET( + Employee.PtoAccrual.SET(INTERVAL(3, HOUR)), + ).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns) + + testutils.AssertStatementSql(t, stmt, expectedQuery) + err := stmt.Query(tx, &windy) + assert.Nil(t, err) + assert.Equal(t, *windy.PtoAccrual, "03:00:00") + + }) + }) + + t.Run("upsertQueryIntervalTest", func(t *testing.T) { + expectedQuery := ` +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6) +ON CONFLICT (employee_id) DO UPDATE + SET pto_accrual = excluded.pto_accrual +RETURNING employee.employee_id AS "employee.employee_id", + employee.first_name AS "employee.first_name", + employee.last_name AS "employee.last_name", + employee.employment_date AS "employee.employment_date", + employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual"; +` + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var employee model.Employee + employee.PtoAccrual = ptr.Of("5h") + stmt := Employee.INSERT(Employee.AllColumns). + MODEL(employee). + ON_CONFLICT(Employee.EmployeeID). + DO_UPDATE(SET( + Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual), + )).RETURNING(Employee.AllColumns) + + testutils.AssertStatementSql(t, stmt, expectedQuery) + err := stmt.Query(tx, &employee) + assert.Nil(t, err) + assert.Equal(t, *employee.PtoAccrual, "05:00:00") + + }) + }) +} + func TestInterval(t *testing.T) { skipForCockroachDB(t) diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index a079091..7edefd4 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -93,9 +93,9 @@ func TestInsertOnConflict(t *testing.T) { ON_CONFLICT().DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -111,9 +111,9 @@ ON CONFLICT DO NOTHING; ON_CONFLICT(Employee.EmployeeID).DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT (employee_id) DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -130,9 +130,9 @@ ON CONFLICT (employee_id) DO NOTHING; ON_CONFLICT().ON_CONSTRAINT("employee_pkey").DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -234,8 +234,8 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE ON_CONFLICT().DO_UPDATE(nil) testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5); +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6); `) testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index a1d4c2d..f252631 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -331,11 +331,13 @@ SELECT employee.employee_id AS "employee.employee_id", employee.last_name AS "employee.last_name", employee.employment_date AS "employee.employment_date", employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual", manager.employee_id AS "manager.employee_id", manager.first_name AS "manager.first_name", manager.last_name AS "manager.last_name", manager.employment_date AS "manager.employment_date", - manager.manager_id AS "manager.manager_id" + manager.manager_id AS "manager.manager_id", + manager.pto_accrual AS "manager.pto_accrual" FROM test_sample.employee LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id) ORDER BY employee.employee_id; @@ -370,6 +372,7 @@ ORDER BY employee.employee_id; LastName: "Hays", EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06.1 +0100 CET", 1), ManagerID: nil, + PtoAccrual: ptr.Of("22:00:00"), }) require.True(t, dest[0].Manager == nil) diff --git a/tests/testdata b/tests/testdata index 1e9247e..6a39774 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 1e9247e333babd5172cf162e38518d993f5f3df4 +Subproject commit 6a397747d310938b41d3950d68009578180d3dd5