diff --git a/README.md b/README.md index 675d9e2..f426c63 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ convert database query result into desired arbitrary object structure. Jet currently supports `PostgreSQL`, `MySQL` and `MariaDB`. Future releases will add support for additional databases. ![jet](https://github.com/go-jet/jet/wiki/image/jet.png) -Jet is the easiest and fastest way to write complex SQL queries and map database query result +Jet is the easiest and the fastest way to write complex SQL queries and map database query result into complex object composition. __It is not an ORM.__ ## Contents @@ -265,7 +265,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; #### Execute query and store result -Well formed SQL is just a first half the job. Lets see how can we make some sense of result set returned executing +Well formed SQL is just a first half of the job. Lets see how can we make some sense of result set returned executing above statement. Usually this is the most complex and tedious work, but with Jet it is the easiest. First we have to create desired structure to store query result set. diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 5a3f1e9..738074b 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -134,7 +134,8 @@ func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) { // ClauseOrderBy struct type ClauseOrderBy struct { - List []OrderByClause + List []OrderByClause + SkipNewLine bool } // Serialize serializes clause into SQLBuilder @@ -143,7 +144,9 @@ func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder) return } - out.NewLine() + if !o.SkipNewLine { + out.NewLine() + } out.WriteString("ORDER BY") out.IncreaseIdent() @@ -469,3 +472,37 @@ func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder) { out.WriteString(string(i.LockMode)) out.WriteString("MODE") } + +// WindowDefinition struct +type WindowDefinition struct { + Name string + Window Window +} + +// ClauseWindow struct +type ClauseWindow struct { + Definitions []WindowDefinition +} + +// Serialize serializes clause into SQLBuilder +func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) { + if len(i.Definitions) == 0 { + return + } + + out.NewLine() + out.WriteString("WINDOW") + + for i, def := range i.Definitions { + if i > 0 { + out.WriteString(", ") + } + out.WriteString(def.Name) + out.WriteString("AS") + if def.Window == nil { + out.WriteString("()") + continue + } + def.Window.serialize(statementType, out) + } +} diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 2489b85..91f200a 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -81,68 +81,154 @@ func LOG(floatExpression FloatExpression) FloatExpression { // ----------------- Aggregate functions -------------------// // AVG is aggregate function used to calculate avg value from numeric expression -func AVG(numericExpression NumericExpression) FloatExpression { - return NewFloatFunc("AVG", numericExpression) +func AVG(numericExpression NumericExpression) floatWindowExpression { + return NewFloatWindowFunc("AVG", numericExpression) } // BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none. -func BIT_AND(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("BIT_AND", integerExpression) +func BIT_AND(integerExpression IntegerExpression) integerWindowExpression { + return newIntegerWindowFunc("BIT_AND", integerExpression) } // BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none. -func BIT_OR(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("BIT_OR", integerExpression) +func BIT_OR(integerExpression IntegerExpression) integerWindowExpression { + return newIntegerWindowFunc("BIT_OR", integerExpression) } // BOOL_AND is aggregate function. Returns true if all input values are true, otherwise false -func BOOL_AND(boolExpression BoolExpression) BoolExpression { - return newBoolFunc("BOOL_AND", boolExpression) +func BOOL_AND(boolExpression BoolExpression) boolWindowExpression { + return newBoolWindowFunc("BOOL_AND", boolExpression) } // BOOL_OR is aggregate function. Returns true if at least one input value is true, otherwise false -func BOOL_OR(boolExpression BoolExpression) BoolExpression { - return newBoolFunc("BOOL_OR", boolExpression) +func BOOL_OR(boolExpression BoolExpression) boolWindowExpression { + return newBoolWindowFunc("BOOL_OR", boolExpression) } // COUNT is aggregate function. Returns number of input rows for which the value of expression is not null. -func COUNT(expression Expression) IntegerExpression { - return newIntegerFunc("COUNT", expression) +func COUNT(expression Expression) integerWindowExpression { + return newIntegerWindowFunc("COUNT", expression) } // EVERY is aggregate function. Returns true if all input values are true, otherwise false -func EVERY(boolExpression BoolExpression) BoolExpression { - return newBoolFunc("EVERY", boolExpression) +func EVERY(boolExpression BoolExpression) boolWindowExpression { + return newBoolWindowFunc("EVERY", boolExpression) } // MAXf is aggregate function. Returns maximum value of float expression across all input values -func MAXf(floatExpression FloatExpression) FloatExpression { - return NewFloatFunc("MAX", floatExpression) +func MAXf(floatExpression FloatExpression) floatWindowExpression { + return NewFloatWindowFunc("MAX", floatExpression) } // MAXi is aggregate function. Returns maximum value of int expression across all input values -func MAXi(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("MAX", integerExpression) +func MAXi(integerExpression IntegerExpression) integerWindowExpression { + return newIntegerWindowFunc("MAX", integerExpression) } // MINf is aggregate function. Returns minimum value of float expression across all input values -func MINf(floatExpression FloatExpression) FloatExpression { - return NewFloatFunc("MIN", floatExpression) +func MINf(floatExpression FloatExpression) floatWindowExpression { + return NewFloatWindowFunc("MIN", floatExpression) } // MINi is aggregate function. Returns minimum value of int expression across all input values -func MINi(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("MIN", integerExpression) +func MINi(integerExpression IntegerExpression) integerWindowExpression { + return newIntegerWindowFunc("MIN", integerExpression) } // SUMf is aggregate function. Returns sum of expression across all float expressions -func SUMf(floatExpression FloatExpression) FloatExpression { - return NewFloatFunc("SUM", floatExpression) +func SUMf(floatExpression FloatExpression) floatWindowExpression { + return NewFloatWindowFunc("SUM", floatExpression) } // SUMi is aggregate function. Returns sum of expression across all integer expression. -func SUMi(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("SUM", integerExpression) +func SUMi(integerExpression IntegerExpression) integerWindowExpression { + return newIntegerWindowFunc("SUM", integerExpression) +} + +// ----------------- Window functions -------------------// + +// ROW_NUMBER returns number of the current row within its partition, counting from 1 +func ROW_NUMBER() integerWindowExpression { + return newIntegerWindowFunc("ROW_NUMBER") +} + +// RANK of the current row with gaps; same as row_number of its first peer +func RANK() integerWindowExpression { + return newIntegerWindowFunc("RANK") +} + +// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups +func DENSE_RANK() integerWindowExpression { + return newIntegerWindowFunc("DENSE_RANK") +} + +// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1) +func PERCENT_RANK() floatWindowExpression { + return NewFloatWindowFunc("PERCENT_RANK") +} + +// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows +func CUME_DIST() floatWindowExpression { + return NewFloatWindowFunc("CUME_DIST") +} + +// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible +func NTILE(numOfBuckets int64) integerWindowExpression { + return newIntegerWindowFunc("NTILE", FixedLiteral(numOfBuckets)) +} + +// LAG returns value evaluated at the row that is offset rows before the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +func LAG(expr Expression, offsetAndDefault ...interface{}) windowExpression { + return leadLagImpl("LAG", expr, offsetAndDefault...) +} + +// LEAD returns value evaluated at the row that is offset rows after the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +func LEAD(expr Expression, offsetAndDefault ...interface{}) windowExpression { + return leadLagImpl("LEAD", expr, offsetAndDefault...) +} + +// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame +func FIRST_VALUE(value Expression) windowExpression { + return newWindowFunc("FIRST_VALUE", value) +} + +// LAST_VALUE returns value evaluated at the row that is the last row of the window frame +func LAST_VALUE(value Expression) windowExpression { + return newWindowFunc("LAST_VALUE", value) +} + +// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row +func NTH_VALUE(value Expression, nth int64) windowExpression { + return newWindowFunc("NTH_VALUE", value, FixedLiteral(nth)) +} + +func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{}) windowExpression { + params := []Expression{expr} + + if len(offsetAndDefault) >= 2 { + offset, ok := offsetAndDefault[0].(int) + if !ok { + panic("jet: LAG offset should be an integer") + } + + var defaultValue Expression + + defaultValue, ok = offsetAndDefault[1].(Expression) + + if !ok { + defaultValue = literal(offsetAndDefault[1]) + } + + params = append(params, FixedLiteral(offset), defaultValue) + } + + return newWindowFunc(name, params...) } //------------ String functions ------------------// @@ -349,7 +435,7 @@ func TO_HEX(number IntegerExpression) StringExpression { // REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType ...string) BoolExpression { if len(matchType) > 0 { - return newBoolFunc("REGEXP_LIKE", stringExp, pattern, ConstLiteral(matchType[0])) + return newBoolFunc("REGEXP_LIKE", stringExp, pattern, FixedLiteral(matchType[0])) } return newBoolFunc("REGEXP_LIKE", stringExp, pattern) @@ -391,7 +477,7 @@ func CURRENT_TIME(precision ...int) TimezExpression { var timezFunc *timezFunc if len(precision) > 0 { - timezFunc = newTimezFunc("CURRENT_TIME", ConstLiteral(precision[0])) + timezFunc = newTimezFunc("CURRENT_TIME", FixedLiteral(precision[0])) } else { timezFunc = newTimezFunc("CURRENT_TIME") } @@ -406,7 +492,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression { var timestampzFunc *timestampzFunc if len(precision) > 0 { - timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", ConstLiteral(precision[0])) + timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", FixedLiteral(precision[0])) } else { timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP") } @@ -421,7 +507,7 @@ func LOCALTIME(precision ...int) TimeExpression { var timeFunc *timeFunc if len(precision) > 0 { - timeFunc = newTimeFunc("LOCALTIME", ConstLiteral(precision[0])) + timeFunc = newTimeFunc("LOCALTIME", FixedLiteral(precision[0])) } else { timeFunc = newTimeFunc("LOCALTIME") } @@ -436,7 +522,7 @@ func LOCALTIMESTAMP(precision ...int) TimestampExpression { var timestampFunc *timestampFunc if len(precision) > 0 { - timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", ConstLiteral(precision[0])) + timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", FixedLiteral(precision[0])) } else { timestampFunc = NewTimestampFunc("LOCALTIMESTAMP") } @@ -504,6 +590,16 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr return funcExp } +// NewFloatWindowFunc creates new float function with name and expressions +func newWindowFunc(name string, expressions ...Expression) windowExpression { + + newFun := newFunc(name, expressions, nil) + windowExpr := newWindowExpression(newFun) + newFun.expressionInterfaceImpl.Parent = windowExpr + + return windowExpr +} + func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { serializeOverrideFunc := serializeOverride(f.expressions...) @@ -536,10 +632,23 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression { boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) boolFunc.boolInterfaceImpl.parent = boolFunc + boolFunc.expressionInterfaceImpl.Parent = boolFunc return boolFunc } +// NewFloatWindowFunc creates new float function with name and expressions +func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression { + boolFunc := &boolFunc{} + + boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) + intWindowFunc := newBoolWindowExpression(boolFunc) + boolFunc.boolInterfaceImpl.parent = intWindowFunc + boolFunc.expressionInterfaceImpl.Parent = intWindowFunc + + return intWindowFunc +} + type floatFunc struct { funcExpressionImpl floatInterfaceImpl @@ -555,6 +664,18 @@ func NewFloatFunc(name string, expressions ...Expression) FloatExpression { return floatFunc } +// NewFloatWindowFunc creates new float function with name and expressions +func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression { + floatFunc := &floatFunc{} + + floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) + floatWindowFunc := newFloatWindowExpression(floatFunc) + floatFunc.floatInterfaceImpl.parent = floatWindowFunc + floatFunc.expressionInterfaceImpl.Parent = floatWindowFunc + + return floatWindowFunc +} + type integerFunc struct { funcExpressionImpl integerInterfaceImpl @@ -569,6 +690,18 @@ func newIntegerFunc(name string, expressions ...Expression) IntegerExpression { return floatFunc } +// NewFloatWindowFunc creates new float function with name and expressions +func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression { + integerFunc := &integerFunc{} + + integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc) + intWindowFunc := newIntegerWindowExpression(integerFunc) + integerFunc.integerInterfaceImpl.parent = intWindowFunc + integerFunc.expressionInterfaceImpl.Parent = intWindowFunc + + return intWindowFunc +} + type stringFunc struct { funcExpressionImpl stringInterfaceImpl diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 4851fa6..68fb429 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -32,8 +32,8 @@ func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl return &exp } -// ConstLiteral is injected directly to SQL query, and does not appear in argument list. -func ConstLiteral(value interface{}) *literalExpressionImpl { +// FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list. +func FixedLiteral(value interface{}) *literalExpressionImpl { exp := literal(value) exp.constant = true diff --git a/internal/jet/window_expression.go b/internal/jet/window_expression.go new file mode 100644 index 0000000..3e7f1c7 --- /dev/null +++ b/internal/jet/window_expression.go @@ -0,0 +1,146 @@ +package jet + +type commonWindowImpl struct { + expression Expression + window Window +} + +func (w *commonWindowImpl) over(window ...Window) { + if len(window) > 0 { + w.window = window[0] + } else { + w.window = newWindowImpl(nil) + } +} + +func (w *commonWindowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + w.expression.serialize(statement, out) + if w.window != nil { + out.WriteString("OVER") + w.window.serialize(statement, out) + } +} + +// -------------------------------------- + +type windowExpression interface { + Expression + OVER(window ...Window) Expression +} + +func newWindowExpression(Exp Expression) windowExpression { + newExp := &windowExpressionImpl{ + Expression: Exp, + } + + newExp.commonWindowImpl.expression = Exp + + return newExp +} + +type windowExpressionImpl struct { + Expression + commonWindowImpl +} + +func (f *windowExpressionImpl) OVER(window ...Window) Expression { + f.commonWindowImpl.over(window...) + return f +} + +func (f *windowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + f.commonWindowImpl.serialize(statement, out) +} + +// ----------------------------------------------------- + +type floatWindowExpression interface { + FloatExpression + OVER(window ...Window) FloatExpression +} + +func newFloatWindowExpression(floatExp FloatExpression) floatWindowExpression { + newExp := &floatWindowExpressionImpl{ + FloatExpression: floatExp, + } + + newExp.commonWindowImpl.expression = floatExp + + return newExp +} + +type floatWindowExpressionImpl struct { + FloatExpression + commonWindowImpl +} + +func (f *floatWindowExpressionImpl) OVER(window ...Window) FloatExpression { + f.commonWindowImpl.over(window...) + return f +} + +func (f *floatWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + f.commonWindowImpl.serialize(statement, out) +} + +// ------------------------------------------------ + +type integerWindowExpression interface { + IntegerExpression + OVER(window ...Window) IntegerExpression +} + +func newIntegerWindowExpression(intExp IntegerExpression) integerWindowExpression { + newExp := &integerWindowExpressionImpl{ + IntegerExpression: intExp, + } + + newExp.commonWindowImpl.expression = intExp + + return newExp +} + +type integerWindowExpressionImpl struct { + IntegerExpression + commonWindowImpl +} + +func (f *integerWindowExpressionImpl) OVER(window ...Window) IntegerExpression { + f.commonWindowImpl.over(window...) + return f +} + +func (f *integerWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + f.commonWindowImpl.serialize(statement, out) +} + +// ------------------------------------------------ + +type boolWindowExpression interface { + BoolExpression + OVER(window ...Window) BoolExpression +} + +func newBoolWindowExpression(boolExp BoolExpression) boolWindowExpression { + newExp := &boolWindowExpressionImpl{ + BoolExpression: boolExp, + } + + newExp.commonWindowImpl.expression = boolExp + + return newExp +} + +type boolWindowExpressionImpl struct { + BoolExpression + commonWindowImpl +} + +func (f *boolWindowExpressionImpl) OVER(window ...Window) BoolExpression { + f.commonWindowImpl.over(window...) + return f +} + +func (f *boolWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + f.commonWindowImpl.serialize(statement, out) +} diff --git a/internal/jet/window_func.go b/internal/jet/window_func.go new file mode 100644 index 0000000..7f4d1b7 --- /dev/null +++ b/internal/jet/window_func.go @@ -0,0 +1,186 @@ +package jet + +// Window interface +type Window interface { + Serializer + ORDER_BY(expr ...OrderByClause) Window + ROWS(start FrameExtent, end ...FrameExtent) Window + RANGE(start FrameExtent, end ...FrameExtent) Window + GROUPS(start FrameExtent, end ...FrameExtent) Window +} + +type windowImpl struct { + partitionBy []Expression + orderBy ClauseOrderBy + frameUnits string + start, end FrameExtent + + parent Window +} + +func newWindowImpl(parent Window) *windowImpl { + newWindow := &windowImpl{} + if parent == nil { + newWindow.parent = newWindow + } else { + newWindow.parent = parent + } + + return newWindow +} + +func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + if !contains(options, noWrap) { + out.WriteByte('(') + } + + if w.partitionBy != nil { + out.WriteString("PARTITION BY") + + serializeExpressionList(statement, w.partitionBy, ", ", out) + } + w.orderBy.SkipNewLine = true + w.orderBy.Serialize(statement, out) + + if w.frameUnits != "" { + out.WriteString(w.frameUnits) + + if w.end == nil { + w.start.serialize(statement, out) + } else { + out.WriteString("BETWEEN") + w.start.serialize(statement, out) + out.WriteString("AND") + w.end.serialize(statement, out) + } + } + + if !contains(options, noWrap) { + out.WriteByte(')') + } +} + +func (w *windowImpl) ORDER_BY(exprs ...OrderByClause) Window { + w.orderBy.List = exprs + return w.parent +} + +func (w *windowImpl) ROWS(start FrameExtent, end ...FrameExtent) Window { + w.frameUnits = "ROWS" + w.setFrameRange(start, end...) + return w.parent +} + +func (w *windowImpl) RANGE(start FrameExtent, end ...FrameExtent) Window { + w.frameUnits = "RANGE" + w.setFrameRange(start, end...) + return w.parent +} + +func (w *windowImpl) GROUPS(start FrameExtent, end ...FrameExtent) Window { + w.frameUnits = "GROUPS" + w.setFrameRange(start, end...) + return w.parent +} + +func (w *windowImpl) setFrameRange(start FrameExtent, end ...FrameExtent) { + w.start = start + if len(end) > 0 { + w.end = end[0] + } +} + +// PARTITION_BY window function constructor +func PARTITION_BY(exp Expression, exprs ...Expression) Window { + funImpl := newWindowImpl(nil) + funImpl.partitionBy = append([]Expression{exp}, exprs...) + return funImpl +} + +// ORDER_BY window function constructor +func ORDER_BY(expr ...OrderByClause) Window { + funImpl := newWindowImpl(nil) + funImpl.orderBy.List = expr + return funImpl +} + +// ----------------------------------------------- + +// FrameExtent interface +type FrameExtent interface { + Serializer + isFrameExtent() +} + +// PRECEDING window frame clause +func PRECEDING(offset Serializer) FrameExtent { + return &frameExtentImpl{ + preceding: true, + offset: offset, + } +} + +// FOLLOWING window frame clause +func FOLLOWING(offset Serializer) FrameExtent { + return &frameExtentImpl{ + preceding: false, + offset: offset, + } +} + +type frameExtentImpl struct { + preceding bool + offset Serializer +} + +func (f *frameExtentImpl) isFrameExtent() {} + +func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + if f == nil { + return + } + f.offset.serialize(statement, out) + + if f.preceding { + out.WriteString("PRECEDING") + } else { + out.WriteString("FOLLOWING") + } +} + +// ----------------------------------------------- + +// Window function keywords +var ( + UNBOUNDED = keywordClause("UNBOUNDED") + CURRENT_ROW = frameExtentKeyword{"CURRENT ROW"} +) + +type frameExtentKeyword struct { + keywordClause +} + +func (f frameExtentKeyword) isFrameExtent() {} + +// ----------------------------------------------- + +// WindowName is used to specify window reference from WINDOW clause +func WindowName(name string) Window { + newWindow := &windowName{name: name} + newWindow.parent = newWindow + return newWindow +} + +type windowName struct { + windowImpl + name string +} + +func (w windowName) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteByte('(') + + out.WriteString(w.name) + w.windowImpl.serialize(statement, out, noWrap) + + out.WriteByte(')') +} diff --git a/internal/jet/window_func_test.go b/internal/jet/window_func_test.go new file mode 100644 index 0000000..74ae9e9 --- /dev/null +++ b/internal/jet/window_func_test.go @@ -0,0 +1,21 @@ +package jet + +import "testing" + +func TestFrameExtent(t *testing.T) { + assertClauseSerialize(t, PRECEDING(Int(2)), "$1 PRECEDING", int64(2)) + assertClauseSerialize(t, FOLLOWING(Int(4)), "$1 FOLLOWING", int64(4)) +} + +func TestWindowFunctions(t *testing.T) { + assertClauseSerialize(t, PARTITION_BY(table1Col1), "(PARTITION BY table1.col1)") + assertClauseSerialize(t, PARTITION_BY(table1Col3).ORDER_BY(table1Col1), "(PARTITION BY table1.col3 ORDER BY table1.col1)") + assertClauseSerialize(t, ORDER_BY(table1Col1), "(ORDER BY table1.col1)") + assertClauseSerialize(t, ORDER_BY(table1Col1).ROWS(PRECEDING(Int(1))), "(ORDER BY table1.col1 ROWS $1 PRECEDING)", int64(1)) + assertClauseSerialize(t, ORDER_BY(table1Col1).ROWS(PRECEDING(Int(1)), FOLLOWING(Int(33))), + "(ORDER BY table1.col1 ROWS BETWEEN $1 PRECEDING AND $2 FOLLOWING)", int64(1), int64(33)) + assertClauseSerialize(t, ORDER_BY(table1Col1).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)), + "(ORDER BY table1.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)") + assertClauseSerialize(t, ORDER_BY(table1Col1).RANGE(PRECEDING(UNBOUNDED), CURRENT_ROW), + "(ORDER BY table1.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)") +} diff --git a/mysql/functions.go b/mysql/functions.go index 2c911ea..0064d9d 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -85,6 +85,47 @@ var SUMi = jet.SUMi // SUMf is aggregate function. Returns sum of float expression. var SUMf = jet.SUMf +// -------------------- Window functions -----------------------// + +// ROW_NUMBER returns number of the current row within its partition, counting from 1 +var ROW_NUMBER = jet.ROW_NUMBER + +// RANK of the current row with gaps; same as row_number of its first peer +var RANK = jet.RANK + +// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups +var DENSE_RANK = jet.DENSE_RANK + +// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1) +var PERCENT_RANK = jet.PERCENT_RANK + +// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows +var CUME_DIST = jet.CUME_DIST + +// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible +var NTILE = jet.NTILE + +// LAG returns value evaluated at the row that is offset rows before the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LAG = jet.LAG + +// LEAD returns value evaluated at the row that is offset rows after the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LEAD = jet.LEAD + +// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame +var FIRST_VALUE = jet.FIRST_VALUE + +// LAST_VALUE returns value evaluated at the row that is the last row of the window frame +var LAST_VALUE = jet.LAST_VALUE + +// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row +var NTH_VALUE = jet.NTH_VALUE + //--------------------- String functions ------------------// // BIT_LENGTH returns number of bits in string expression @@ -181,7 +222,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampExpression { // NOW returns current datetime func NOW(fsp ...int) DateTimeExpression { if len(fsp) > 0 { - return jet.NewTimestampFunc("NOW", jet.ConstLiteral(int64(fsp[0]))) + return jet.NewTimestampFunc("NOW", jet.FixedLiteral(int64(fsp[0]))) } return jet.NewTimestampFunc("NOW") } diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 5622f3e..3347888 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -1,6 +1,8 @@ package mysql -import "github.com/go-jet/jet/internal/jet" +import ( + "github.com/go-jet/jet/internal/jet" +) // RowLock is interface for SELECT statement row lock types type RowLock = jet.RowLock @@ -11,6 +13,27 @@ var ( SHARE = jet.NewRowLock("SHARE") ) +// Window function clauses +var ( + PARTITION_BY = jet.PARTITION_BY + ORDER_BY = jet.ORDER_BY + UNBOUNDED = jet.UNBOUNDED + CURRENT_ROW = jet.CURRENT_ROW +) + +// PRECEDING window frame clause +func PRECEDING(offset interface{}) jet.FrameExtent { + return jet.PRECEDING(toJetFrameOffset(offset)) +} + +// FOLLOWING window frame clause +func FOLLOWING(offset interface{}) jet.FrameExtent { + return jet.FOLLOWING(toJetFrameOffset(offset)) +} + +// Window is used to specify window reference from WINDOW clause +var Window = jet.WindowName + // SelectStatement is interface for MySQL SELECT statement type SelectStatement interface { Statement @@ -22,6 +45,7 @@ type SelectStatement interface { WHERE(expression BoolExpression) SelectStatement GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement + WINDOW(name string) windowExpand ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement LIMIT(limit int64) SelectStatement OFFSET(offset int64) SelectStatement @@ -42,7 +66,7 @@ func SELECT(projection Projection, projections ...Projection) SelectStatement { func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { newSelect := &selectStatementImpl{} newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, - &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, + &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, &newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock) newSelect.Select.Projections = toJetProjectionList(projections) @@ -66,6 +90,7 @@ type selectStatementImpl struct { Where jet.ClauseWhere GroupBy jet.ClauseGroupBy Having jet.ClauseHaving + Window jet.ClauseWindow OrderBy jet.ClauseOrderBy Limit jet.ClauseLimit Offset jet.ClauseOffset @@ -98,6 +123,11 @@ func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatem return s } +func (s *selectStatementImpl) WINDOW(name string) windowExpand { + s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name}) + return windowExpand{selectStatement: s} +} + func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement { s.OrderBy.List = orderByClauses return s @@ -126,3 +156,31 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement { func (s *selectStatementImpl) AsTable(alias string) SelectTable { return newSelectTable(s, alias) } + +//----------------------------------------------------- + +type windowExpand struct { + selectStatement *selectStatementImpl +} + +func (w windowExpand) AS(window ...jet.Window) SelectStatement { + if len(window) == 0 { + return w.selectStatement + } + windowsDefinition := w.selectStatement.Window.Definitions + windowsDefinition[len(windowsDefinition)-1].Window = window[0] + return w.selectStatement +} + +func toJetFrameOffset(offset interface{}) jet.Serializer { + if offset == UNBOUNDED { + return jet.UNBOUNDED + } + + // check for interval expression + //if exp, ok := offset.(Expression); ok { + // return exp + //} + + return jet.FixedLiteral(offset) +} diff --git a/postgres/functions.go b/postgres/functions.go index 18a637f..4f657ca 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -87,6 +87,47 @@ var SUMf = jet.SUMf // SUMi is aggregate function. Returns sum of expression across all integer expression. var SUMi = jet.SUMi +// -------------------- Window functions -----------------------// + +// ROW_NUMBER returns number of the current row within its partition, counting from 1 +var ROW_NUMBER = jet.ROW_NUMBER + +// RANK of the current row with gaps; same as row_number of its first peer +var RANK = jet.RANK + +// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups +var DENSE_RANK = jet.DENSE_RANK + +// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1) +var PERCENT_RANK = jet.PERCENT_RANK + +// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows +var CUME_DIST = jet.CUME_DIST + +// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible +var NTILE = jet.NTILE + +// LAG returns value evaluated at the row that is offset rows before the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LAG = jet.LAG + +// LEAD returns value evaluated at the row that is offset rows after the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LEAD = jet.LEAD + +// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame +var FIRST_VALUE = jet.FIRST_VALUE + +// LAST_VALUE returns value evaluated at the row that is the last row of the window frame +var LAST_VALUE = jet.LAST_VALUE + +// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row +var NTH_VALUE = jet.NTH_VALUE + //--------------------- String functions ------------------// // BIT_LENGTH returns number of bits in string expression diff --git a/postgres/select_statement.go b/postgres/select_statement.go index 34225e6..e4aeb37 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -1,6 +1,9 @@ package postgres -import "github.com/go-jet/jet/internal/jet" +import ( + "github.com/go-jet/jet/internal/jet" + "math" +) // RowLock is interface for SELECT statement row lock types type RowLock = jet.RowLock @@ -13,6 +16,27 @@ var ( KEY_SHARE = jet.NewRowLock("KEY SHARE") ) +// Window function clauses +var ( + PARTITION_BY = jet.PARTITION_BY + ORDER_BY = jet.ORDER_BY + UNBOUNDED = int64(math.MaxInt64) + CURRENT_ROW = jet.CURRENT_ROW +) + +// PRECEDING window frame clause +func PRECEDING(offset int64) jet.FrameExtent { + return jet.PRECEDING(toJetFrameOffset(offset)) +} + +// FOLLOWING window frame clause +func FOLLOWING(offset int64) jet.FrameExtent { + return jet.FOLLOWING(toJetFrameOffset(offset)) +} + +// Window definition reference +var Window = jet.WindowName + // SelectStatement is interface for PostgreSQL SELECT statement type SelectStatement interface { Statement @@ -24,6 +48,7 @@ type SelectStatement interface { WHERE(expression BoolExpression) SelectStatement GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement + WINDOW(name string) windowExpand ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement LIMIT(limit int64) SelectStatement OFFSET(offset int64) SelectStatement @@ -47,15 +72,9 @@ func SELECT(projection Projection, projections ...Projection) SelectStatement { func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { newSelect := &selectStatementImpl{} newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, - &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, + &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, &newSelect.Limit, &newSelect.Offset, &newSelect.For) - // statementImpl = jet.NewStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, - // &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, - // &newSelect.Limit, &newSelect.Offset, &newSelect.For) - // - //newSelect.expressionStatementImpl.expressionInterfaceImpl.Parent = newSelect - newSelect.Select.Projections = toJetProjectionList(projections) newSelect.From.Table = table newSelect.Limit.Count = -1 @@ -75,6 +94,7 @@ type selectStatementImpl struct { Where jet.ClauseWhere GroupBy jet.ClauseGroupBy Having jet.ClauseHaving + Window jet.ClauseWindow OrderBy jet.ClauseOrderBy Limit jet.ClauseLimit Offset jet.ClauseOffset @@ -106,6 +126,11 @@ func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatem return s } +func (s *selectStatementImpl) WINDOW(name string) windowExpand { + s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name}) + return windowExpand{selectStatement: s} +} + func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement { s.OrderBy.List = orderByClauses return s @@ -129,3 +154,25 @@ func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement { func (s *selectStatementImpl) AsTable(alias string) SelectTable { return newSelectTable(s, alias) } + +//----------------------------------------------------- + +type windowExpand struct { + selectStatement *selectStatementImpl +} + +func (w windowExpand) AS(window ...jet.Window) SelectStatement { + if len(window) == 0 { + return w.selectStatement + } + windowsDefinition := w.selectStatement.Window.Definitions + windowsDefinition[len(windowsDefinition)-1].Window = window[0] + return w.selectStatement +} + +func toJetFrameOffset(offset int64) jet.Serializer { + if offset == UNBOUNDED { + return jet.UNBOUNDED + } + return jet.FixedLiteral(offset) +} diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 1e9b73b..d0dac9c 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -1,6 +1,7 @@ package mysql import ( + "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/mysql" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/enum" @@ -527,3 +528,110 @@ LOCK IN SHARE MODE; err := query.Query(db, &struct{}{}) assert.NilError(t, err) } + +func TestWindowFunction(t *testing.T) { + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (PARTITION BY payment.customer_id), + MAX(payment.amount) OVER (ORDER BY payment.payment_date DESC), + MIN(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC ROWS BETWEEN 1 PRECEDING AND 6 FOLLOWING), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + MAX(payment.customer_id) OVER (ORDER BY payment.payment_date DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING), + MIN(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + ROW_NUMBER() OVER (ORDER BY payment.payment_date), + RANK() OVER (ORDER BY payment.payment_date), + DENSE_RANK() OVER (ORDER BY payment.payment_date), + CUME_DIST() OVER (ORDER BY payment.payment_date), + NTILE(11) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date), + FIRST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + LAST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + NTH_VALUE(payment.amount, 3) OVER (ORDER BY payment.payment_date) +FROM dvds.payment +WHERE payment.payment_id < ? +GROUP BY payment.amount, payment.customer_id, payment.payment_date; +` + query := Payment. + SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID)), + MAXf(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate.DESC())), + MINf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).ROWS(PRECEDING(1), FOLLOWING(6))), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + MAXi(Payment.CustomerID).OVER(ORDER_BY(Payment.PaymentDate.DESC()).ROWS(CURRENT_ROW, FOLLOWING(UNBOUNDED))), + MINi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + ROW_NUMBER().OVER(ORDER_BY(Payment.PaymentDate)), + RANK().OVER(ORDER_BY(Payment.PaymentDate)), + DENSE_RANK().OVER(ORDER_BY(Payment.PaymentDate)), + CUME_DIST().OVER(ORDER_BY(Payment.PaymentDate)), + NTILE(11).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + FIRST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + NTH_VALUE(Payment.Amount, 3).OVER(ORDER_BY(Payment.PaymentDate)), + ).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate). + WHERE(Payment.PaymentID.LT(Int(10))) + + fmt.Println(query.Sql()) + + testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} + +func TestWindowClause(t *testing.T) { + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (w1), + AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +FROM dvds.payment +WHERE payment.payment_id < ? +WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id) +ORDER BY payment.customer_id; +` + query := Payment.SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(Window("w1")), + AVG(Payment.Amount).OVER( + Window("w2"). + ORDER_BY(Payment.CustomerID). + RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)), + ), + AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + ). + WHERE(Payment.PaymentID.LT(Int(10))). + WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)). + WINDOW("w2").AS(Window("w1")). + WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). + ORDER_BY(Payment.CustomerID) + + fmt.Println(query.Sql()) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) + + err := query.Query(db, &struct{}{}) + + assert.NilError(t, err) +} diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 03e5672..1fa2421 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1615,3 +1615,110 @@ SELECT true, err := query.Query(db, &struct{}{}) assert.NilError(t, err) } + +func TestWindowFunction(t *testing.T) { + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (PARTITION BY payment.customer_id), + MAX(payment.amount) OVER (ORDER BY payment.payment_date DESC), + MIN(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC ROWS BETWEEN 1 PRECEDING AND 6 FOLLOWING), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + MAX(payment.customer_id) OVER (ORDER BY payment.payment_date DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING), + MIN(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + ROW_NUMBER() OVER (ORDER BY payment.payment_date), + RANK() OVER (ORDER BY payment.payment_date), + DENSE_RANK() OVER (ORDER BY payment.payment_date), + CUME_DIST() OVER (ORDER BY payment.payment_date), + NTILE(11) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, $1) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, $2) OVER (ORDER BY payment.payment_date), + FIRST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + LAST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + NTH_VALUE(payment.amount, 3) OVER (ORDER BY payment.payment_date) +FROM dvds.payment +WHERE payment.payment_id < $3 +GROUP BY payment.amount, payment.customer_id, payment.payment_date; +` + query := Payment. + SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID)), + MAXf(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate.DESC())), + MINf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).ROWS(PRECEDING(1), FOLLOWING(6))), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + MAXi(Payment.CustomerID).OVER(ORDER_BY(Payment.PaymentDate.DESC()).ROWS(CURRENT_ROW, FOLLOWING(UNBOUNDED))), + MINi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + ROW_NUMBER().OVER(ORDER_BY(Payment.PaymentDate)), + RANK().OVER(ORDER_BY(Payment.PaymentDate)), + DENSE_RANK().OVER(ORDER_BY(Payment.PaymentDate)), + CUME_DIST().OVER(ORDER_BY(Payment.PaymentDate)), + NTILE(11).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + FIRST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + NTH_VALUE(Payment.Amount, 3).OVER(ORDER_BY(Payment.PaymentDate)), + ).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate). + WHERE(Payment.PaymentID.LT(Int(10))) + + fmt.Println(query.Sql()) + + testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} + +func TestWindowClause(t *testing.T) { + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (w1), + AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +FROM dvds.payment +WHERE payment.payment_id < $1 +WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id) +ORDER BY payment.customer_id; +` + query := Payment.SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(Window("w1")), + AVG(Payment.Amount).OVER( + Window("w2"). + ORDER_BY(Payment.CustomerID). + RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)), + ), + AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + ). + WHERE(Payment.PaymentID.LT(Int(10))). + WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)). + WINDOW("w2").AS(Window("w1")). + WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). + ORDER_BY(Payment.CustomerID) + + fmt.Println(query.Sql()) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) + + err := query.Query(db, &struct{}{}) + + assert.NilError(t, err) +}