Add automatic type cast for integer literals

In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions,
postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error.
Int8, Int16, Int32.... functions now will add automatic type cast over a placeholder, so type deduction is always possible.
This commit is contained in:
go-jet 2021-12-26 17:15:43 +01:00
parent 47545ce571
commit 01305a138f
10 changed files with 146 additions and 106 deletions

View file

@ -29,7 +29,7 @@ ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`)
) )
assertClauseSerialize(t, onConflict, ` assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE
SET col_bool = $1, SET col_bool = $1::boolean,
col_int = $2 col_int = $2
WHERE table2.col_float > $3`) WHERE table2.col_float > $3`)
} }

View file

@ -33,7 +33,7 @@ func TestExists(t *testing.T) {
).EQ(Bool(true)), ).EQ(Bool(true)),
`((EXISTS ( `((EXISTS (
SELECT $1 SELECT $1
)) = $2)`, int64(1), true) )) = $2::boolean)`, int64(1), true)
assertProjectionSerialize(t, EXISTS( assertProjectionSerialize(t, EXISTS(
SELECT(Int(1)), SELECT(Int(1)),

View file

@ -165,7 +165,7 @@ VALUES ('one', 'two'),
('1', '2'), ('1', '2'),
('theta', 'beta') ('theta', 'beta')
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
SET col_bool = TRUE, SET col_bool = TRUE::boolean,
col_int = 1, col_int = 1,
(col1, col_bool) = ROW(2, 'two') (col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2 WHERE table1.col1 > 2
@ -191,7 +191,7 @@ INSERT INTO db.table1 (col1, col_bool)
VALUES ('one', 'two'), VALUES ('one', 'two'),
('1', '2') ('1', '2')
ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE
SET col_bool = FALSE, SET col_bool = FALSE::boolean,
col_int = 1, col_int = 1,
(col1, col_bool) = ROW(2, 'two') (col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2 WHERE table1.col1 > 2

View file

@ -67,7 +67,7 @@ func TestIntervalExpressionMethods(t *testing.T) {
assertSerialize(t, table1ColInterval.EQ(INTERVAL(10, SECOND)), "(table1.col_interval = INTERVAL '10 SECOND')") assertSerialize(t, table1ColInterval.EQ(INTERVAL(10, SECOND)), "(table1.col_interval = INTERVAL '10 SECOND')")
assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)), "(table1.col_interval = INTERVAL '11 MINUTE')") assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)), "(table1.col_interval = INTERVAL '11 MINUTE')")
assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)).EQ(Bool(false)), assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)).EQ(Bool(false)),
"((table1.col_interval = INTERVAL '11 MINUTE') = $1)", false) "((table1.col_interval = INTERVAL '11 MINUTE') = $1::boolean)", false)
assertSerialize(t, table1ColInterval.NOT_EQ(table2ColInterval), "(table1.col_interval != table2.col_interval)") assertSerialize(t, table1ColInterval.NOT_EQ(table2ColInterval), "(table1.col_interval != table2.col_interval)")
assertSerialize(t, table1ColInterval.IS_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS DISTINCT FROM table2.col_interval)") assertSerialize(t, table1ColInterval.IS_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS DISTINCT FROM table2.col_interval)")
assertSerialize(t, table1ColInterval.IS_NOT_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS NOT DISTINCT FROM table2.col_interval)") assertSerialize(t, table1ColInterval.IS_NOT_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS NOT DISTINCT FROM table2.col_interval)")

View file

@ -6,35 +6,52 @@ import (
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
) )
// Bool creates new bool literal expression func Bool(value bool) BoolExpression {
var Bool = jet.Bool return CAST(jet.Bool(value)).AS_BOOL()
}
// Int is constructor for 64 bit signed integer expressions literals. // Int is constructor for 64 bit signed integer expressions literals.
var Int = jet.Int var Int = jet.Int
// Int8 is constructor for 8 bit signed integer expressions literals. // Int8 is constructor for 8 bit signed integer expressions literals.
var Int8 = jet.Int8 func Int8(value int8) IntegerExpression {
return CAST(jet.Int8(value)).AS_SMALLINT()
}
// Int16 is constructor for 16 bit signed integer expressions literals. // Int16 is constructor for 16 bit signed integer expressions literals.
var Int16 = jet.Int16 func Int16(value int16) IntegerExpression {
return CAST(jet.Int16(value)).AS_SMALLINT()
}
// Int32 is constructor for 32 bit signed integer expressions literals. // Int32 is constructor for 32 bit signed integer expressions literals.
var Int32 = jet.Int32 func Int32(value int32) IntegerExpression {
return CAST(jet.Int32(value)).AS_INTEGER()
}
// Int64 is constructor for 64 bit signed integer expressions literals. // Int64 is constructor for 64 bit signed integer expressions literals.
var Int64 = jet.Int func Int64(value int64) IntegerExpression {
return CAST(jet.Int(value)).AS_BIGINT()
}
// Uint8 is constructor for 8 bit unsigned integer expressions literals. // Uint8 is constructor for 8 bit unsigned integer expressions literals.
var Uint8 = jet.Uint8 func Uint8(value uint8) IntegerExpression {
return CAST(jet.Uint8(value)).AS_SMALLINT()
}
// Uint16 is constructor for 16 bit unsigned integer expressions literals. // Uint16 is constructor for 16 bit unsigned integer expressions literals.
var Uint16 = jet.Uint16 func Uint16(value uint16) IntegerExpression {
return CAST(jet.Uint16(value)).AS_INTEGER()
}
// Uint32 is constructor for 32 bit unsigned integer expressions literals. // Uint32 is constructor for 32 bit unsigned integer expressions literals.
var Uint32 = jet.Uint32 func Uint32(value uint32) IntegerExpression {
return CAST(jet.Uint32(value)).AS_BIGINT()
}
// Uint64 is constructor for 64 bit unsigned integer expressions literals. // Uint64 is constructor for 64 bit unsigned integer expressions literals.
var Uint64 = jet.Uint64 func Uint64(value uint64) IntegerExpression {
return CAST(jet.Uint64(value)).AS_BIGINT()
}
// Float creates new float literal expression // Float creates new float literal expression
var Float = jet.Float var Float = jet.Float

View file

@ -7,7 +7,7 @@ import (
) )
func TestBool(t *testing.T) { func TestBool(t *testing.T) {
assertSerialize(t, Bool(false), `$1`, false) assertSerialize(t, Bool(false), `$1::boolean`, false)
} }
func TestInt(t *testing.T) { func TestInt(t *testing.T) {
@ -16,42 +16,42 @@ func TestInt(t *testing.T) {
func TestInt8(t *testing.T) { func TestInt8(t *testing.T) {
val := int8(math.MinInt8) val := int8(math.MinInt8)
assertSerialize(t, Int8(val), `$1`, val) assertSerialize(t, Int8(val), `$1::smallint`, val)
} }
func TestInt16(t *testing.T) { func TestInt16(t *testing.T) {
val := int16(math.MinInt16) val := int16(math.MinInt16)
assertSerialize(t, Int16(val), `$1`, val) assertSerialize(t, Int16(val), `$1::smallint`, val)
} }
func TestInt32(t *testing.T) { func TestInt32(t *testing.T) {
val := int32(math.MinInt32) val := int32(math.MinInt32)
assertSerialize(t, Int32(val), `$1`, val) assertSerialize(t, Int32(val), `$1::integer`, val)
} }
func TestInt64(t *testing.T) { func TestInt64(t *testing.T) {
val := int64(math.MinInt64) val := int64(math.MinInt64)
assertSerialize(t, Int64(val), `$1`, val) assertSerialize(t, Int64(val), `$1::bigint`, val)
} }
func TestUint8(t *testing.T) { func TestUint8(t *testing.T) {
val := uint8(math.MaxUint8) val := uint8(math.MaxUint8)
assertSerialize(t, Uint8(val), `$1`, val) assertSerialize(t, Uint8(val), `$1::smallint`, val)
} }
func TestUint16(t *testing.T) { func TestUint16(t *testing.T) {
val := uint16(math.MaxUint16) val := uint16(math.MaxUint16)
assertSerialize(t, Uint16(val), `$1`, val) assertSerialize(t, Uint16(val), `$1::integer`, val)
} }
func TestUint32(t *testing.T) { func TestUint32(t *testing.T) {
val := uint32(math.MaxUint32) val := uint32(math.MaxUint32)
assertSerialize(t, Uint32(val), `$1`, val) assertSerialize(t, Uint32(val), `$1::bigint`, val)
} }
func TestUint64(t *testing.T) { func TestUint64(t *testing.T) {
val := uint64(math.MaxUint64) val := uint64(math.MaxUint64)
assertSerialize(t, Uint64(val), `$1`, val) assertSerialize(t, Uint64(val), `$1::bigint`, val)
} }
func TestFloat(t *testing.T) { func TestFloat(t *testing.T) {

View file

@ -23,7 +23,7 @@ func TestSelectLiterals(t *testing.T) {
assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), `
SELECT $1, SELECT $1,
$2, $2,
$3 $3::boolean
FROM db.table1; FROM db.table1;
`, int64(1), 2.2, false) `, int64(1), 2.2, false)
} }
@ -59,7 +59,7 @@ func TestSelectWhere(t *testing.T) {
assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), `
SELECT table1.col_int AS "table1.col_int" SELECT table1.col_int AS "table1.col_int"
FROM db.table1 FROM db.table1
WHERE $1; WHERE $1::boolean;
`, true) `, true)
assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), `
SELECT table1.col_int AS "table1.col_int" SELECT table1.col_int AS "table1.col_int"
@ -80,7 +80,7 @@ func TestSelectHaving(t *testing.T) {
assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), `
SELECT table3.col_int AS "table3.col_int" SELECT table3.col_int AS "table3.col_int"
FROM db.table3 FROM db.table3
HAVING table1.col_bool = $1; HAVING table1.col_bool = $1::boolean;
`, true) `, true)
} }

View file

@ -225,7 +225,7 @@ func TestExpressionOperators(t *testing.T) {
query := AllTypes.SELECT( query := AllTypes.SELECT(
AllTypes.Integer.IS_NULL().AS("result.is_null"), AllTypes.Integer.IS_NULL().AS("result.is_null"),
AllTypes.DatePtr.IS_NOT_NULL().AS("result.is_not_null"), AllTypes.DatePtr.IS_NOT_NULL().AS("result.is_not_null"),
AllTypes.SmallIntPtr.IN(Int(11), Int(22)).AS("result.in"), AllTypes.SmallIntPtr.IN(Int8(11), Int8(22)).AS("result.in"),
AllTypes.SmallIntPtr.IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.in_select"), AllTypes.SmallIntPtr.IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.in_select"),
Raw("CURRENT_USER").AS("result.raw"), Raw("CURRENT_USER").AS("result.raw"),
@ -233,14 +233,16 @@ func TestExpressionOperators(t *testing.T) {
Raw("#1 + all_types.integer + #2 + #1 + #3 + #4", Raw("#1 + all_types.integer + #2 + #1 + #3 + #4",
RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}).AS("result.raw_arg2"), RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}).AS("result.raw_arg2"),
AllTypes.SmallIntPtr.NOT_IN(Int(11), Int(22), NULL).AS("result.not_in"), AllTypes.SmallIntPtr.NOT_IN(Int(11), Int16(22), NULL).AS("result.not_in"),
AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"),
).LIMIT(2) ).LIMIT(2)
//fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, ` testutils.AssertStatementSql(t, query, `
SELECT all_types.integer IS NULL AS "result.is_null", SELECT all_types.integer IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN ($1, $2)) AS "result.in", (all_types.small_int_ptr IN ($1::smallint, $2::smallint)) AS "result.in",
(all_types.small_int_ptr IN ( (all_types.small_int_ptr IN (
SELECT all_types.integer AS "all_types.integer" SELECT all_types.integer AS "all_types.integer"
FROM test_sample.all_types FROM test_sample.all_types
@ -248,14 +250,14 @@ SELECT all_types.integer IS NULL AS "result.is_null",
(CURRENT_USER) AS "result.raw", (CURRENT_USER) AS "result.raw",
($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg", ($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg",
($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2", ($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2",
(all_types.small_int_ptr NOT IN ($9, $10, NULL)) AS "result.not_in", (all_types.small_int_ptr NOT IN ($9, $10::smallint, NULL)) AS "result.not_in",
(all_types.small_int_ptr NOT IN ( (all_types.small_int_ptr NOT IN (
SELECT all_types.integer AS "all_types.integer" SELECT all_types.integer AS "all_types.integer"
FROM test_sample.all_types FROM test_sample.all_types
)) AS "result.not_in_select" )) AS "result.not_in_select"
FROM test_sample.all_types FROM test_sample.all_types
LIMIT $11; LIMIT $11;
`, int64(11), int64(22), 78, 56, 11, 22, 33, 44, int64(11), int64(22), int64(2)) `, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), int64(2))
var dest []struct { var dest []struct {
common.ExpressionTestResult `alias:"result.*"` common.ExpressionTestResult `alias:"result.*"`
@ -450,13 +452,13 @@ func TestBoolOperators(t *testing.T) {
testutils.AssertStatementSql(t, query, ` testutils.AssertStatementSql(t, query, `
SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1", SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1",
(all_types.boolean = $1) AS "EQ2", (all_types.boolean = $1::boolean) AS "EQ2",
(all_types.boolean != all_types.boolean_ptr) AS "NEq1", (all_types.boolean != all_types.boolean_ptr) AS "NEq1",
(all_types.boolean != $2) AS "NEq2", (all_types.boolean != $2::boolean) AS "NEq2",
(all_types.boolean IS DISTINCT FROM all_types.boolean_ptr) AS "distinct1", (all_types.boolean IS DISTINCT FROM all_types.boolean_ptr) AS "distinct1",
(all_types.boolean IS DISTINCT FROM $3) AS "distinct2", (all_types.boolean IS DISTINCT FROM $3::boolean) AS "distinct2",
(all_types.boolean IS NOT DISTINCT FROM all_types.boolean_ptr) AS "not_distinct_1", (all_types.boolean IS NOT DISTINCT FROM all_types.boolean_ptr) AS "not_distinct_1",
(all_types.boolean IS NOT DISTINCT FROM $4) AS "NOTDISTINCT2", (all_types.boolean IS NOT DISTINCT FROM $4::boolean) AS "NOTDISTINCT2",
all_types.boolean IS TRUE AS "ISTRUE", all_types.boolean IS TRUE AS "ISTRUE",
all_types.boolean IS NOT TRUE AS "isnottrue", all_types.boolean IS NOT TRUE AS "isnottrue",
all_types.boolean IS FALSE AS "is_False", all_types.boolean IS FALSE AS "is_False",
@ -512,23 +514,23 @@ func TestFloatOperators(t *testing.T) {
AllTypes.Numeric.GT(Float(124)).AS("gt1"), AllTypes.Numeric.GT(Float(124)).AS("gt1"),
AllTypes.Numeric.GT(Float(34.56)).AS("gt2"), AllTypes.Numeric.GT(Float(34.56)).AS("gt2"),
TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Int(2)).AS("add1"), TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Uint8(2)).AS("add1"),
TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int(2)).AS("add2"), TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int8(2)).AS("add2"),
TRUNC(AllTypes.Decimal.SUB(AllTypes.DecimalPtr), Int(2)).AS("sub1"), TRUNC(AllTypes.Decimal.SUB(AllTypes.DecimalPtr), Uint16(2)).AS("sub1"),
TRUNC(AllTypes.Decimal.SUB(Float(11.22)), Int(2)).AS("sub2"), TRUNC(AllTypes.Decimal.SUB(Float(11.22)), Int16(2)).AS("sub2"),
TRUNC(AllTypes.Decimal.MUL(AllTypes.DecimalPtr), Int(2)).AS("mul1"), TRUNC(AllTypes.Decimal.MUL(AllTypes.DecimalPtr), Int16(2)).AS("mul1"),
TRUNC(AllTypes.Decimal.MUL(Float(11.22)), Int(2)).AS("mul2"), TRUNC(AllTypes.Decimal.MUL(Float(11.22)), Int32(2)).AS("mul2"),
TRUNC(AllTypes.Decimal.DIV(AllTypes.DecimalPtr), Int(2)).AS("div1"), TRUNC(AllTypes.Decimal.DIV(AllTypes.DecimalPtr), Int32(2)).AS("div1"),
TRUNC(AllTypes.Decimal.DIV(Float(11.22)), Int(2)).AS("div2"), TRUNC(AllTypes.Decimal.DIV(Float(11.22)), Int8(2)).AS("div2"),
TRUNC(AllTypes.Decimal.MOD(AllTypes.DecimalPtr), Int(2)).AS("mod1"), TRUNC(AllTypes.Decimal.MOD(AllTypes.DecimalPtr), Int8(2)).AS("mod1"),
TRUNC(AllTypes.Decimal.MOD(Float(11.22)), Int(2)).AS("mod2"), TRUNC(AllTypes.Decimal.MOD(Float(11.22)), Int8(2)).AS("mod2"),
TRUNC(AllTypes.Decimal.POW(AllTypes.DecimalPtr), Int(2)).AS("pow1"), TRUNC(AllTypes.Decimal.POW(AllTypes.DecimalPtr), Int8(2)).AS("pow1"),
TRUNC(AllTypes.Decimal.POW(Float(2.1)), Int(2)).AS("pow2"), TRUNC(AllTypes.Decimal.POW(Float(2.1)), Int8(2)).AS("pow2"),
TRUNC(ABSf(AllTypes.Decimal), Int(2)).AS("abs"), TRUNC(ABSf(AllTypes.Decimal), Int8(2)).AS("abs"),
TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int(2)).AS("power"), TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int8(2)).AS("power"),
TRUNC(SQRT(AllTypes.Decimal), Int(2)).AS("sqrt"), TRUNC(SQRT(AllTypes.Decimal), Int16(2)).AS("sqrt"),
TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"), TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int8(2)).AS("cbrt"),
CEIL(AllTypes.Real).AS("ceil"), CEIL(AllTypes.Real).AS("ceil"),
FLOOR(AllTypes.Real).AS("floor"), FLOOR(AllTypes.Real).AS("floor"),
@ -536,12 +538,12 @@ func TestFloatOperators(t *testing.T) {
ROUND(AllTypes.Decimal, AllTypes.Integer).AS("round2"), ROUND(AllTypes.Decimal, AllTypes.Integer).AS("round2"),
SIGN(AllTypes.Real).AS("sign"), SIGN(AllTypes.Real).AS("sign"),
TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"), TRUNC(AllTypes.Decimal, Int32(1)).AS("trunc"),
).LIMIT(2) ).LIMIT(2)
queryStr, _ := query.Sql() //fmt.Println(query.Sql())
require.Equal(t, queryStr, ` testutils.AssertStatementSql(t, query, `
SELECT (all_types.numeric = all_types.numeric) AS "eq1", SELECT (all_types.numeric = all_types.numeric) AS "eq1",
(all_types.decimal = $1) AS "eq2", (all_types.decimal = $1) AS "eq2",
(all_types.real = $2) AS "eq3", (all_types.real = $2) AS "eq3",
@ -555,28 +557,28 @@ SELECT (all_types.numeric = all_types.numeric) AS "eq1",
(all_types.numeric < $8) AS "lt2", (all_types.numeric < $8) AS "lt2",
(all_types.numeric > $9) AS "gt1", (all_types.numeric > $9) AS "gt1",
(all_types.numeric > $10) AS "gt2", (all_types.numeric > $10) AS "gt2",
TRUNC((all_types.decimal + all_types.decimal), $11) AS "add1", TRUNC((all_types.decimal + all_types.decimal), $11::smallint) AS "add1",
TRUNC((all_types.decimal + $12), $13) AS "add2", TRUNC((all_types.decimal + $12), $13::smallint) AS "add2",
TRUNC((all_types.decimal - all_types.decimal_ptr), $14) AS "sub1", TRUNC((all_types.decimal - all_types.decimal_ptr), $14::integer) AS "sub1",
TRUNC((all_types.decimal - $15), $16) AS "sub2", TRUNC((all_types.decimal - $15), $16::smallint) AS "sub2",
TRUNC((all_types.decimal * all_types.decimal_ptr), $17) AS "mul1", TRUNC((all_types.decimal * all_types.decimal_ptr), $17::smallint) AS "mul1",
TRUNC((all_types.decimal * $18), $19) AS "mul2", TRUNC((all_types.decimal * $18), $19::integer) AS "mul2",
TRUNC((all_types.decimal / all_types.decimal_ptr), $20) AS "div1", TRUNC((all_types.decimal / all_types.decimal_ptr), $20::integer) AS "div1",
TRUNC((all_types.decimal / $21), $22) AS "div2", TRUNC((all_types.decimal / $21), $22::smallint) AS "div2",
TRUNC((all_types.decimal % all_types.decimal_ptr), $23) AS "mod1", TRUNC((all_types.decimal % all_types.decimal_ptr), $23::smallint) AS "mod1",
TRUNC((all_types.decimal % $24), $25) AS "mod2", TRUNC((all_types.decimal % $24), $25::smallint) AS "mod2",
TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $26) AS "pow1", TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $26::smallint) AS "pow1",
TRUNC(POW(all_types.decimal, $27), $28) AS "pow2", TRUNC(POW(all_types.decimal, $27), $28::smallint) AS "pow2",
TRUNC(ABS(all_types.decimal), $29) AS "abs", TRUNC(ABS(all_types.decimal), $29::smallint) AS "abs",
TRUNC(POWER(all_types.decimal, $30), $31) AS "power", TRUNC(POWER(all_types.decimal, $30), $31::smallint) AS "power",
TRUNC(SQRT(all_types.decimal), $32) AS "sqrt", TRUNC(SQRT(all_types.decimal), $32::smallint) AS "sqrt",
TRUNC(CBRT(all_types.decimal)::decimal, $33) AS "cbrt", TRUNC(CBRT(all_types.decimal)::decimal, $33::smallint) AS "cbrt",
CEIL(all_types.real) AS "ceil", CEIL(all_types.real) AS "ceil",
FLOOR(all_types.real) AS "floor", FLOOR(all_types.real) AS "floor",
ROUND(all_types.decimal) AS "round1", ROUND(all_types.decimal) AS "round1",
ROUND(all_types.decimal, all_types.integer) AS "round2", ROUND(all_types.decimal, all_types.integer) AS "round2",
SIGN(all_types.real) AS "sign", SIGN(all_types.real) AS "sign",
TRUNC(all_types.decimal, $34) AS "trunc" TRUNC(all_types.decimal, $34::integer) AS "trunc"
FROM test_sample.all_types FROM test_sample.all_types
LIMIT $35; LIMIT $35;
`) `)
@ -602,46 +604,46 @@ func TestIntegerOperators(t *testing.T) {
AllTypes.SmallIntPtr, AllTypes.SmallIntPtr,
AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"), AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"),
AllTypes.BigInt.EQ(Int(12)).AS("eq2"), AllTypes.BigInt.EQ(Int64(12)).AS("eq2"),
AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"), AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"),
AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"), AllTypes.BigInt.NOT_EQ(Int64(12)).AS("neq2"),
AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"), AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"),
AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"), AllTypes.BigInt.IS_DISTINCT_FROM(Int32(12)).AS("distinct2"),
AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"), AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"),
AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"), AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int32(12)).AS("not distinct2"),
AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"), AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"),
AllTypes.BigInt.LT(Int(65)).AS("lt2"), AllTypes.BigInt.LT(Uint8(65)).AS("lt2"),
AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"), AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"),
AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"), AllTypes.BigInt.LT_EQ(Uint16(65)).AS("lte2"),
AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"), AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"),
AllTypes.BigInt.GT(Int(65)).AS("gt2"), AllTypes.BigInt.GT(Uint32(65)).AS("gt2"),
AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"), AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"),
AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"), AllTypes.BigInt.GT_EQ(Uint64(65)).AS("gte2"),
AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"), AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"),
AllTypes.BigInt.ADD(Int(11)).AS("add2"), AllTypes.BigInt.ADD(Int(11)).AS("add2"),
AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"), AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"),
AllTypes.BigInt.SUB(Int(11)).AS("sub2"), AllTypes.BigInt.SUB(Int8(11)).AS("sub2"),
AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"), AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"),
AllTypes.BigInt.MUL(Int(11)).AS("mul2"), AllTypes.BigInt.MUL(Int16(11)).AS("mul2"),
AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"), AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"),
AllTypes.BigInt.DIV(Int(11)).AS("div2"), AllTypes.BigInt.DIV(Int32(11)).AS("div2"),
AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"), AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"),
AllTypes.BigInt.MOD(Int(11)).AS("mod2"), AllTypes.BigInt.MOD(Int64(11)).AS("mod2"),
AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int(3))).AS("pow1"), AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int8(3))).AS("pow1"),
AllTypes.SmallInt.POW(Int(6)).AS("pow2"), AllTypes.SmallInt.POW(Int8(6)).AS("pow2"),
AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"), AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"),
AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"), AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"),
@ -655,7 +657,7 @@ func TestIntegerOperators(t *testing.T) {
BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"),
BIT_NOT(Int(-11)).AS("bit_not_2"), BIT_NOT(Int(-11)).AS("bit_not_2"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"), AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int8(2))).AS("bit shift left 1"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"), AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"),
AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"), AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"),
@ -666,7 +668,7 @@ func TestIntegerOperators(t *testing.T) {
CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"), CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"),
).LIMIT(2) ).LIMIT(2)
//fmt.Println(query.Sql()) // fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, ` testutils.AssertStatementSql(t, query, `
SELECT all_types.big_int AS "all_types.big_int", SELECT all_types.big_int AS "all_types.big_int",
@ -674,33 +676,33 @@ SELECT all_types.big_int AS "all_types.big_int",
all_types.small_int AS "all_types.small_int", all_types.small_int AS "all_types.small_int",
all_types.small_int_ptr AS "all_types.small_int_ptr", all_types.small_int_ptr AS "all_types.small_int_ptr",
(all_types.big_int = all_types.big_int) AS "eq1", (all_types.big_int = all_types.big_int) AS "eq1",
(all_types.big_int = $1) AS "eq2", (all_types.big_int = $1::bigint) AS "eq2",
(all_types.big_int != all_types.big_int_ptr) AS "neq1", (all_types.big_int != all_types.big_int_ptr) AS "neq1",
(all_types.big_int != $2) AS "neq2", (all_types.big_int != $2::bigint) AS "neq2",
(all_types.big_int IS DISTINCT FROM all_types.big_int) AS "distinct1", (all_types.big_int IS DISTINCT FROM all_types.big_int) AS "distinct1",
(all_types.big_int IS DISTINCT FROM $3) AS "distinct2", (all_types.big_int IS DISTINCT FROM $3::integer) AS "distinct2",
(all_types.big_int IS NOT DISTINCT FROM all_types.big_int) AS "not distinct1", (all_types.big_int IS NOT DISTINCT FROM all_types.big_int) AS "not distinct1",
(all_types.big_int IS NOT DISTINCT FROM $4) AS "not distinct2", (all_types.big_int IS NOT DISTINCT FROM $4::integer) AS "not distinct2",
(all_types.big_int < all_types.big_int_ptr) AS "lt1", (all_types.big_int < all_types.big_int_ptr) AS "lt1",
(all_types.big_int < $5) AS "lt2", (all_types.big_int < $5::smallint) AS "lt2",
(all_types.big_int <= all_types.big_int_ptr) AS "lte1", (all_types.big_int <= all_types.big_int_ptr) AS "lte1",
(all_types.big_int <= $6) AS "lte2", (all_types.big_int <= $6::integer) AS "lte2",
(all_types.big_int > all_types.big_int_ptr) AS "gt1", (all_types.big_int > all_types.big_int_ptr) AS "gt1",
(all_types.big_int > $7) AS "gt2", (all_types.big_int > $7::bigint) AS "gt2",
(all_types.big_int >= all_types.big_int_ptr) AS "gte1", (all_types.big_int >= all_types.big_int_ptr) AS "gte1",
(all_types.big_int >= $8) AS "gte2", (all_types.big_int >= $8::bigint) AS "gte2",
(all_types.big_int + all_types.big_int) AS "add1", (all_types.big_int + all_types.big_int) AS "add1",
(all_types.big_int + $9) AS "add2", (all_types.big_int + $9) AS "add2",
(all_types.big_int - all_types.big_int) AS "sub1", (all_types.big_int - all_types.big_int) AS "sub1",
(all_types.big_int - $10) AS "sub2", (all_types.big_int - $10::smallint) AS "sub2",
(all_types.big_int * all_types.big_int) AS "mul1", (all_types.big_int * all_types.big_int) AS "mul1",
(all_types.big_int * $11) AS "mul2", (all_types.big_int * $11::smallint) AS "mul2",
(all_types.big_int / all_types.big_int) AS "div1", (all_types.big_int / all_types.big_int) AS "div1",
(all_types.big_int / $12) AS "div2", (all_types.big_int / $12::integer) AS "div2",
(all_types.big_int % all_types.big_int) AS "mod1", (all_types.big_int % all_types.big_int) AS "mod1",
(all_types.big_int % $13) AS "mod2", (all_types.big_int % $13::bigint) AS "mod2",
POW(all_types.small_int, (all_types.small_int / $14)) AS "pow1", POW(all_types.small_int, (all_types.small_int / $14::smallint)) AS "pow1",
POW(all_types.small_int, $15) AS "pow2", POW(all_types.small_int, $15::smallint) AS "pow2",
(all_types.small_int & all_types.small_int) AS "bit_and1", (all_types.small_int & all_types.small_int) AS "bit_and1",
(all_types.small_int & all_types.small_int) AS "bit_and2", (all_types.small_int & all_types.small_int) AS "bit_and2",
(all_types.small_int | all_types.small_int) AS "bit or 1", (all_types.small_int | all_types.small_int) AS "bit or 1",
@ -709,7 +711,7 @@ SELECT all_types.big_int AS "all_types.big_int",
(all_types.small_int # $17) AS "bit xor 2", (all_types.small_int # $17) AS "bit xor 2",
(~ ($18 * all_types.small_int)) AS "bit_not_1", (~ ($18 * all_types.small_int)) AS "bit_not_1",
(~ -11) AS "bit_not_2", (~ -11) AS "bit_not_2",
(all_types.small_int << (all_types.small_int / $19)) AS "bit shift left 1", (all_types.small_int << (all_types.small_int / $19::smallint)) AS "bit shift left 1",
(all_types.small_int << $20) AS "bit shift left 2", (all_types.small_int << $20) AS "bit shift left 2",
(all_types.small_int >> (all_types.small_int / $21)) AS "bit shift right 1", (all_types.small_int >> (all_types.small_int / $21)) AS "bit shift right 1",
(all_types.small_int >> $22) AS "bit shift right 2", (all_types.small_int >> $22) AS "bit shift right 2",

View file

@ -1943,7 +1943,7 @@ SELECT customer.customer_id AS "customer.customer_id",
customer.last_update AS "customer.last_update", customer.last_update AS "customer.last_update",
customer.active AS "customer.active" customer.active AS "customer.active"
FROM dvds.customer FROM dvds.customer
WHERE ($1 AND (customer.customer_id = $2)) AND (customer.activebool = $3); WHERE ($1::boolean AND (customer.customer_id = $2)) AND (customer.activebool = $3::boolean);
`, true, int64(1), true) `, true, int64(1), true)
dest := []model.Customer{} dest := []model.Customer{}
@ -2403,3 +2403,24 @@ func TestRecursionScanNx1(t *testing.T) {
`) `)
}) })
} }
// In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions,
// postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error.
// Int8, Int16, .... functions will add automatic type cast over placeholder, so type deduction is always possible.
func TestLiteralTypeDeduction(t *testing.T) {
stmt := SELECT(
SUM(
CASE().WHEN(Staff.Active.IS_TRUE()).
THEN(Int8(6)). // if Int8 and Int32 are replaced with Int,
ELSE(Int32(-1)), // execution of this statement will return an error
).AS("num_passed"),
).FROM(Staff)
testutils.AssertStatementSql(t, stmt, `
SELECT SUM((CASE WHEN staff.active IS TRUE THEN $1::smallint ELSE $2::integer END)) AS "num_passed"
FROM dvds.staff;
`)
err := stmt.Query(db, &struct{}{})
require.NoError(t, err)
}

View file

@ -266,7 +266,7 @@ func TestUpdateWithModelData(t *testing.T) {
expectedSQL := ` expectedSQL := `
UPDATE test_sample.link UPDATE test_sample.link
SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL) SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL)
WHERE link.id = 201; WHERE link.id = 201::integer;
` `
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201))
@ -293,7 +293,7 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
var expectedSQL = ` var expectedSQL = `
UPDATE test_sample.link UPDATE test_sample.link
SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com')
WHERE link.id = 201; WHERE link.id = 201::integer;
` `
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201))