diff --git a/postgres/clause_test.go b/postgres/clause_test.go index 5602505..28be315 100644 --- a/postgres/clause_test.go +++ b/postgres/clause_test.go @@ -29,7 +29,7 @@ ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`) ) assertClauseSerialize(t, onConflict, ` ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE - SET col_bool = $1, + SET col_bool = $1::boolean, col_int = $2 WHERE table2.col_float > $3`) } diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index d98d8f3..45ed739 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -33,7 +33,7 @@ func TestExists(t *testing.T) { ).EQ(Bool(true)), `((EXISTS ( SELECT $1 -)) = $2)`, int64(1), true) +)) = $2::boolean)`, int64(1), true) assertProjectionSerialize(t, EXISTS( SELECT(Int(1)), diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index ad687b5..3ec333e 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -165,7 +165,7 @@ VALUES ('one', 'two'), ('1', '2'), ('theta', 'beta') ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE - SET col_bool = TRUE, + SET col_bool = TRUE::boolean, col_int = 1, (col1, col_bool) = ROW(2, 'two') WHERE table1.col1 > 2 @@ -191,7 +191,7 @@ INSERT INTO db.table1 (col1, col_bool) VALUES ('one', 'two'), ('1', '2') ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE - SET col_bool = FALSE, + SET col_bool = FALSE::boolean, col_int = 1, (col1, col_bool) = ROW(2, 'two') WHERE table1.col1 > 2 diff --git a/postgres/interval_expression_test.go b/postgres/interval_expression_test.go index f2fa9fe..8d6e647 100644 --- a/postgres/interval_expression_test.go +++ b/postgres/interval_expression_test.go @@ -67,7 +67,7 @@ func TestIntervalExpressionMethods(t *testing.T) { assertSerialize(t, table1ColInterval.EQ(INTERVAL(10, SECOND)), "(table1.col_interval = INTERVAL '10 SECOND')") assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)), "(table1.col_interval = INTERVAL '11 MINUTE')") assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)).EQ(Bool(false)), - "((table1.col_interval = INTERVAL '11 MINUTE') = $1)", false) + "((table1.col_interval = INTERVAL '11 MINUTE') = $1::boolean)", false) assertSerialize(t, table1ColInterval.NOT_EQ(table2ColInterval), "(table1.col_interval != table2.col_interval)") assertSerialize(t, table1ColInterval.IS_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS DISTINCT FROM table2.col_interval)") assertSerialize(t, table1ColInterval.IS_NOT_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS NOT DISTINCT FROM table2.col_interval)") diff --git a/postgres/literal.go b/postgres/literal.go index 8ee3235..524b251 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -6,35 +6,52 @@ import ( "github.com/go-jet/jet/v2/internal/jet" ) -// Bool creates new bool literal expression -var Bool = jet.Bool +func Bool(value bool) BoolExpression { + return CAST(jet.Bool(value)).AS_BOOL() +} // Int is constructor for 64 bit signed integer expressions literals. var Int = jet.Int // Int8 is constructor for 8 bit signed integer expressions literals. -var Int8 = jet.Int8 +func Int8(value int8) IntegerExpression { + return CAST(jet.Int8(value)).AS_SMALLINT() +} // Int16 is constructor for 16 bit signed integer expressions literals. -var Int16 = jet.Int16 +func Int16(value int16) IntegerExpression { + return CAST(jet.Int16(value)).AS_SMALLINT() +} // Int32 is constructor for 32 bit signed integer expressions literals. -var Int32 = jet.Int32 +func Int32(value int32) IntegerExpression { + return CAST(jet.Int32(value)).AS_INTEGER() +} // Int64 is constructor for 64 bit signed integer expressions literals. -var Int64 = jet.Int +func Int64(value int64) IntegerExpression { + return CAST(jet.Int(value)).AS_BIGINT() +} // Uint8 is constructor for 8 bit unsigned integer expressions literals. -var Uint8 = jet.Uint8 +func Uint8(value uint8) IntegerExpression { + return CAST(jet.Uint8(value)).AS_SMALLINT() +} // Uint16 is constructor for 16 bit unsigned integer expressions literals. -var Uint16 = jet.Uint16 +func Uint16(value uint16) IntegerExpression { + return CAST(jet.Uint16(value)).AS_INTEGER() +} // Uint32 is constructor for 32 bit unsigned integer expressions literals. -var Uint32 = jet.Uint32 +func Uint32(value uint32) IntegerExpression { + return CAST(jet.Uint32(value)).AS_BIGINT() +} // Uint64 is constructor for 64 bit unsigned integer expressions literals. -var Uint64 = jet.Uint64 +func Uint64(value uint64) IntegerExpression { + return CAST(jet.Uint64(value)).AS_BIGINT() +} // Float creates new float literal expression var Float = jet.Float diff --git a/postgres/literal_test.go b/postgres/literal_test.go index 52a15a0..f95e486 100644 --- a/postgres/literal_test.go +++ b/postgres/literal_test.go @@ -7,7 +7,7 @@ import ( ) func TestBool(t *testing.T) { - assertSerialize(t, Bool(false), `$1`, false) + assertSerialize(t, Bool(false), `$1::boolean`, false) } func TestInt(t *testing.T) { @@ -16,42 +16,42 @@ func TestInt(t *testing.T) { func TestInt8(t *testing.T) { val := int8(math.MinInt8) - assertSerialize(t, Int8(val), `$1`, val) + assertSerialize(t, Int8(val), `$1::smallint`, val) } func TestInt16(t *testing.T) { val := int16(math.MinInt16) - assertSerialize(t, Int16(val), `$1`, val) + assertSerialize(t, Int16(val), `$1::smallint`, val) } func TestInt32(t *testing.T) { val := int32(math.MinInt32) - assertSerialize(t, Int32(val), `$1`, val) + assertSerialize(t, Int32(val), `$1::integer`, val) } func TestInt64(t *testing.T) { val := int64(math.MinInt64) - assertSerialize(t, Int64(val), `$1`, val) + assertSerialize(t, Int64(val), `$1::bigint`, val) } func TestUint8(t *testing.T) { val := uint8(math.MaxUint8) - assertSerialize(t, Uint8(val), `$1`, val) + assertSerialize(t, Uint8(val), `$1::smallint`, val) } func TestUint16(t *testing.T) { val := uint16(math.MaxUint16) - assertSerialize(t, Uint16(val), `$1`, val) + assertSerialize(t, Uint16(val), `$1::integer`, val) } func TestUint32(t *testing.T) { val := uint32(math.MaxUint32) - assertSerialize(t, Uint32(val), `$1`, val) + assertSerialize(t, Uint32(val), `$1::bigint`, val) } func TestUint64(t *testing.T) { val := uint64(math.MaxUint64) - assertSerialize(t, Uint64(val), `$1`, val) + assertSerialize(t, Uint64(val), `$1::bigint`, val) } func TestFloat(t *testing.T) { diff --git a/postgres/select_statement_test.go b/postgres/select_statement_test.go index c3af03b..b487f90 100644 --- a/postgres/select_statement_test.go +++ b/postgres/select_statement_test.go @@ -23,7 +23,7 @@ func TestSelectLiterals(t *testing.T) { assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` SELECT $1, $2, - $3 + $3::boolean FROM db.table1; `, int64(1), 2.2, false) } @@ -59,7 +59,7 @@ func TestSelectWhere(t *testing.T) { assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 -WHERE $1; +WHERE $1::boolean; `, true) assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` SELECT table1.col_int AS "table1.col_int" @@ -80,7 +80,7 @@ func TestSelectHaving(t *testing.T) { assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` SELECT table3.col_int AS "table3.col_int" FROM db.table3 -HAVING table1.col_bool = $1; +HAVING table1.col_bool = $1::boolean; `, true) } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 82ac82b..a70fffe 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -225,7 +225,7 @@ func TestExpressionOperators(t *testing.T) { query := AllTypes.SELECT( AllTypes.Integer.IS_NULL().AS("result.is_null"), AllTypes.DatePtr.IS_NOT_NULL().AS("result.is_not_null"), - AllTypes.SmallIntPtr.IN(Int(11), Int(22)).AS("result.in"), + AllTypes.SmallIntPtr.IN(Int8(11), Int8(22)).AS("result.in"), AllTypes.SmallIntPtr.IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.in_select"), Raw("CURRENT_USER").AS("result.raw"), @@ -233,14 +233,16 @@ func TestExpressionOperators(t *testing.T) { Raw("#1 + all_types.integer + #2 + #1 + #3 + #4", RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}).AS("result.raw_arg2"), - AllTypes.SmallIntPtr.NOT_IN(Int(11), Int(22), NULL).AS("result.not_in"), + AllTypes.SmallIntPtr.NOT_IN(Int(11), Int16(22), NULL).AS("result.not_in"), AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), ).LIMIT(2) + //fmt.Println(query.Sql()) + testutils.AssertStatementSql(t, query, ` SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", - (all_types.small_int_ptr IN ($1, $2)) AS "result.in", + (all_types.small_int_ptr IN ($1::smallint, $2::smallint)) AS "result.in", (all_types.small_int_ptr IN ( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types @@ -248,14 +250,14 @@ SELECT all_types.integer IS NULL AS "result.is_null", (CURRENT_USER) AS "result.raw", ($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg", ($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2", - (all_types.small_int_ptr NOT IN ($9, $10, NULL)) AS "result.not_in", + (all_types.small_int_ptr NOT IN ($9, $10::smallint, NULL)) AS "result.not_in", (all_types.small_int_ptr NOT IN ( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types )) AS "result.not_in_select" FROM test_sample.all_types LIMIT $11; -`, int64(11), int64(22), 78, 56, 11, 22, 33, 44, int64(11), int64(22), int64(2)) +`, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), int64(2)) var dest []struct { common.ExpressionTestResult `alias:"result.*"` @@ -450,13 +452,13 @@ func TestBoolOperators(t *testing.T) { testutils.AssertStatementSql(t, query, ` SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1", - (all_types.boolean = $1) AS "EQ2", + (all_types.boolean = $1::boolean) AS "EQ2", (all_types.boolean != all_types.boolean_ptr) AS "NEq1", - (all_types.boolean != $2) AS "NEq2", + (all_types.boolean != $2::boolean) AS "NEq2", (all_types.boolean IS DISTINCT FROM all_types.boolean_ptr) AS "distinct1", - (all_types.boolean IS DISTINCT FROM $3) AS "distinct2", + (all_types.boolean IS DISTINCT FROM $3::boolean) AS "distinct2", (all_types.boolean IS NOT DISTINCT FROM all_types.boolean_ptr) AS "not_distinct_1", - (all_types.boolean IS NOT DISTINCT FROM $4) AS "NOTDISTINCT2", + (all_types.boolean IS NOT DISTINCT FROM $4::boolean) AS "NOTDISTINCT2", all_types.boolean IS TRUE AS "ISTRUE", all_types.boolean IS NOT TRUE AS "isnottrue", all_types.boolean IS FALSE AS "is_False", @@ -512,23 +514,23 @@ func TestFloatOperators(t *testing.T) { AllTypes.Numeric.GT(Float(124)).AS("gt1"), AllTypes.Numeric.GT(Float(34.56)).AS("gt2"), - TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Int(2)).AS("add1"), - TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int(2)).AS("add2"), - TRUNC(AllTypes.Decimal.SUB(AllTypes.DecimalPtr), Int(2)).AS("sub1"), - TRUNC(AllTypes.Decimal.SUB(Float(11.22)), Int(2)).AS("sub2"), - TRUNC(AllTypes.Decimal.MUL(AllTypes.DecimalPtr), Int(2)).AS("mul1"), - TRUNC(AllTypes.Decimal.MUL(Float(11.22)), Int(2)).AS("mul2"), - TRUNC(AllTypes.Decimal.DIV(AllTypes.DecimalPtr), Int(2)).AS("div1"), - TRUNC(AllTypes.Decimal.DIV(Float(11.22)), Int(2)).AS("div2"), - TRUNC(AllTypes.Decimal.MOD(AllTypes.DecimalPtr), Int(2)).AS("mod1"), - TRUNC(AllTypes.Decimal.MOD(Float(11.22)), Int(2)).AS("mod2"), - TRUNC(AllTypes.Decimal.POW(AllTypes.DecimalPtr), Int(2)).AS("pow1"), - TRUNC(AllTypes.Decimal.POW(Float(2.1)), Int(2)).AS("pow2"), + TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Uint8(2)).AS("add1"), + TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int8(2)).AS("add2"), + TRUNC(AllTypes.Decimal.SUB(AllTypes.DecimalPtr), Uint16(2)).AS("sub1"), + TRUNC(AllTypes.Decimal.SUB(Float(11.22)), Int16(2)).AS("sub2"), + TRUNC(AllTypes.Decimal.MUL(AllTypes.DecimalPtr), Int16(2)).AS("mul1"), + TRUNC(AllTypes.Decimal.MUL(Float(11.22)), Int32(2)).AS("mul2"), + TRUNC(AllTypes.Decimal.DIV(AllTypes.DecimalPtr), Int32(2)).AS("div1"), + TRUNC(AllTypes.Decimal.DIV(Float(11.22)), Int8(2)).AS("div2"), + TRUNC(AllTypes.Decimal.MOD(AllTypes.DecimalPtr), Int8(2)).AS("mod1"), + TRUNC(AllTypes.Decimal.MOD(Float(11.22)), Int8(2)).AS("mod2"), + TRUNC(AllTypes.Decimal.POW(AllTypes.DecimalPtr), Int8(2)).AS("pow1"), + TRUNC(AllTypes.Decimal.POW(Float(2.1)), Int8(2)).AS("pow2"), - TRUNC(ABSf(AllTypes.Decimal), Int(2)).AS("abs"), - TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int(2)).AS("power"), - TRUNC(SQRT(AllTypes.Decimal), Int(2)).AS("sqrt"), - TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"), + TRUNC(ABSf(AllTypes.Decimal), Int8(2)).AS("abs"), + TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int8(2)).AS("power"), + TRUNC(SQRT(AllTypes.Decimal), Int16(2)).AS("sqrt"), + TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int8(2)).AS("cbrt"), CEIL(AllTypes.Real).AS("ceil"), FLOOR(AllTypes.Real).AS("floor"), @@ -536,12 +538,12 @@ func TestFloatOperators(t *testing.T) { ROUND(AllTypes.Decimal, AllTypes.Integer).AS("round2"), SIGN(AllTypes.Real).AS("sign"), - TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"), + TRUNC(AllTypes.Decimal, Int32(1)).AS("trunc"), ).LIMIT(2) - queryStr, _ := query.Sql() + //fmt.Println(query.Sql()) - require.Equal(t, queryStr, ` + testutils.AssertStatementSql(t, query, ` SELECT (all_types.numeric = all_types.numeric) AS "eq1", (all_types.decimal = $1) AS "eq2", (all_types.real = $2) AS "eq3", @@ -555,28 +557,28 @@ SELECT (all_types.numeric = all_types.numeric) AS "eq1", (all_types.numeric < $8) AS "lt2", (all_types.numeric > $9) AS "gt1", (all_types.numeric > $10) AS "gt2", - TRUNC((all_types.decimal + all_types.decimal), $11) AS "add1", - TRUNC((all_types.decimal + $12), $13) AS "add2", - TRUNC((all_types.decimal - all_types.decimal_ptr), $14) AS "sub1", - TRUNC((all_types.decimal - $15), $16) AS "sub2", - TRUNC((all_types.decimal * all_types.decimal_ptr), $17) AS "mul1", - TRUNC((all_types.decimal * $18), $19) AS "mul2", - TRUNC((all_types.decimal / all_types.decimal_ptr), $20) AS "div1", - TRUNC((all_types.decimal / $21), $22) AS "div2", - TRUNC((all_types.decimal % all_types.decimal_ptr), $23) AS "mod1", - TRUNC((all_types.decimal % $24), $25) AS "mod2", - TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $26) AS "pow1", - TRUNC(POW(all_types.decimal, $27), $28) AS "pow2", - TRUNC(ABS(all_types.decimal), $29) AS "abs", - TRUNC(POWER(all_types.decimal, $30), $31) AS "power", - TRUNC(SQRT(all_types.decimal), $32) AS "sqrt", - TRUNC(CBRT(all_types.decimal)::decimal, $33) AS "cbrt", + TRUNC((all_types.decimal + all_types.decimal), $11::smallint) AS "add1", + TRUNC((all_types.decimal + $12), $13::smallint) AS "add2", + TRUNC((all_types.decimal - all_types.decimal_ptr), $14::integer) AS "sub1", + TRUNC((all_types.decimal - $15), $16::smallint) AS "sub2", + TRUNC((all_types.decimal * all_types.decimal_ptr), $17::smallint) AS "mul1", + TRUNC((all_types.decimal * $18), $19::integer) AS "mul2", + TRUNC((all_types.decimal / all_types.decimal_ptr), $20::integer) AS "div1", + TRUNC((all_types.decimal / $21), $22::smallint) AS "div2", + TRUNC((all_types.decimal % all_types.decimal_ptr), $23::smallint) AS "mod1", + TRUNC((all_types.decimal % $24), $25::smallint) AS "mod2", + TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $26::smallint) AS "pow1", + TRUNC(POW(all_types.decimal, $27), $28::smallint) AS "pow2", + TRUNC(ABS(all_types.decimal), $29::smallint) AS "abs", + TRUNC(POWER(all_types.decimal, $30), $31::smallint) AS "power", + TRUNC(SQRT(all_types.decimal), $32::smallint) AS "sqrt", + TRUNC(CBRT(all_types.decimal)::decimal, $33::smallint) AS "cbrt", CEIL(all_types.real) AS "ceil", FLOOR(all_types.real) AS "floor", ROUND(all_types.decimal) AS "round1", ROUND(all_types.decimal, all_types.integer) AS "round2", SIGN(all_types.real) AS "sign", - TRUNC(all_types.decimal, $34) AS "trunc" + TRUNC(all_types.decimal, $34::integer) AS "trunc" FROM test_sample.all_types LIMIT $35; `) @@ -602,46 +604,46 @@ func TestIntegerOperators(t *testing.T) { AllTypes.SmallIntPtr, AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"), - AllTypes.BigInt.EQ(Int(12)).AS("eq2"), + AllTypes.BigInt.EQ(Int64(12)).AS("eq2"), AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"), - AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"), + AllTypes.BigInt.NOT_EQ(Int64(12)).AS("neq2"), AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"), - AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"), + AllTypes.BigInt.IS_DISTINCT_FROM(Int32(12)).AS("distinct2"), AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"), - AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"), + AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int32(12)).AS("not distinct2"), AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"), - AllTypes.BigInt.LT(Int(65)).AS("lt2"), + AllTypes.BigInt.LT(Uint8(65)).AS("lt2"), AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"), - AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"), + AllTypes.BigInt.LT_EQ(Uint16(65)).AS("lte2"), AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"), - AllTypes.BigInt.GT(Int(65)).AS("gt2"), + AllTypes.BigInt.GT(Uint32(65)).AS("gt2"), AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"), - AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"), + AllTypes.BigInt.GT_EQ(Uint64(65)).AS("gte2"), AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"), AllTypes.BigInt.ADD(Int(11)).AS("add2"), AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"), - AllTypes.BigInt.SUB(Int(11)).AS("sub2"), + AllTypes.BigInt.SUB(Int8(11)).AS("sub2"), AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"), - AllTypes.BigInt.MUL(Int(11)).AS("mul2"), + AllTypes.BigInt.MUL(Int16(11)).AS("mul2"), AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"), - AllTypes.BigInt.DIV(Int(11)).AS("div2"), + AllTypes.BigInt.DIV(Int32(11)).AS("div2"), AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"), - AllTypes.BigInt.MOD(Int(11)).AS("mod2"), + AllTypes.BigInt.MOD(Int64(11)).AS("mod2"), - AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int(3))).AS("pow1"), - AllTypes.SmallInt.POW(Int(6)).AS("pow2"), + AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int8(3))).AS("pow1"), + AllTypes.SmallInt.POW(Int8(6)).AS("pow2"), AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"), AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"), @@ -655,7 +657,7 @@ func TestIntegerOperators(t *testing.T) { BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), BIT_NOT(Int(-11)).AS("bit_not_2"), - AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"), + AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int8(2))).AS("bit shift left 1"), AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"), AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"), @@ -666,7 +668,7 @@ func TestIntegerOperators(t *testing.T) { CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"), ).LIMIT(2) - //fmt.Println(query.Sql()) + // fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, ` SELECT all_types.big_int AS "all_types.big_int", @@ -674,33 +676,33 @@ SELECT all_types.big_int AS "all_types.big_int", all_types.small_int AS "all_types.small_int", all_types.small_int_ptr AS "all_types.small_int_ptr", (all_types.big_int = all_types.big_int) AS "eq1", - (all_types.big_int = $1) AS "eq2", + (all_types.big_int = $1::bigint) AS "eq2", (all_types.big_int != all_types.big_int_ptr) AS "neq1", - (all_types.big_int != $2) AS "neq2", + (all_types.big_int != $2::bigint) AS "neq2", (all_types.big_int IS DISTINCT FROM all_types.big_int) AS "distinct1", - (all_types.big_int IS DISTINCT FROM $3) AS "distinct2", + (all_types.big_int IS DISTINCT FROM $3::integer) AS "distinct2", (all_types.big_int IS NOT DISTINCT FROM all_types.big_int) AS "not distinct1", - (all_types.big_int IS NOT DISTINCT FROM $4) AS "not distinct2", + (all_types.big_int IS NOT DISTINCT FROM $4::integer) AS "not distinct2", (all_types.big_int < all_types.big_int_ptr) AS "lt1", - (all_types.big_int < $5) AS "lt2", + (all_types.big_int < $5::smallint) AS "lt2", (all_types.big_int <= all_types.big_int_ptr) AS "lte1", - (all_types.big_int <= $6) AS "lte2", + (all_types.big_int <= $6::integer) AS "lte2", (all_types.big_int > all_types.big_int_ptr) AS "gt1", - (all_types.big_int > $7) AS "gt2", + (all_types.big_int > $7::bigint) AS "gt2", (all_types.big_int >= all_types.big_int_ptr) AS "gte1", - (all_types.big_int >= $8) AS "gte2", + (all_types.big_int >= $8::bigint) AS "gte2", (all_types.big_int + all_types.big_int) AS "add1", (all_types.big_int + $9) AS "add2", (all_types.big_int - all_types.big_int) AS "sub1", - (all_types.big_int - $10) AS "sub2", + (all_types.big_int - $10::smallint) AS "sub2", (all_types.big_int * all_types.big_int) AS "mul1", - (all_types.big_int * $11) AS "mul2", + (all_types.big_int * $11::smallint) AS "mul2", (all_types.big_int / all_types.big_int) AS "div1", - (all_types.big_int / $12) AS "div2", + (all_types.big_int / $12::integer) AS "div2", (all_types.big_int % all_types.big_int) AS "mod1", - (all_types.big_int % $13) AS "mod2", - POW(all_types.small_int, (all_types.small_int / $14)) AS "pow1", - POW(all_types.small_int, $15) AS "pow2", + (all_types.big_int % $13::bigint) AS "mod2", + POW(all_types.small_int, (all_types.small_int / $14::smallint)) AS "pow1", + POW(all_types.small_int, $15::smallint) AS "pow2", (all_types.small_int & all_types.small_int) AS "bit_and1", (all_types.small_int & all_types.small_int) AS "bit_and2", (all_types.small_int | all_types.small_int) AS "bit or 1", @@ -709,7 +711,7 @@ SELECT all_types.big_int AS "all_types.big_int", (all_types.small_int # $17) AS "bit xor 2", (~ ($18 * all_types.small_int)) AS "bit_not_1", (~ -11) AS "bit_not_2", - (all_types.small_int << (all_types.small_int / $19)) AS "bit shift left 1", + (all_types.small_int << (all_types.small_int / $19::smallint)) AS "bit shift left 1", (all_types.small_int << $20) AS "bit shift left 2", (all_types.small_int >> (all_types.small_int / $21)) AS "bit shift right 1", (all_types.small_int >> $22) AS "bit shift right 2", diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 988499a..90bd8f9 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1943,7 +1943,7 @@ SELECT customer.customer_id AS "customer.customer_id", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer -WHERE ($1 AND (customer.customer_id = $2)) AND (customer.activebool = $3); +WHERE ($1::boolean AND (customer.customer_id = $2)) AND (customer.activebool = $3::boolean); `, true, int64(1), true) dest := []model.Customer{} @@ -2403,3 +2403,24 @@ func TestRecursionScanNx1(t *testing.T) { `) }) } + +// In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions, +// postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error. +// Int8, Int16, .... functions will add automatic type cast over placeholder, so type deduction is always possible. +func TestLiteralTypeDeduction(t *testing.T) { + stmt := SELECT( + SUM( + CASE().WHEN(Staff.Active.IS_TRUE()). + THEN(Int8(6)). // if Int8 and Int32 are replaced with Int, + ELSE(Int32(-1)), // execution of this statement will return an error + ).AS("num_passed"), + ).FROM(Staff) + + testutils.AssertStatementSql(t, stmt, ` +SELECT SUM((CASE WHEN staff.active IS TRUE THEN $1::smallint ELSE $2::integer END)) AS "num_passed" +FROM dvds.staff; +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 87ba49b..476333c 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -266,7 +266,7 @@ func TestUpdateWithModelData(t *testing.T) { expectedSQL := ` UPDATE test_sample.link SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL) -WHERE link.id = 201; +WHERE link.id = 201::integer; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) @@ -293,7 +293,7 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { var expectedSQL = ` UPDATE test_sample.link SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') -WHERE link.id = 201; +WHERE link.id = 201::integer; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201))