From dca028295d4e7426865db67764e2ea9bdd547240 Mon Sep 17 00:00:00 2001 From: zer0sub Date: Mon, 3 Jun 2019 17:38:47 +0200 Subject: [PATCH] Conditional expression functions. --- sqlbuilder/func_expression.go | 24 ++++++++++++++++++++++++ sqlbuilder/func_expression_test.go | 29 +++++++++++++++-------------- sqlbuilder/keyword.go | 4 ++++ sqlbuilder/literal_expression.go | 18 ++++++++++++++++++ sqlbuilder/operators.go | 2 +- sqlbuilder/operators_test.go | 22 ++++++++++++++++++++++ tests/types_test.go | 5 +++++ 7 files changed, 89 insertions(+), 15 deletions(-) create mode 100644 sqlbuilder/operators_test.go diff --git a/sqlbuilder/func_expression.go b/sqlbuilder/func_expression.go index a9816ee..ea1f425 100644 --- a/sqlbuilder/func_expression.go +++ b/sqlbuilder/func_expression.go @@ -503,3 +503,27 @@ func LOCALTIMESTAMP(precision ...int) TimestampExpression { func NOW() TimestampzExpression { return newTimestampzFunc("NOW") } + +// --------------- Conditional Expressions Functions -------------// + +func COALESCE(value expression, values ...expression) expression { + var allValues = []expression{value} + allValues = append(allValues, values...) + return newFunc("COALESCE", allValues, nil) +} + +func NULLIF(value1, value2 expression) expression { + return newFunc("NULLIF", []expression{value1, value2}, nil) +} + +func GREATEST(value expression, values ...expression) expression { + var allValues = []expression{value} + allValues = append(allValues, values...) + return newFunc("GREATEST", allValues, nil) +} + +func LEAST(value expression, values ...expression) expression { + var allValues = []expression{value} + allValues = append(allValues, values...) + return newFunc("LEAST", allValues, nil) +} diff --git a/sqlbuilder/func_expression_test.go b/sqlbuilder/func_expression_test.go index 5f7c553..b6276a5 100644 --- a/sqlbuilder/func_expression_test.go +++ b/sqlbuilder/func_expression_test.go @@ -118,23 +118,24 @@ func TestFuncLOG(t *testing.T) { assertExpressionSerialize(t, LOG(Float(11.2222)), "LOG($1)", float64(11.2222)) } -func TestCase1(t *testing.T) { - query := CASE(). - WHEN(table3Col1.EQ(Int(1))).THEN(table3Col1.ADD(Int(1))). - WHEN(table3Col1.EQ(Int(2))).THEN(table3Col1.ADD(Int(2))) - - assertExpressionSerialize(t, query, `(CASE WHEN table3.col1 = $1 THEN table3.col1 + $2 WHEN table3.col1 = $3 THEN table3.col1 + $4 END)`, - int64(1), int64(1), int64(2), int64(2)) +func TestFuncCOALESCE(t *testing.T) { + assertExpressionSerialize(t, COALESCE(table1ColFloat), "COALESCE(table1.colFloat)") + assertExpressionSerialize(t, COALESCE(Float(11.2222), NULL, String("str")), "COALESCE($1, NULL, $2)", float64(11.2222), "str") } -func TestCase2(t *testing.T) { - query := CASE(table3Col1). - WHEN(Int(1)).THEN(table3Col1.ADD(Int(1))). - WHEN(Int(2)).THEN(table3Col1.ADD(Int(2))). - ELSE(Int(0)) +func TestFuncNULLIF(t *testing.T) { + assertExpressionSerialize(t, NULLIF(table1ColFloat, table2ColInt), "NULLIF(table1.colFloat, table2.colInt)") + assertExpressionSerialize(t, NULLIF(Float(11.2222), NULL), "NULLIF($1, NULL)", float64(11.2222)) +} - assertExpressionSerialize(t, query, `(CASE table3.col1 WHEN $1 THEN table3.col1 + $2 WHEN $3 THEN table3.col1 + $4 ELSE $5 END)`, - int64(1), int64(1), int64(2), int64(2), int64(0)) +func TestFuncGREATEST(t *testing.T) { + assertExpressionSerialize(t, GREATEST(table1ColFloat), "GREATEST(table1.colFloat)") + assertExpressionSerialize(t, GREATEST(Float(11.2222), NULL, String("str")), "GREATEST($1, NULL, $2)", float64(11.2222), "str") +} + +func TestFuncLEAST(t *testing.T) { + assertExpressionSerialize(t, LEAST(table1ColFloat), "LEAST(table1.colFloat)") + assertExpressionSerialize(t, LEAST(Float(11.2222), NULL, String("str")), "LEAST($1, NULL, $2)", float64(11.2222), "str") } func TestInterval(t *testing.T) { diff --git a/sqlbuilder/keyword.go b/sqlbuilder/keyword.go index fb3f8b3..5e5a94f 100644 --- a/sqlbuilder/keyword.go +++ b/sqlbuilder/keyword.go @@ -4,6 +4,10 @@ const ( DEFAULT keywordClause = "DEFAULT" ) +var ( + NULL = newNullExpression() +) + type keywordClause string func (k keywordClause) serialize(statement statementType, out *queryData, options ...serializeOption) error { diff --git a/sqlbuilder/literal_expression.go b/sqlbuilder/literal_expression.go index 247b1a0..4331570 100644 --- a/sqlbuilder/literal_expression.go +++ b/sqlbuilder/literal_expression.go @@ -175,3 +175,21 @@ func Date(year, month, day int) DateExpression { return dateLiteral.CAST_TO_DATE() } + +//--------------------------------------------------// +type nullExpression struct { + expressionInterfaceImpl +} + +func newNullExpression() expression { + nullExpression := &nullExpression{} + + nullExpression.expressionInterfaceImpl.parent = nullExpression + + return nullExpression +} + +func (n *nullExpression) serialize(statement statementType, out *queryData, options ...serializeOption) error { + out.writeString("NULL") + return nil +} diff --git a/sqlbuilder/operators.go b/sqlbuilder/operators.go index 1e446e3..c224e75 100644 --- a/sqlbuilder/operators.go +++ b/sqlbuilder/operators.go @@ -124,7 +124,7 @@ type caseOperatorImpl struct { func CASE(expression ...expression) caseOperatorExpression { caseExp := &caseOperatorImpl{} - if len(expression) == 1 { + if len(expression) > 0 { caseExp.expression = expression[0] } diff --git a/sqlbuilder/operators_test.go b/sqlbuilder/operators_test.go new file mode 100644 index 0000000..568639b --- /dev/null +++ b/sqlbuilder/operators_test.go @@ -0,0 +1,22 @@ +package sqlbuilder + +import "testing" + +func TestCase1(t *testing.T) { + query := CASE(). + WHEN(table3Col1.EQ(Int(1))).THEN(table3Col1.ADD(Int(1))). + WHEN(table3Col1.EQ(Int(2))).THEN(table3Col1.ADD(Int(2))) + + assertExpressionSerialize(t, query, `(CASE WHEN table3.col1 = $1 THEN table3.col1 + $2 WHEN table3.col1 = $3 THEN table3.col1 + $4 END)`, + int64(1), int64(1), int64(2), int64(2)) +} + +func TestCase2(t *testing.T) { + query := CASE(table3Col1). + WHEN(Int(1)).THEN(table3Col1.ADD(Int(1))). + WHEN(Int(2)).THEN(table3Col1.ADD(Int(2))). + ELSE(Int(0)) + + assertExpressionSerialize(t, query, `(CASE table3.col1 WHEN $1 THEN table3.col1 + $2 WHEN $3 THEN table3.col1 + $4 ELSE $5 END)`, + int64(1), int64(1), int64(2), int64(2), int64(0)) +} diff --git a/tests/types_test.go b/tests/types_test.go index c12ac6b..0687e44 100644 --- a/tests/types_test.go +++ b/tests/types_test.go @@ -48,6 +48,11 @@ func TestExpressionOperators(t *testing.T) { TO_DATE(String("05 Dec 2000"), String("DD Mon YYYY")), TO_NUMBER(String("12,454"), String("99G999D9S")), TO_TIMESTAMP(String("05 Dec 2000"), String("DD Mon YYYY")), + + COALESCE(AllTypes.IntegerPtr, AllTypes.SmallintPtr, NULL, Int(11)), + NULLIF(AllTypes.Text, String("(none)")), + GREATEST(AllTypes.Numeric, AllTypes.NumericPtr), + LEAST(AllTypes.Numeric, AllTypes.NumericPtr), ) fmt.Println(query.DebugSql())