Added support for window clause and functions.

This commit is contained in:
go-jet 2019-09-17 13:34:47 +02:00
parent b7363a554b
commit 5ba10d35db
13 changed files with 973 additions and 48 deletions

View file

@ -134,7 +134,8 @@ func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) {
// ClauseOrderBy struct
type ClauseOrderBy struct {
List []OrderByClause
List []OrderByClause
SkipNewLine bool
}
// Serialize serializes clause into SQLBuilder
@ -143,7 +144,9 @@ func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder)
return
}
out.NewLine()
if !o.SkipNewLine {
out.NewLine()
}
out.WriteString("ORDER BY")
out.IncreaseIdent()
@ -469,3 +472,37 @@ func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString(string(i.LockMode))
out.WriteString("MODE")
}
// WindowDefinition struct
type WindowDefinition struct {
Name string
Window Window
}
// ClauseWindow struct
type ClauseWindow struct {
Definitions []WindowDefinition
}
// Serialize serializes clause into SQLBuilder
func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) {
if len(i.Definitions) == 0 {
return
}
out.NewLine()
out.WriteString("WINDOW")
for i, def := range i.Definitions {
if i > 0 {
out.WriteString(", ")
}
out.WriteString(def.Name)
out.WriteString("AS")
if def.Window == nil {
out.WriteString("()")
continue
}
def.Window.serialize(statementType, out)
}
}

View file

@ -81,68 +81,154 @@ func LOG(floatExpression FloatExpression) FloatExpression {
// ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression
func AVG(numericExpression NumericExpression) FloatExpression {
return NewFloatFunc("AVG", numericExpression)
func AVG(numericExpression NumericExpression) floatWindowExpression {
return NewFloatWindowFunc("AVG", numericExpression)
}
// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none.
func BIT_AND(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_AND", integerExpression)
func BIT_AND(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("BIT_AND", integerExpression)
}
// BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none.
func BIT_OR(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_OR", integerExpression)
func BIT_OR(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("BIT_OR", integerExpression)
}
// BOOL_AND is aggregate function. Returns true if all input values are true, otherwise false
func BOOL_AND(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_AND", boolExpression)
func BOOL_AND(boolExpression BoolExpression) boolWindowExpression {
return newBoolWindowFunc("BOOL_AND", boolExpression)
}
// BOOL_OR is aggregate function. Returns true if at least one input value is true, otherwise false
func BOOL_OR(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_OR", boolExpression)
func BOOL_OR(boolExpression BoolExpression) boolWindowExpression {
return newBoolWindowFunc("BOOL_OR", boolExpression)
}
// COUNT is aggregate function. Returns number of input rows for which the value of expression is not null.
func COUNT(expression Expression) IntegerExpression {
return newIntegerFunc("COUNT", expression)
func COUNT(expression Expression) integerWindowExpression {
return newIntegerWindowFunc("COUNT", expression)
}
// EVERY is aggregate function. Returns true if all input values are true, otherwise false
func EVERY(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("EVERY", boolExpression)
func EVERY(boolExpression BoolExpression) boolWindowExpression {
return newBoolWindowFunc("EVERY", boolExpression)
}
// MAXf is aggregate function. Returns maximum value of float expression across all input values
func MAXf(floatExpression FloatExpression) FloatExpression {
return NewFloatFunc("MAX", floatExpression)
func MAXf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatWindowFunc("MAX", floatExpression)
}
// MAXi is aggregate function. Returns maximum value of int expression across all input values
func MAXi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MAX", integerExpression)
func MAXi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("MAX", integerExpression)
}
// MINf is aggregate function. Returns minimum value of float expression across all input values
func MINf(floatExpression FloatExpression) FloatExpression {
return NewFloatFunc("MIN", floatExpression)
func MINf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatWindowFunc("MIN", floatExpression)
}
// MINi is aggregate function. Returns minimum value of int expression across all input values
func MINi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MIN", integerExpression)
func MINi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("MIN", integerExpression)
}
// SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) FloatExpression {
return NewFloatFunc("SUM", floatExpression)
func SUMf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatWindowFunc("SUM", floatExpression)
}
// SUMi is aggregate function. Returns sum of expression across all integer expression.
func SUMi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("SUM", integerExpression)
func SUMi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("SUM", integerExpression)
}
// ----------------- Window functions -------------------//
// ROW_NUMBER returns number of the current row within its partition, counting from 1
func ROW_NUMBER() integerWindowExpression {
return newIntegerWindowFunc("ROW_NUMBER")
}
// RANK of the current row with gaps; same as row_number of its first peer
func RANK() integerWindowExpression {
return newIntegerWindowFunc("RANK")
}
// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups
func DENSE_RANK() integerWindowExpression {
return newIntegerWindowFunc("DENSE_RANK")
}
// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1)
func PERCENT_RANK() floatWindowExpression {
return NewFloatWindowFunc("PERCENT_RANK")
}
// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows
func CUME_DIST() floatWindowExpression {
return NewFloatWindowFunc("CUME_DIST")
}
// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible
func NTILE(numOfBuckets int64) integerWindowExpression {
return newIntegerWindowFunc("NTILE", FixedLiteral(numOfBuckets))
}
// LAG returns value evaluated at the row that is offset rows before the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
func LAG(expr Expression, offsetAndDefault ...interface{}) windowExpression {
return leadLagImpl("LAG", expr, offsetAndDefault...)
}
// LEAD returns value evaluated at the row that is offset rows after the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
func LEAD(expr Expression, offsetAndDefault ...interface{}) windowExpression {
return leadLagImpl("LEAD", expr, offsetAndDefault...)
}
// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame
func FIRST_VALUE(value Expression) windowExpression {
return newWindowFunc("FIRST_VALUE", value)
}
// LAST_VALUE returns value evaluated at the row that is the last row of the window frame
func LAST_VALUE(value Expression) windowExpression {
return newWindowFunc("LAST_VALUE", value)
}
// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row
func NTH_VALUE(value Expression, nth int64) windowExpression {
return newWindowFunc("NTH_VALUE", value, FixedLiteral(nth))
}
func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{}) windowExpression {
params := []Expression{expr}
if len(offsetAndDefault) >= 2 {
offset, ok := offsetAndDefault[0].(int)
if !ok {
panic("jet: LAG offset should be an integer")
}
var defaultValue Expression
defaultValue, ok = offsetAndDefault[1].(Expression)
if !ok {
defaultValue = literal(offsetAndDefault[1])
}
params = append(params, FixedLiteral(offset), defaultValue)
}
return newWindowFunc(name, params...)
}
//------------ String functions ------------------//
@ -349,7 +435,7 @@ func TO_HEX(number IntegerExpression) StringExpression {
// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise.
func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType ...string) BoolExpression {
if len(matchType) > 0 {
return newBoolFunc("REGEXP_LIKE", stringExp, pattern, ConstLiteral(matchType[0]))
return newBoolFunc("REGEXP_LIKE", stringExp, pattern, FixedLiteral(matchType[0]))
}
return newBoolFunc("REGEXP_LIKE", stringExp, pattern)
@ -391,7 +477,7 @@ func CURRENT_TIME(precision ...int) TimezExpression {
var timezFunc *timezFunc
if len(precision) > 0 {
timezFunc = newTimezFunc("CURRENT_TIME", ConstLiteral(precision[0]))
timezFunc = newTimezFunc("CURRENT_TIME", FixedLiteral(precision[0]))
} else {
timezFunc = newTimezFunc("CURRENT_TIME")
}
@ -406,7 +492,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression {
var timestampzFunc *timestampzFunc
if len(precision) > 0 {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", ConstLiteral(precision[0]))
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", FixedLiteral(precision[0]))
} else {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP")
}
@ -421,7 +507,7 @@ func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc
if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", ConstLiteral(precision[0]))
timeFunc = newTimeFunc("LOCALTIME", FixedLiteral(precision[0]))
} else {
timeFunc = newTimeFunc("LOCALTIME")
}
@ -436,7 +522,7 @@ func LOCALTIMESTAMP(precision ...int) TimestampExpression {
var timestampFunc *timestampFunc
if len(precision) > 0 {
timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", ConstLiteral(precision[0]))
timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", FixedLiteral(precision[0]))
} else {
timestampFunc = NewTimestampFunc("LOCALTIMESTAMP")
}
@ -504,6 +590,16 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
return funcExp
}
// NewFloatWindowFunc creates new float function with name and expressions
func newWindowFunc(name string, expressions ...Expression) windowExpression {
newFun := newFunc(name, expressions, nil)
windowExpr := newWindowExpression(newFun)
newFun.expressionInterfaceImpl.Parent = windowExpr
return windowExpr
}
func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(f.expressions...)
@ -536,10 +632,23 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression {
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc)
boolFunc.boolInterfaceImpl.parent = boolFunc
boolFunc.expressionInterfaceImpl.Parent = boolFunc
return boolFunc
}
// NewFloatWindowFunc creates new float function with name and expressions
func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression {
boolFunc := &boolFunc{}
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc)
intWindowFunc := newBoolWindowExpression(boolFunc)
boolFunc.boolInterfaceImpl.parent = intWindowFunc
boolFunc.expressionInterfaceImpl.Parent = intWindowFunc
return intWindowFunc
}
type floatFunc struct {
funcExpressionImpl
floatInterfaceImpl
@ -555,6 +664,18 @@ func NewFloatFunc(name string, expressions ...Expression) FloatExpression {
return floatFunc
}
// NewFloatWindowFunc creates new float function with name and expressions
func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression {
floatFunc := &floatFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc)
floatWindowFunc := newFloatWindowExpression(floatFunc)
floatFunc.floatInterfaceImpl.parent = floatWindowFunc
floatFunc.expressionInterfaceImpl.Parent = floatWindowFunc
return floatWindowFunc
}
type integerFunc struct {
funcExpressionImpl
integerInterfaceImpl
@ -569,6 +690,18 @@ func newIntegerFunc(name string, expressions ...Expression) IntegerExpression {
return floatFunc
}
// NewFloatWindowFunc creates new float function with name and expressions
func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression {
integerFunc := &integerFunc{}
integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc)
intWindowFunc := newIntegerWindowExpression(integerFunc)
integerFunc.integerInterfaceImpl.parent = intWindowFunc
integerFunc.expressionInterfaceImpl.Parent = intWindowFunc
return intWindowFunc
}
type stringFunc struct {
funcExpressionImpl
stringInterfaceImpl

View file

@ -32,8 +32,8 @@ func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl
return &exp
}
// ConstLiteral is injected directly to SQL query, and does not appear in argument list.
func ConstLiteral(value interface{}) *literalExpressionImpl {
// FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list.
func FixedLiteral(value interface{}) *literalExpressionImpl {
exp := literal(value)
exp.constant = true

View file

@ -0,0 +1,146 @@
package jet
type commonWindowImpl struct {
expression Expression
window Window
}
func (w *commonWindowImpl) over(window ...Window) {
if len(window) > 0 {
w.window = window[0]
} else {
w.window = newWindowImpl(nil)
}
}
func (w *commonWindowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
w.expression.serialize(statement, out)
if w.window != nil {
out.WriteString("OVER")
w.window.serialize(statement, out)
}
}
// --------------------------------------
type windowExpression interface {
Expression
OVER(window ...Window) Expression
}
func newWindowExpression(Exp Expression) windowExpression {
newExp := &windowExpressionImpl{
Expression: Exp,
}
newExp.commonWindowImpl.expression = Exp
return newExp
}
type windowExpressionImpl struct {
Expression
commonWindowImpl
}
func (f *windowExpressionImpl) OVER(window ...Window) Expression {
f.commonWindowImpl.over(window...)
return f
}
func (f *windowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}
// -----------------------------------------------------
type floatWindowExpression interface {
FloatExpression
OVER(window ...Window) FloatExpression
}
func newFloatWindowExpression(floatExp FloatExpression) floatWindowExpression {
newExp := &floatWindowExpressionImpl{
FloatExpression: floatExp,
}
newExp.commonWindowImpl.expression = floatExp
return newExp
}
type floatWindowExpressionImpl struct {
FloatExpression
commonWindowImpl
}
func (f *floatWindowExpressionImpl) OVER(window ...Window) FloatExpression {
f.commonWindowImpl.over(window...)
return f
}
func (f *floatWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}
// ------------------------------------------------
type integerWindowExpression interface {
IntegerExpression
OVER(window ...Window) IntegerExpression
}
func newIntegerWindowExpression(intExp IntegerExpression) integerWindowExpression {
newExp := &integerWindowExpressionImpl{
IntegerExpression: intExp,
}
newExp.commonWindowImpl.expression = intExp
return newExp
}
type integerWindowExpressionImpl struct {
IntegerExpression
commonWindowImpl
}
func (f *integerWindowExpressionImpl) OVER(window ...Window) IntegerExpression {
f.commonWindowImpl.over(window...)
return f
}
func (f *integerWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}
// ------------------------------------------------
type boolWindowExpression interface {
BoolExpression
OVER(window ...Window) BoolExpression
}
func newBoolWindowExpression(boolExp BoolExpression) boolWindowExpression {
newExp := &boolWindowExpressionImpl{
BoolExpression: boolExp,
}
newExp.commonWindowImpl.expression = boolExp
return newExp
}
type boolWindowExpressionImpl struct {
BoolExpression
commonWindowImpl
}
func (f *boolWindowExpressionImpl) OVER(window ...Window) BoolExpression {
f.commonWindowImpl.over(window...)
return f
}
func (f *boolWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}

186
internal/jet/window_func.go Normal file
View file

@ -0,0 +1,186 @@
package jet
// Window interface
type Window interface {
Serializer
ORDER_BY(expr ...OrderByClause) Window
ROWS(start FrameExtent, end ...FrameExtent) Window
RANGE(start FrameExtent, end ...FrameExtent) Window
GROUPS(start FrameExtent, end ...FrameExtent) Window
}
type windowImpl struct {
partitionBy []Expression
orderBy ClauseOrderBy
frameUnits string
start, end FrameExtent
parent Window
}
func newWindowImpl(parent Window) *windowImpl {
newWindow := &windowImpl{}
if parent == nil {
newWindow.parent = newWindow
} else {
newWindow.parent = parent
}
return newWindow
}
func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, noWrap) {
out.WriteByte('(')
}
if w.partitionBy != nil {
out.WriteString("PARTITION BY")
serializeExpressionList(statement, w.partitionBy, ", ", out)
}
w.orderBy.SkipNewLine = true
w.orderBy.Serialize(statement, out)
if w.frameUnits != "" {
out.WriteString(w.frameUnits)
if w.end == nil {
w.start.serialize(statement, out)
} else {
out.WriteString("BETWEEN")
w.start.serialize(statement, out)
out.WriteString("AND")
w.end.serialize(statement, out)
}
}
if !contains(options, noWrap) {
out.WriteByte(')')
}
}
func (w *windowImpl) ORDER_BY(exprs ...OrderByClause) Window {
w.orderBy.List = exprs
return w.parent
}
func (w *windowImpl) ROWS(start FrameExtent, end ...FrameExtent) Window {
w.frameUnits = "ROWS"
w.setFrameRange(start, end...)
return w.parent
}
func (w *windowImpl) RANGE(start FrameExtent, end ...FrameExtent) Window {
w.frameUnits = "RANGE"
w.setFrameRange(start, end...)
return w.parent
}
func (w *windowImpl) GROUPS(start FrameExtent, end ...FrameExtent) Window {
w.frameUnits = "GROUPS"
w.setFrameRange(start, end...)
return w.parent
}
func (w *windowImpl) setFrameRange(start FrameExtent, end ...FrameExtent) {
w.start = start
if len(end) > 0 {
w.end = end[0]
}
}
// PARTITION_BY window function constructor
func PARTITION_BY(exp Expression, exprs ...Expression) Window {
funImpl := newWindowImpl(nil)
funImpl.partitionBy = append([]Expression{exp}, exprs...)
return funImpl
}
// ORDER_BY window function constructor
func ORDER_BY(expr ...OrderByClause) Window {
funImpl := newWindowImpl(nil)
funImpl.orderBy.List = expr
return funImpl
}
// -----------------------------------------------
// FrameExtent interface
type FrameExtent interface {
Serializer
isFrameExtent()
}
// PRECEDING window frame clause
func PRECEDING(offset Serializer) FrameExtent {
return &frameExtentImpl{
preceding: true,
offset: offset,
}
}
// FOLLOWING window frame clause
func FOLLOWING(offset Serializer) FrameExtent {
return &frameExtentImpl{
preceding: false,
offset: offset,
}
}
type frameExtentImpl struct {
preceding bool
offset Serializer
}
func (f *frameExtentImpl) isFrameExtent() {}
func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if f == nil {
return
}
f.offset.serialize(statement, out)
if f.preceding {
out.WriteString("PRECEDING")
} else {
out.WriteString("FOLLOWING")
}
}
// -----------------------------------------------
// Window function keywords
var (
UNBOUNDED = keywordClause("UNBOUNDED")
CURRENT_ROW = frameExtentKeyword{"CURRENT ROW"}
)
type frameExtentKeyword struct {
keywordClause
}
func (f frameExtentKeyword) isFrameExtent() {}
// -----------------------------------------------
// WindowName is used to specify window reference from WINDOW clause
func WindowName(name string) Window {
newWindow := &windowName{name: name}
newWindow.parent = newWindow
return newWindow
}
type windowName struct {
windowImpl
name string
}
func (w windowName) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteByte('(')
out.WriteString(w.name)
w.windowImpl.serialize(statement, out, noWrap)
out.WriteByte(')')
}

View file

@ -0,0 +1,21 @@
package jet
import "testing"
func TestFrameExtent(t *testing.T) {
assertClauseSerialize(t, PRECEDING(Int(2)), "$1 PRECEDING", int64(2))
assertClauseSerialize(t, FOLLOWING(Int(4)), "$1 FOLLOWING", int64(4))
}
func TestWindowFunctions(t *testing.T) {
assertClauseSerialize(t, PARTITION_BY(table1Col1), "(PARTITION BY table1.col1)")
assertClauseSerialize(t, PARTITION_BY(table1Col3).ORDER_BY(table1Col1), "(PARTITION BY table1.col3 ORDER BY table1.col1)")
assertClauseSerialize(t, ORDER_BY(table1Col1), "(ORDER BY table1.col1)")
assertClauseSerialize(t, ORDER_BY(table1Col1).ROWS(PRECEDING(Int(1))), "(ORDER BY table1.col1 ROWS $1 PRECEDING)", int64(1))
assertClauseSerialize(t, ORDER_BY(table1Col1).ROWS(PRECEDING(Int(1)), FOLLOWING(Int(33))),
"(ORDER BY table1.col1 ROWS BETWEEN $1 PRECEDING AND $2 FOLLOWING)", int64(1), int64(33))
assertClauseSerialize(t, ORDER_BY(table1Col1).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)),
"(ORDER BY table1.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)")
assertClauseSerialize(t, ORDER_BY(table1Col1).RANGE(PRECEDING(UNBOUNDED), CURRENT_ROW),
"(ORDER BY table1.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)")
}