From 01305a138f318c16d0e860675a3cbfb46a647c11 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 26 Dec 2021 17:15:43 +0100 Subject: [PATCH] Add automatic type cast for integer literals In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions, postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error. Int8, Int16, Int32.... functions now will add automatic type cast over a placeholder, so type deduction is always possible. --- postgres/clause_test.go | 2 +- postgres/dialect_test.go | 2 +- postgres/insert_statement_test.go | 4 +- postgres/interval_expression_test.go | 2 +- postgres/literal.go | 37 +++++-- postgres/literal_test.go | 18 ++-- postgres/select_statement_test.go | 6 +- tests/postgres/alltypes_test.go | 154 ++++++++++++++------------- tests/postgres/select_test.go | 23 +++- tests/postgres/update_test.go | 4 +- 10 files changed, 146 insertions(+), 106 deletions(-) diff --git a/postgres/clause_test.go b/postgres/clause_test.go index 5602505..28be315 100644 --- a/postgres/clause_test.go +++ b/postgres/clause_test.go @@ -29,7 +29,7 @@ ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`) ) assertClauseSerialize(t, onConflict, ` ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE - SET col_bool = $1, + SET col_bool = $1::boolean, col_int = $2 WHERE table2.col_float > $3`) } diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index d98d8f3..45ed739 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -33,7 +33,7 @@ func TestExists(t *testing.T) { ).EQ(Bool(true)), `((EXISTS ( SELECT $1 -)) = $2)`, int64(1), true) +)) = $2::boolean)`, int64(1), true) assertProjectionSerialize(t, EXISTS( SELECT(Int(1)), diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index ad687b5..3ec333e 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -165,7 +165,7 @@ VALUES ('one', 'two'), ('1', '2'), ('theta', 'beta') ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE - SET col_bool = TRUE, + SET col_bool = TRUE::boolean, col_int = 1, (col1, col_bool) = ROW(2, 'two') WHERE table1.col1 > 2 @@ -191,7 +191,7 @@ INSERT INTO db.table1 (col1, col_bool) VALUES ('one', 'two'), ('1', '2') ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE - SET col_bool = FALSE, + SET col_bool = FALSE::boolean, col_int = 1, (col1, col_bool) = ROW(2, 'two') WHERE table1.col1 > 2 diff --git a/postgres/interval_expression_test.go b/postgres/interval_expression_test.go index f2fa9fe..8d6e647 100644 --- a/postgres/interval_expression_test.go +++ b/postgres/interval_expression_test.go @@ -67,7 +67,7 @@ func TestIntervalExpressionMethods(t *testing.T) { assertSerialize(t, table1ColInterval.EQ(INTERVAL(10, SECOND)), "(table1.col_interval = INTERVAL '10 SECOND')") assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)), "(table1.col_interval = INTERVAL '11 MINUTE')") assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)).EQ(Bool(false)), - "((table1.col_interval = INTERVAL '11 MINUTE') = $1)", false) + "((table1.col_interval = INTERVAL '11 MINUTE') = $1::boolean)", false) assertSerialize(t, table1ColInterval.NOT_EQ(table2ColInterval), "(table1.col_interval != table2.col_interval)") assertSerialize(t, table1ColInterval.IS_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS DISTINCT FROM table2.col_interval)") assertSerialize(t, table1ColInterval.IS_NOT_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS NOT DISTINCT FROM table2.col_interval)") diff --git a/postgres/literal.go b/postgres/literal.go index 8ee3235..524b251 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -6,35 +6,52 @@ import ( "github.com/go-jet/jet/v2/internal/jet" ) -// Bool creates new bool literal expression -var Bool = jet.Bool +func Bool(value bool) BoolExpression { + return CAST(jet.Bool(value)).AS_BOOL() +} // Int is constructor for 64 bit signed integer expressions literals. var Int = jet.Int // Int8 is constructor for 8 bit signed integer expressions literals. -var Int8 = jet.Int8 +func Int8(value int8) IntegerExpression { + return CAST(jet.Int8(value)).AS_SMALLINT() +} // Int16 is constructor for 16 bit signed integer expressions literals. -var Int16 = jet.Int16 +func Int16(value int16) IntegerExpression { + return CAST(jet.Int16(value)).AS_SMALLINT() +} // Int32 is constructor for 32 bit signed integer expressions literals. -var Int32 = jet.Int32 +func Int32(value int32) IntegerExpression { + return CAST(jet.Int32(value)).AS_INTEGER() +} // Int64 is constructor for 64 bit signed integer expressions literals. -var Int64 = jet.Int +func Int64(value int64) IntegerExpression { + return CAST(jet.Int(value)).AS_BIGINT() +} // Uint8 is constructor for 8 bit unsigned integer expressions literals. -var Uint8 = jet.Uint8 +func Uint8(value uint8) IntegerExpression { + return CAST(jet.Uint8(value)).AS_SMALLINT() +} // Uint16 is constructor for 16 bit unsigned integer expressions literals. -var Uint16 = jet.Uint16 +func Uint16(value uint16) IntegerExpression { + return CAST(jet.Uint16(value)).AS_INTEGER() +} // Uint32 is constructor for 32 bit unsigned integer expressions literals. -var Uint32 = jet.Uint32 +func Uint32(value uint32) IntegerExpression { + return CAST(jet.Uint32(value)).AS_BIGINT() +} // Uint64 is constructor for 64 bit unsigned integer expressions literals. -var Uint64 = jet.Uint64 +func Uint64(value uint64) IntegerExpression { + return CAST(jet.Uint64(value)).AS_BIGINT() +} // Float creates new float literal expression var Float = jet.Float diff --git a/postgres/literal_test.go b/postgres/literal_test.go index 52a15a0..f95e486 100644 --- a/postgres/literal_test.go +++ b/postgres/literal_test.go @@ -7,7 +7,7 @@ import ( ) func TestBool(t *testing.T) { - assertSerialize(t, Bool(false), `$1`, false) + assertSerialize(t, Bool(false), `$1::boolean`, false) } func TestInt(t *testing.T) { @@ -16,42 +16,42 @@ func TestInt(t *testing.T) { func TestInt8(t *testing.T) { val := int8(math.MinInt8) - assertSerialize(t, Int8(val), `$1`, val) + assertSerialize(t, Int8(val), `$1::smallint`, val) } func TestInt16(t *testing.T) { val := int16(math.MinInt16) - assertSerialize(t, Int16(val), `$1`, val) + assertSerialize(t, Int16(val), `$1::smallint`, val) } func TestInt32(t *testing.T) { val := int32(math.MinInt32) - assertSerialize(t, Int32(val), `$1`, val) + assertSerialize(t, Int32(val), `$1::integer`, val) } func TestInt64(t *testing.T) { val := int64(math.MinInt64) - assertSerialize(t, Int64(val), `$1`, val) + assertSerialize(t, Int64(val), `$1::bigint`, val) } func TestUint8(t *testing.T) { val := uint8(math.MaxUint8) - assertSerialize(t, Uint8(val), `$1`, val) + assertSerialize(t, Uint8(val), `$1::smallint`, val) } func TestUint16(t *testing.T) { val := uint16(math.MaxUint16) - assertSerialize(t, Uint16(val), `$1`, val) + assertSerialize(t, Uint16(val), `$1::integer`, val) } func TestUint32(t *testing.T) { val := uint32(math.MaxUint32) - assertSerialize(t, Uint32(val), `$1`, val) + assertSerialize(t, Uint32(val), `$1::bigint`, val) } func TestUint64(t *testing.T) { val := uint64(math.MaxUint64) - assertSerialize(t, Uint64(val), `$1`, val) + assertSerialize(t, Uint64(val), `$1::bigint`, val) } func TestFloat(t *testing.T) { diff --git a/postgres/select_statement_test.go b/postgres/select_statement_test.go index c3af03b..b487f90 100644 --- a/postgres/select_statement_test.go +++ b/postgres/select_statement_test.go @@ -23,7 +23,7 @@ func TestSelectLiterals(t *testing.T) { assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` SELECT $1, $2, - $3 + $3::boolean FROM db.table1; `, int64(1), 2.2, false) } @@ -59,7 +59,7 @@ func TestSelectWhere(t *testing.T) { assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 -WHERE $1; +WHERE $1::boolean; `, true) assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` SELECT table1.col_int AS "table1.col_int" @@ -80,7 +80,7 @@ func TestSelectHaving(t *testing.T) { assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` SELECT table3.col_int AS "table3.col_int" FROM db.table3 -HAVING table1.col_bool = $1; +HAVING table1.col_bool = $1::boolean; `, true) } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 82ac82b..a70fffe 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -225,7 +225,7 @@ func TestExpressionOperators(t *testing.T) { query := AllTypes.SELECT( AllTypes.Integer.IS_NULL().AS("result.is_null"), AllTypes.DatePtr.IS_NOT_NULL().AS("result.is_not_null"), - AllTypes.SmallIntPtr.IN(Int(11), Int(22)).AS("result.in"), + AllTypes.SmallIntPtr.IN(Int8(11), Int8(22)).AS("result.in"), AllTypes.SmallIntPtr.IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.in_select"), Raw("CURRENT_USER").AS("result.raw"), @@ -233,14 +233,16 @@ func TestExpressionOperators(t *testing.T) { Raw("#1 + all_types.integer + #2 + #1 + #3 + #4", RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}).AS("result.raw_arg2"), - AllTypes.SmallIntPtr.NOT_IN(Int(11), Int(22), NULL).AS("result.not_in"), + AllTypes.SmallIntPtr.NOT_IN(Int(11), Int16(22), NULL).AS("result.not_in"), AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), ).LIMIT(2) + //fmt.Println(query.Sql()) + testutils.AssertStatementSql(t, query, ` SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", - (all_types.small_int_ptr IN ($1, $2)) AS "result.in", + (all_types.small_int_ptr IN ($1::smallint, $2::smallint)) AS "result.in", (all_types.small_int_ptr IN ( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types @@ -248,14 +250,14 @@ SELECT all_types.integer IS NULL AS "result.is_null", (CURRENT_USER) AS "result.raw", ($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg", ($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2", - (all_types.small_int_ptr NOT IN ($9, $10, NULL)) AS "result.not_in", + (all_types.small_int_ptr NOT IN ($9, $10::smallint, NULL)) AS "result.not_in", (all_types.small_int_ptr NOT IN ( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types )) AS "result.not_in_select" FROM test_sample.all_types LIMIT $11; -`, int64(11), int64(22), 78, 56, 11, 22, 33, 44, int64(11), int64(22), int64(2)) +`, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), int64(2)) var dest []struct { common.ExpressionTestResult `alias:"result.*"` @@ -450,13 +452,13 @@ func TestBoolOperators(t *testing.T) { testutils.AssertStatementSql(t, query, ` SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1", - (all_types.boolean = $1) AS "EQ2", + (all_types.boolean = $1::boolean) AS "EQ2", (all_types.boolean != all_types.boolean_ptr) AS "NEq1", - (all_types.boolean != $2) AS "NEq2", + (all_types.boolean != $2::boolean) AS "NEq2", (all_types.boolean IS DISTINCT FROM all_types.boolean_ptr) AS "distinct1", - (all_types.boolean IS DISTINCT FROM $3) AS "distinct2", + (all_types.boolean IS DISTINCT FROM $3::boolean) AS "distinct2", (all_types.boolean IS NOT DISTINCT FROM all_types.boolean_ptr) AS "not_distinct_1", - (all_types.boolean IS NOT DISTINCT FROM $4) AS "NOTDISTINCT2", + (all_types.boolean IS NOT DISTINCT FROM $4::boolean) AS "NOTDISTINCT2", all_types.boolean IS TRUE AS "ISTRUE", all_types.boolean IS NOT TRUE AS "isnottrue", all_types.boolean IS FALSE AS "is_False", @@ -512,23 +514,23 @@ func TestFloatOperators(t *testing.T) { AllTypes.Numeric.GT(Float(124)).AS("gt1"), AllTypes.Numeric.GT(Float(34.56)).AS("gt2"), - TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Int(2)).AS("add1"), - TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int(2)).AS("add2"), - TRUNC(AllTypes.Decimal.SUB(AllTypes.DecimalPtr), Int(2)).AS("sub1"), - TRUNC(AllTypes.Decimal.SUB(Float(11.22)), Int(2)).AS("sub2"), - TRUNC(AllTypes.Decimal.MUL(AllTypes.DecimalPtr), Int(2)).AS("mul1"), - TRUNC(AllTypes.Decimal.MUL(Float(11.22)), Int(2)).AS("mul2"), - TRUNC(AllTypes.Decimal.DIV(AllTypes.DecimalPtr), Int(2)).AS("div1"), - TRUNC(AllTypes.Decimal.DIV(Float(11.22)), Int(2)).AS("div2"), - TRUNC(AllTypes.Decimal.MOD(AllTypes.DecimalPtr), Int(2)).AS("mod1"), - TRUNC(AllTypes.Decimal.MOD(Float(11.22)), Int(2)).AS("mod2"), - TRUNC(AllTypes.Decimal.POW(AllTypes.DecimalPtr), Int(2)).AS("pow1"), - TRUNC(AllTypes.Decimal.POW(Float(2.1)), Int(2)).AS("pow2"), + TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Uint8(2)).AS("add1"), + TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int8(2)).AS("add2"), + TRUNC(AllTypes.Decimal.SUB(AllTypes.DecimalPtr), Uint16(2)).AS("sub1"), + TRUNC(AllTypes.Decimal.SUB(Float(11.22)), Int16(2)).AS("sub2"), + TRUNC(AllTypes.Decimal.MUL(AllTypes.DecimalPtr), Int16(2)).AS("mul1"), + TRUNC(AllTypes.Decimal.MUL(Float(11.22)), Int32(2)).AS("mul2"), + TRUNC(AllTypes.Decimal.DIV(AllTypes.DecimalPtr), Int32(2)).AS("div1"), + TRUNC(AllTypes.Decimal.DIV(Float(11.22)), Int8(2)).AS("div2"), + TRUNC(AllTypes.Decimal.MOD(AllTypes.DecimalPtr), Int8(2)).AS("mod1"), + TRUNC(AllTypes.Decimal.MOD(Float(11.22)), Int8(2)).AS("mod2"), + TRUNC(AllTypes.Decimal.POW(AllTypes.DecimalPtr), Int8(2)).AS("pow1"), + TRUNC(AllTypes.Decimal.POW(Float(2.1)), Int8(2)).AS("pow2"), - TRUNC(ABSf(AllTypes.Decimal), Int(2)).AS("abs"), - TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int(2)).AS("power"), - TRUNC(SQRT(AllTypes.Decimal), Int(2)).AS("sqrt"), - TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"), + TRUNC(ABSf(AllTypes.Decimal), Int8(2)).AS("abs"), + TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int8(2)).AS("power"), + TRUNC(SQRT(AllTypes.Decimal), Int16(2)).AS("sqrt"), + TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int8(2)).AS("cbrt"), CEIL(AllTypes.Real).AS("ceil"), FLOOR(AllTypes.Real).AS("floor"), @@ -536,12 +538,12 @@ func TestFloatOperators(t *testing.T) { ROUND(AllTypes.Decimal, AllTypes.Integer).AS("round2"), SIGN(AllTypes.Real).AS("sign"), - TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"), + TRUNC(AllTypes.Decimal, Int32(1)).AS("trunc"), ).LIMIT(2) - queryStr, _ := query.Sql() + //fmt.Println(query.Sql()) - require.Equal(t, queryStr, ` + testutils.AssertStatementSql(t, query, ` SELECT (all_types.numeric = all_types.numeric) AS "eq1", (all_types.decimal = $1) AS "eq2", (all_types.real = $2) AS "eq3", @@ -555,28 +557,28 @@ SELECT (all_types.numeric = all_types.numeric) AS "eq1", (all_types.numeric < $8) AS "lt2", (all_types.numeric > $9) AS "gt1", (all_types.numeric > $10) AS "gt2", - TRUNC((all_types.decimal + all_types.decimal), $11) AS "add1", - TRUNC((all_types.decimal + $12), $13) AS "add2", - TRUNC((all_types.decimal - all_types.decimal_ptr), $14) AS "sub1", - TRUNC((all_types.decimal - $15), $16) AS "sub2", - TRUNC((all_types.decimal * all_types.decimal_ptr), $17) AS "mul1", - TRUNC((all_types.decimal * $18), $19) AS "mul2", - TRUNC((all_types.decimal / all_types.decimal_ptr), $20) AS "div1", - TRUNC((all_types.decimal / $21), $22) AS "div2", - TRUNC((all_types.decimal % all_types.decimal_ptr), $23) AS "mod1", - TRUNC((all_types.decimal % $24), $25) AS "mod2", - TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $26) AS "pow1", - TRUNC(POW(all_types.decimal, $27), $28) AS "pow2", - TRUNC(ABS(all_types.decimal), $29) AS "abs", - TRUNC(POWER(all_types.decimal, $30), $31) AS "power", - TRUNC(SQRT(all_types.decimal), $32) AS "sqrt", - TRUNC(CBRT(all_types.decimal)::decimal, $33) AS "cbrt", + TRUNC((all_types.decimal + all_types.decimal), $11::smallint) AS "add1", + TRUNC((all_types.decimal + $12), $13::smallint) AS "add2", + TRUNC((all_types.decimal - all_types.decimal_ptr), $14::integer) AS "sub1", + TRUNC((all_types.decimal - $15), $16::smallint) AS "sub2", + TRUNC((all_types.decimal * all_types.decimal_ptr), $17::smallint) AS "mul1", + TRUNC((all_types.decimal * $18), $19::integer) AS "mul2", + TRUNC((all_types.decimal / all_types.decimal_ptr), $20::integer) AS "div1", + TRUNC((all_types.decimal / $21), $22::smallint) AS "div2", + TRUNC((all_types.decimal % all_types.decimal_ptr), $23::smallint) AS "mod1", + TRUNC((all_types.decimal % $24), $25::smallint) AS "mod2", + TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $26::smallint) AS "pow1", + TRUNC(POW(all_types.decimal, $27), $28::smallint) AS "pow2", + TRUNC(ABS(all_types.decimal), $29::smallint) AS "abs", + TRUNC(POWER(all_types.decimal, $30), $31::smallint) AS "power", + TRUNC(SQRT(all_types.decimal), $32::smallint) AS "sqrt", + TRUNC(CBRT(all_types.decimal)::decimal, $33::smallint) AS "cbrt", CEIL(all_types.real) AS "ceil", FLOOR(all_types.real) AS "floor", ROUND(all_types.decimal) AS "round1", ROUND(all_types.decimal, all_types.integer) AS "round2", SIGN(all_types.real) AS "sign", - TRUNC(all_types.decimal, $34) AS "trunc" + TRUNC(all_types.decimal, $34::integer) AS "trunc" FROM test_sample.all_types LIMIT $35; `) @@ -602,46 +604,46 @@ func TestIntegerOperators(t *testing.T) { AllTypes.SmallIntPtr, AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"), - AllTypes.BigInt.EQ(Int(12)).AS("eq2"), + AllTypes.BigInt.EQ(Int64(12)).AS("eq2"), AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"), - AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"), + AllTypes.BigInt.NOT_EQ(Int64(12)).AS("neq2"), AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"), - AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"), + AllTypes.BigInt.IS_DISTINCT_FROM(Int32(12)).AS("distinct2"), AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"), - AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"), + AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int32(12)).AS("not distinct2"), AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"), - AllTypes.BigInt.LT(Int(65)).AS("lt2"), + AllTypes.BigInt.LT(Uint8(65)).AS("lt2"), AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"), - AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"), + AllTypes.BigInt.LT_EQ(Uint16(65)).AS("lte2"), AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"), - AllTypes.BigInt.GT(Int(65)).AS("gt2"), + AllTypes.BigInt.GT(Uint32(65)).AS("gt2"), AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"), - AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"), + AllTypes.BigInt.GT_EQ(Uint64(65)).AS("gte2"), AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"), AllTypes.BigInt.ADD(Int(11)).AS("add2"), AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"), - AllTypes.BigInt.SUB(Int(11)).AS("sub2"), + AllTypes.BigInt.SUB(Int8(11)).AS("sub2"), AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"), - AllTypes.BigInt.MUL(Int(11)).AS("mul2"), + AllTypes.BigInt.MUL(Int16(11)).AS("mul2"), AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"), - AllTypes.BigInt.DIV(Int(11)).AS("div2"), + AllTypes.BigInt.DIV(Int32(11)).AS("div2"), AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"), - AllTypes.BigInt.MOD(Int(11)).AS("mod2"), + AllTypes.BigInt.MOD(Int64(11)).AS("mod2"), - AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int(3))).AS("pow1"), - AllTypes.SmallInt.POW(Int(6)).AS("pow2"), + AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int8(3))).AS("pow1"), + AllTypes.SmallInt.POW(Int8(6)).AS("pow2"), AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"), AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"), @@ -655,7 +657,7 @@ func TestIntegerOperators(t *testing.T) { BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), BIT_NOT(Int(-11)).AS("bit_not_2"), - AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"), + AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int8(2))).AS("bit shift left 1"), AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"), AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"), @@ -666,7 +668,7 @@ func TestIntegerOperators(t *testing.T) { CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"), ).LIMIT(2) - //fmt.Println(query.Sql()) + // fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, ` SELECT all_types.big_int AS "all_types.big_int", @@ -674,33 +676,33 @@ SELECT all_types.big_int AS "all_types.big_int", all_types.small_int AS "all_types.small_int", all_types.small_int_ptr AS "all_types.small_int_ptr", (all_types.big_int = all_types.big_int) AS "eq1", - (all_types.big_int = $1) AS "eq2", + (all_types.big_int = $1::bigint) AS "eq2", (all_types.big_int != all_types.big_int_ptr) AS "neq1", - (all_types.big_int != $2) AS "neq2", + (all_types.big_int != $2::bigint) AS "neq2", (all_types.big_int IS DISTINCT FROM all_types.big_int) AS "distinct1", - (all_types.big_int IS DISTINCT FROM $3) AS "distinct2", + (all_types.big_int IS DISTINCT FROM $3::integer) AS "distinct2", (all_types.big_int IS NOT DISTINCT FROM all_types.big_int) AS "not distinct1", - (all_types.big_int IS NOT DISTINCT FROM $4) AS "not distinct2", + (all_types.big_int IS NOT DISTINCT FROM $4::integer) AS "not distinct2", (all_types.big_int < all_types.big_int_ptr) AS "lt1", - (all_types.big_int < $5) AS "lt2", + (all_types.big_int < $5::smallint) AS "lt2", (all_types.big_int <= all_types.big_int_ptr) AS "lte1", - (all_types.big_int <= $6) AS "lte2", + (all_types.big_int <= $6::integer) AS "lte2", (all_types.big_int > all_types.big_int_ptr) AS "gt1", - (all_types.big_int > $7) AS "gt2", + (all_types.big_int > $7::bigint) AS "gt2", (all_types.big_int >= all_types.big_int_ptr) AS "gte1", - (all_types.big_int >= $8) AS "gte2", + (all_types.big_int >= $8::bigint) AS "gte2", (all_types.big_int + all_types.big_int) AS "add1", (all_types.big_int + $9) AS "add2", (all_types.big_int - all_types.big_int) AS "sub1", - (all_types.big_int - $10) AS "sub2", + (all_types.big_int - $10::smallint) AS "sub2", (all_types.big_int * all_types.big_int) AS "mul1", - (all_types.big_int * $11) AS "mul2", + (all_types.big_int * $11::smallint) AS "mul2", (all_types.big_int / all_types.big_int) AS "div1", - (all_types.big_int / $12) AS "div2", + (all_types.big_int / $12::integer) AS "div2", (all_types.big_int % all_types.big_int) AS "mod1", - (all_types.big_int % $13) AS "mod2", - POW(all_types.small_int, (all_types.small_int / $14)) AS "pow1", - POW(all_types.small_int, $15) AS "pow2", + (all_types.big_int % $13::bigint) AS "mod2", + POW(all_types.small_int, (all_types.small_int / $14::smallint)) AS "pow1", + POW(all_types.small_int, $15::smallint) AS "pow2", (all_types.small_int & all_types.small_int) AS "bit_and1", (all_types.small_int & all_types.small_int) AS "bit_and2", (all_types.small_int | all_types.small_int) AS "bit or 1", @@ -709,7 +711,7 @@ SELECT all_types.big_int AS "all_types.big_int", (all_types.small_int # $17) AS "bit xor 2", (~ ($18 * all_types.small_int)) AS "bit_not_1", (~ -11) AS "bit_not_2", - (all_types.small_int << (all_types.small_int / $19)) AS "bit shift left 1", + (all_types.small_int << (all_types.small_int / $19::smallint)) AS "bit shift left 1", (all_types.small_int << $20) AS "bit shift left 2", (all_types.small_int >> (all_types.small_int / $21)) AS "bit shift right 1", (all_types.small_int >> $22) AS "bit shift right 2", diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 988499a..90bd8f9 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1943,7 +1943,7 @@ SELECT customer.customer_id AS "customer.customer_id", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer -WHERE ($1 AND (customer.customer_id = $2)) AND (customer.activebool = $3); +WHERE ($1::boolean AND (customer.customer_id = $2)) AND (customer.activebool = $3::boolean); `, true, int64(1), true) dest := []model.Customer{} @@ -2403,3 +2403,24 @@ func TestRecursionScanNx1(t *testing.T) { `) }) } + +// In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions, +// postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error. +// Int8, Int16, .... functions will add automatic type cast over placeholder, so type deduction is always possible. +func TestLiteralTypeDeduction(t *testing.T) { + stmt := SELECT( + SUM( + CASE().WHEN(Staff.Active.IS_TRUE()). + THEN(Int8(6)). // if Int8 and Int32 are replaced with Int, + ELSE(Int32(-1)), // execution of this statement will return an error + ).AS("num_passed"), + ).FROM(Staff) + + testutils.AssertStatementSql(t, stmt, ` +SELECT SUM((CASE WHEN staff.active IS TRUE THEN $1::smallint ELSE $2::integer END)) AS "num_passed" +FROM dvds.staff; +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 87ba49b..476333c 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -266,7 +266,7 @@ func TestUpdateWithModelData(t *testing.T) { expectedSQL := ` UPDATE test_sample.link SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL) -WHERE link.id = 201; +WHERE link.id = 201::integer; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) @@ -293,7 +293,7 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { var expectedSQL = ` UPDATE test_sample.link SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') -WHERE link.id = 201; +WHERE link.id = 201::integer; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201))