From d69c67569a54646ae72b0f8cb1742ef148100a0e Mon Sep 17 00:00:00 2001 From: zer0sub Date: Mon, 3 Jun 2019 18:28:16 +0200 Subject: [PATCH] Aggregate functions --- sqlbuilder/func_expression.go | 57 ++++++++++++---- sqlbuilder/func_expression_test.go | 106 +++++++++++++++++++---------- sqlbuilder/keyword.go | 3 +- sqlbuilder/literal_expression.go | 26 +++++-- tests/select_test.go | 26 +++++-- 5 files changed, 159 insertions(+), 59 deletions(-) diff --git a/sqlbuilder/func_expression.go b/sqlbuilder/func_expression.go index ea1f425..b2453ad 100644 --- a/sqlbuilder/func_expression.go +++ b/sqlbuilder/func_expression.go @@ -10,10 +10,6 @@ type funcExpressionImpl struct { noBrackets bool } -func ROW(expressions ...expression) expression { - return newFunc("ROW", expressions, nil) -} - func newFunc(name string, expressions []expression, parent expression) *funcExpressionImpl { funcExp := &funcExpressionImpl{ name: name, @@ -53,6 +49,9 @@ func (f *funcExpressionImpl) serialize(statement statementType, out *queryData, return nil } +func ROW(expressions ...expression) expression { + return newFunc("ROW", expressions, nil) +} type boolFunc struct { funcExpressionImpl @@ -240,7 +239,39 @@ func LOG(floatExpression FloatExpression) FloatExpression { return newFloatFunc("LOG", floatExpression) } -// ----------------- Group function operators -------------------// +// ----------------- Aggregate functions -------------------// + +func AVGf(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("AVG", floatExpression) +} + +func AVGi(integerExpression IntegerExpression) FloatExpression { + return newFloatFunc("AVG", integerExpression) +} + +func BIT_AND(integerExpression IntegerExpression) IntegerExpression { + return newIntegerFunc("BIT_AND", integerExpression) +} + +func BIT_OR(integerExpression IntegerExpression) IntegerExpression { + return newIntegerFunc("BIT_OR", integerExpression) +} + +func BOOL_AND(boolExpression BoolExpression) BoolExpression { + return newBoolFunc("BOOL_AND", boolExpression) +} + +func BOOL_OR(boolExpression BoolExpression) BoolExpression { + return newBoolFunc("BOOL_OR", boolExpression) +} + +func COUNT(expression expression) IntegerExpression { + return newIntegerFunc("COUNT", expression) +} + +func EVERY(boolExpression BoolExpression) BoolExpression { + return newBoolFunc("EVERY", boolExpression) +} func MAXf(floatExpression FloatExpression) FloatExpression { return newFloatFunc("MAX", floatExpression) @@ -250,6 +281,14 @@ func MAXi(integerExpression IntegerExpression) IntegerExpression { return newIntegerFunc("MAX", integerExpression) } +func MINf(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("MIN", floatExpression) +} + +func MINi(integerExpression IntegerExpression) IntegerExpression { + return newIntegerFunc("MIN", integerExpression) +} + func SUMf(floatExpression FloatExpression) FloatExpression { return newFloatFunc("SUM", floatExpression) } @@ -258,14 +297,6 @@ func SUMi(integerExpression IntegerExpression) IntegerExpression { return newIntegerFunc("SUM", integerExpression) } -func COUNTf(floatExpression FloatExpression) FloatExpression { - return newFloatFunc("COUNT", floatExpression) -} - -func COUNTi(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("COUNT", integerExpression) -} - //------------ String functions ------------------// func BIT_LENGTH(stringExpression StringExpression) IntegerExpression { diff --git a/sqlbuilder/func_expression_test.go b/sqlbuilder/func_expression_test.go index b6276a5..ae0d1fc 100644 --- a/sqlbuilder/func_expression_test.go +++ b/sqlbuilder/func_expression_test.go @@ -5,6 +5,76 @@ import ( "testing" ) +func TestFuncAVG(t *testing.T) { + t.Run("float", func(t *testing.T) { + assertExpressionSerialize(t, AVGf(table1ColFloat), "AVG(table1.colFloat)") + }) + + t.Run("integer", func(t *testing.T) { + assertExpressionSerialize(t, AVGi(table1ColInt), "AVG(table1.colInt)") + }) +} + +func TestFuncBIT_AND(t *testing.T) { + assertExpressionSerialize(t, BIT_AND(table1ColInt), "BIT_AND(table1.colInt)") +} + +func TestFuncBIT_OR(t *testing.T) { + assertExpressionSerialize(t, BIT_OR(table1ColInt), "BIT_OR(table1.colInt)") +} + +func TestFuncBOOL_AND(t *testing.T) { + assertExpressionSerialize(t, BOOL_AND(table1ColBool), "BOOL_AND(table1.colBool)") +} + +func TestFuncBOOL_OR(t *testing.T) { + assertExpressionSerialize(t, BOOL_OR(table1ColBool), "BOOL_OR(table1.colBool)") +} + +func TestFuncEVERY(t *testing.T) { + assertExpressionSerialize(t, EVERY(table1ColBool), "EVERY(table1.colBool)") +} + +func TestFuncMIN(t *testing.T) { + t.Run("float", func(t *testing.T) { + assertExpressionSerialize(t, MINf(table1ColFloat), "MIN(table1.colFloat)") + }) + + t.Run("integer", func(t *testing.T) { + assertExpressionSerialize(t, MINi(table1ColInt), "MIN(table1.colInt)") + }) +} + +func TestFuncMAX(t *testing.T) { + t.Run("float", func(t *testing.T) { + assertExpressionSerialize(t, MAXf(table1ColFloat), "MAX(table1.colFloat)") + assertExpressionSerialize(t, MAXf(Float(11.2222)), "MAX($1)", float64(11.2222)) + }) + + t.Run("integer", func(t *testing.T) { + assertExpressionSerialize(t, MAXi(table1ColInt), "MAX(table1.colInt)") + assertExpressionSerialize(t, MAXi(Int(11)), "MAX($1)", int64(11)) + }) +} + +func TestFuncSUM(t *testing.T) { + t.Run("float", func(t *testing.T) { + assertExpressionSerialize(t, SUMf(table1ColFloat), "SUM(table1.colFloat)") + assertExpressionSerialize(t, SUMf(Float(11.2222)), "SUM($1)", float64(11.2222)) + }) + + t.Run("integer", func(t *testing.T) { + assertExpressionSerialize(t, SUMi(table1ColInt), "SUM(table1.colInt)") + assertExpressionSerialize(t, SUMi(Int(11)), "SUM($1)", int64(11)) + }) +} + +func TestFuncCOUNT(t *testing.T) { + assertExpressionSerialize(t, COUNT(STAR), "COUNT(*)") + assertExpressionSerialize(t, COUNT(table1ColFloat), "COUNT(table1.colFloat)") + assertExpressionSerialize(t, COUNT(Float(11.2222)), "COUNT($1)", float64(11.2222)) +} + func TestFuncABS(t *testing.T) { t.Run("float", func(t *testing.T) { assertExpressionSerialize(t, ABSf(table1ColFloat), "ABS(table1.colFloat)") @@ -41,42 +111,6 @@ func TestFuncCBRT(t *testing.T) { }) } -func TestFuncMAX(t *testing.T) { - t.Run("float", func(t *testing.T) { - assertExpressionSerialize(t, MAXf(table1ColFloat), "MAX(table1.colFloat)") - assertExpressionSerialize(t, MAXf(Float(11.2222)), "MAX($1)", float64(11.2222)) - }) - - t.Run("integer", func(t *testing.T) { - assertExpressionSerialize(t, MAXi(table1ColInt), "MAX(table1.colInt)") - assertExpressionSerialize(t, MAXi(Int(11)), "MAX($1)", int64(11)) - }) -} - -func TestFuncSUM(t *testing.T) { - t.Run("float", func(t *testing.T) { - assertExpressionSerialize(t, SUMf(table1ColFloat), "SUM(table1.colFloat)") - assertExpressionSerialize(t, SUMf(Float(11.2222)), "SUM($1)", float64(11.2222)) - }) - - t.Run("integer", func(t *testing.T) { - assertExpressionSerialize(t, SUMi(table1ColInt), "SUM(table1.colInt)") - assertExpressionSerialize(t, SUMi(Int(11)), "SUM($1)", int64(11)) - }) -} - -func TestFuncCOUNT(t *testing.T) { - t.Run("float", func(t *testing.T) { - assertExpressionSerialize(t, COUNTf(table1ColFloat), "COUNT(table1.colFloat)") - assertExpressionSerialize(t, COUNTf(Float(11.2222)), "COUNT($1)", float64(11.2222)) - }) - - t.Run("integer", func(t *testing.T) { - assertExpressionSerialize(t, COUNTi(table1ColInt), "COUNT(table1.colInt)") - assertExpressionSerialize(t, COUNTi(Int(11)), "COUNT($1)", int64(11)) - }) -} - func TestFuncCEIL(t *testing.T) { assertExpressionSerialize(t, CEIL(table1ColFloat), "CEIL(table1.colFloat)") assertExpressionSerialize(t, CEIL(Float(11.2222)), "CEIL($1)", float64(11.2222)) diff --git a/sqlbuilder/keyword.go b/sqlbuilder/keyword.go index 5e5a94f..70079b9 100644 --- a/sqlbuilder/keyword.go +++ b/sqlbuilder/keyword.go @@ -5,7 +5,8 @@ const ( ) var ( - NULL = newNullExpression() + NULL = newNullLiteral() + STAR = newStarLiteral() ) type keywordClause string diff --git a/sqlbuilder/literal_expression.go b/sqlbuilder/literal_expression.go index 4331570..54f4c8f 100644 --- a/sqlbuilder/literal_expression.go +++ b/sqlbuilder/literal_expression.go @@ -177,19 +177,37 @@ func Date(year, month, day int) DateExpression { } //--------------------------------------------------// -type nullExpression struct { +type nullLiteral struct { expressionInterfaceImpl } -func newNullExpression() expression { - nullExpression := &nullExpression{} +func newNullLiteral() expression { + nullExpression := &nullLiteral{} nullExpression.expressionInterfaceImpl.parent = nullExpression return nullExpression } -func (n *nullExpression) serialize(statement statementType, out *queryData, options ...serializeOption) error { +func (n *nullLiteral) serialize(statement statementType, out *queryData, options ...serializeOption) error { out.writeString("NULL") return nil } + +//--------------------------------------------------// +type starLiteral struct { + expressionInterfaceImpl +} + +func newStarLiteral() expression { + starExpression := &starLiteral{} + + starExpression.expressionInterfaceImpl.parent = starExpression + + return starExpression +} + +func (n *starLiteral) serialize(statement statementType, out *queryData, options ...serializeOption) error { + out.writeString("*") + return nil +} diff --git a/tests/select_test.go b/tests/select_test.go index e159a6b..da80a48 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -849,7 +849,11 @@ ORDER BY film.film_id ASC; func TestSelectGroupByHaving(t *testing.T) { expectedSql := ` SELECT payment.customer_id AS "customer_payment_sum.customer_id", - SUM(payment.amount) AS "customer_payment_sum.amount_sum" + SUM(payment.amount) AS "customer_payment_sum.amount_sum", + AVG(payment.amount) AS "customer_payment_sum.amount_avg", + MAX(payment.amount) AS "customer_payment_sum.amount_max", + MIN(payment.amount) AS "customer_payment_sum.amount_min", + COUNT(payment.amount) AS "customer_payment_sum.amount_count" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount) > 100 @@ -859,6 +863,10 @@ ORDER BY SUM(payment.amount) ASC; SELECT( Payment.CustomerID.AS("customer_payment_sum.customer_id"), SUMf(Payment.Amount).AS("customer_payment_sum.amount_sum"), + AVGf(Payment.Amount).AS("customer_payment_sum.amount_avg"), + MAXf(Payment.Amount).AS("customer_payment_sum.amount_max"), + MINf(Payment.Amount).AS("customer_payment_sum.amount_min"), + COUNT(Payment.Amount).AS("customer_payment_sum.amount_count"), ). GROUP_BY(Payment.CustomerID). ORDER_BY( @@ -871,8 +879,12 @@ ORDER BY SUM(payment.amount) ASC; assertQuery(t, customersPaymentQuery, expectedSql, float64(100)) type CustomerPaymentSum struct { - CustomerID int16 - AmountSum float64 + CustomerID int16 + AmountSum float64 + AmountAvg float64 + AmountMax float64 + AmountMin float64 + AmountCount int64 } customerPaymentSum := []CustomerPaymentSum{} @@ -883,8 +895,12 @@ ORDER BY SUM(payment.amount) ASC; assert.Equal(t, len(customerPaymentSum), 296) assert.DeepEqual(t, customerPaymentSum[0], CustomerPaymentSum{ - CustomerID: 135, - AmountSum: 100.72, + CustomerID: 135, + AmountSum: 100.72, + AmountAvg: 3.597142857142857, + AmountMax: 7.99, + AmountMin: 0.99, + AmountCount: 28, }) }