diff --git a/sqlbuilder/bool_expression_test.go b/sqlbuilder/bool_expression_test.go index 8c28200..9efbf9a 100644 --- a/sqlbuilder/bool_expression_test.go +++ b/sqlbuilder/bool_expression_test.go @@ -58,7 +58,7 @@ func TestBinaryExpression(t *testing.T) { err := alias.serializeForProjection(select_statement, &out) assert.NilError(t, err) - assert.Equal(t, out.buff.String(), `($1 = $2) AS "alias_eq_expression"`) + assert.Equal(t, out.buff.String(), `$1 = $2 AS "alias_eq_expression"`) }) t.Run("and", func(t *testing.T) { diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index 09e3d6e..932c7f4 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -56,15 +56,15 @@ func (e *expressionInterfaceImpl) DESC() orderByClause { } func (e *expressionInterfaceImpl) serializeForGroupBy(statement statementType, out *queryData) error { - return e.parent.serialize(statement, out) + return e.parent.serialize(statement, out, NO_WRAP) } func (e *expressionInterfaceImpl) serializeForProjection(statement statementType, out *queryData) error { - return e.parent.serialize(statement, out) + return e.parent.serialize(statement, out, NO_WRAP) } func (e *expressionInterfaceImpl) serializeAsOrderBy(statement statementType, out *queryData) error { - return e.parent.serialize(statement, out) + return e.parent.serialize(statement, out, NO_WRAP) } // Representation of binary operations (e.g. comparisons, arithmetic) diff --git a/sqlbuilder/expression_old.go b/sqlbuilder/expression_old.go index b1bae4f..82f06e3 100644 --- a/sqlbuilder/expression_old.go +++ b/sqlbuilder/expression_old.go @@ -1,4 +1,4 @@ -// Query building functions for expression components +// Query building functions for expressions components package sqlbuilder import ( @@ -42,7 +42,7 @@ func (c *intervalExpression) serialize(statement statementType, out *queryData, } //// Interval returns a representation of duration -//func Interval(duration time.Duration) expression { +//func Interval(duration time.Duration) expressions { // intervalExp := &intervalExpression{ // duration: duration, // } diff --git a/sqlbuilder/float_expression.go b/sqlbuilder/float_expression.go index bb19674..4004da0 100644 --- a/sqlbuilder/float_expression.go +++ b/sqlbuilder/float_expression.go @@ -19,6 +19,8 @@ type FloatExpression interface { SUB(rhs FloatExpression) FloatExpression MUL(rhs FloatExpression) FloatExpression DIV(rhs FloatExpression) FloatExpression + MOD(rhs FloatExpression) FloatExpression + POW(rhs FloatExpression) FloatExpression } type floatInterfaceImpl struct { @@ -73,6 +75,14 @@ func (n *floatInterfaceImpl) DIV(expression FloatExpression) FloatExpression { return newBinaryFloatExpression(n.parent, expression, "/") } +func (n *floatInterfaceImpl) MOD(expression FloatExpression) FloatExpression { + return newBinaryFloatExpression(n.parent, expression, "%") +} + +func (n *floatInterfaceImpl) POW(expression FloatExpression) FloatExpression { + return newBinaryFloatExpression(n.parent, expression, "^") +} + //---------------------------------------------------// type binaryFloatExpression struct { expressionInterfaceImpl @@ -113,7 +123,7 @@ func newFloatExpressionWrap(expression expression) FloatExpression { func (n *floatExpressionWrapper) serialize(statement statementType, out *queryData, options ...serializeOption) error { if n == nil { - return errors.New("Float expression wrapper is nil. ") + return errors.New("Float expressions wrapper is nil. ") } //out.writeString("(") err := n.expression.serialize(statement, out) diff --git a/sqlbuilder/float_expression_test.go b/sqlbuilder/float_expression_test.go index 958c0b4..ae249f4 100644 --- a/sqlbuilder/float_expression_test.go +++ b/sqlbuilder/float_expression_test.go @@ -6,33 +6,53 @@ import ( ) func TestFloatExpressionEQColumn(t *testing.T) { - assert.Equal(t, getTestSerialize(t, table1Col1.EQ(table2Col3)), "(table1.col1 = table2.col3)") -} - -func TestFloatExpressionEQInt(t *testing.T) { - assert.Equal(t, getTestSerialize(t, table1Col1.EQ(Int(11))), "(table1.col1 = $1)") + assert.Equal(t, getTestSerialize(t, table1ColFloat.EQ(table2ColFloat)), "(table1.colFloat = table2.colFloat)") } func TestFloatExpressionEQFloat(t *testing.T) { - assert.Equal(t, getTestSerialize(t, table1Col1.EQ(Int(22))), "(table1.col1 = $1)") + assert.Equal(t, getTestSerialize(t, table1ColFloat.EQ(Float(11))), "(table1.colFloat = $1)") } func TestFloatExpressionNOT_EQ(t *testing.T) { - assert.Equal(t, getTestSerialize(t, table1Col1.NOT_EQ(table2Col3)), "(table1.col1 != table2.col3)") + assert.Equal(t, getTestSerialize(t, table1ColFloat.NOT_EQ(table2ColFloat)), "(table1.colFloat != table2.colFloat)") } func TestFloatExpressionGT(t *testing.T) { - assert.Equal(t, getTestSerialize(t, table1Col1.GT(table2Col3)), "(table1.col1 > table2.col3)") + assert.Equal(t, getTestSerialize(t, table1ColFloat.GT(table2ColFloat)), "(table1.colFloat > table2.colFloat)") } func TestFloatExpressionGT_EQ(t *testing.T) { - assert.Equal(t, getTestSerialize(t, table1Col1.GT_EQ(table2Col3)), "(table1.col1 >= table2.col3)") + assert.Equal(t, getTestSerialize(t, table1ColFloat.GT_EQ(table2ColFloat)), "(table1.colFloat >= table2.colFloat)") } func TestFloatExpressionLT(t *testing.T) { - assert.Equal(t, getTestSerialize(t, table1Col1.LT(table2Col3)), "(table1.col1 < table2.col3)") + assert.Equal(t, getTestSerialize(t, table1ColFloat.LT(table2ColFloat)), "(table1.colFloat < table2.colFloat)") } func TestFloatExpressionLT_EQ(t *testing.T) { - assert.Equal(t, getTestSerialize(t, table1Col1.LT_EQ(table2Col3)), "(table1.col1 <= table2.col3)") + assert.Equal(t, getTestSerialize(t, table1ColFloat.LT_EQ(table2ColFloat)), "(table1.colFloat <= table2.colFloat)") +} + +func TestFloatExpressionADD(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColFloat.ADD(table2ColFloat)), "(table1.colFloat + table2.colFloat)") +} + +func TestFloatExpressionSUB(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColFloat.SUB(table2ColFloat)), "(table1.colFloat - table2.colFloat)") +} + +func TestFloatExpressionMUL(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColFloat.MUL(table2ColFloat)), "(table1.colFloat * table2.colFloat)") +} + +func TestFloatExpressionDIV(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColFloat.DIV(table2ColFloat)), "(table1.colFloat / table2.colFloat)") +} + +func TestFloatExpressionMOD(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColFloat.MOD(table2ColFloat)), "(table1.colFloat % table2.colFloat)") +} + +func TestFloatExpressionEXP(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColFloat.POW(table2ColFloat)), "(table1.colFloat ^ table2.colFloat)") } diff --git a/sqlbuilder/func_expression.go b/sqlbuilder/func_expression.go index 5d30433..2509cee 100644 --- a/sqlbuilder/func_expression.go +++ b/sqlbuilder/func_expression.go @@ -5,8 +5,8 @@ import "errors" type funcExpressionImpl struct { expressionInterfaceImpl - name string - expression []expression + name string + expressions []expression } func ROW(expressions ...expression) expression { @@ -15,8 +15,8 @@ func ROW(expressions ...expression) expression { func newFunc(name string, expressions []expression, parent expression) *funcExpressionImpl { funcExp := &funcExpressionImpl{ - name: name, - expression: expressions, + name: name, + expressions: expressions, } if parent != nil { @@ -30,12 +30,12 @@ func newFunc(name string, expressions []expression, parent expression) *funcExpr func (f *funcExpressionImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { if f == nil { - return errors.New("Function expression is nil. ") + return errors.New("Function expressions is nil. ") } out.writeString(f.name + "(") - err := serializeExpressionList(statement, f.expression, ", ", out) + err := serializeExpressionList(statement, f.expressions, ", ", out) if err != nil { return err } @@ -44,8 +44,6 @@ func (f *funcExpressionImpl) serialize(statement statementType, out *queryData, return nil } -// ------------------- FLOAT FUNCTIONS --------------------------// - type floatFunc struct { funcExpressionImpl floatInterfaceImpl @@ -60,20 +58,6 @@ func newFloatFunc(name string, expressions ...expression) FloatExpression { return floatFunc } -func COUNTf(floatExpression FloatExpression) FloatExpression { - return newFloatFunc("COUNT", floatExpression) -} - -func MAXf(floatExpression FloatExpression) FloatExpression { - return newFloatFunc("MAX", floatExpression) -} - -func SUMf(floatExpression FloatExpression) FloatExpression { - return newFloatFunc("SUM", floatExpression) -} - -// ------------------- FLOAT FUNCTIONS --------------------------// - type integerFunc struct { funcExpressionImpl integerInterfaceImpl @@ -88,14 +72,88 @@ func newIntegerFunc(name string, expressions ...expression) IntegerExpression { return floatFunc } -func COUNTi(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("COUNT", integerExpression) +// ------------------ Mathematical functions ---------------// + +func ABSf(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("ABS", floatExpression) +} + +func ABSi(integerExpression IntegerExpression) FloatExpression { + return newFloatFunc("ABS", integerExpression) +} + +func SQRTf(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("SQRT", floatExpression) +} + +func SQRTi(integerExpression IntegerExpression) FloatExpression { + return newFloatFunc("SQRT", integerExpression) +} + +func CBRTf(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("CBRT", floatExpression) +} + +func CBRTi(integerExpression IntegerExpression) FloatExpression { + return newFloatFunc("CBRT", integerExpression) +} + +func CEIL(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("CEIL", floatExpression) +} + +func FLOOR(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("FLOOR", floatExpression) +} + +func ROUND(floatExpression FloatExpression, intExpression ...IntegerExpression) FloatExpression { + if len(intExpression) > 0 { + return newFloatFunc("ROUND", floatExpression, intExpression[0]) + } + return newFloatFunc("ROUND", floatExpression) +} + +func SIGN(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("SIGN", floatExpression) +} + +func TRUNC(floatExpression FloatExpression, intExpression ...IntegerExpression) FloatExpression { + if len(intExpression) > 0 { + return newFloatFunc("TRUNC", floatExpression, intExpression[0]) + } + return newFloatFunc("TRUNC", floatExpression) +} + +func LN(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("LN", floatExpression) +} + +func LOG(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("LOG", floatExpression) +} + +// ----------------- Group function operators -------------------// + +func MAXf(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("MAX", floatExpression) } func MAXi(integerExpression IntegerExpression) IntegerExpression { return newIntegerFunc("MAX", integerExpression) } +func SUMf(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("SUM", floatExpression) +} + func SUMi(integerExpression IntegerExpression) IntegerExpression { return newIntegerFunc("SUM", integerExpression) } + +func COUNTf(floatExpression FloatExpression) FloatExpression { + return newFloatFunc("COUNT", floatExpression) +} + +func COUNTi(integerExpression IntegerExpression) IntegerExpression { + return newIntegerFunc("COUNT", integerExpression) +} diff --git a/sqlbuilder/func_expression_test.go b/sqlbuilder/func_expression_test.go index 72861f1..0e92592 100644 --- a/sqlbuilder/func_expression_test.go +++ b/sqlbuilder/func_expression_test.go @@ -5,6 +5,119 @@ import ( "testing" ) +func TestFuncABS(t *testing.T) { + t.Run("float", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, ABSf(table1ColFloat)), "ABS(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, ABSf(Float(11.2222))), "ABS($1)") + }) + + t.Run("integer", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, ABSi(table1ColInt)), "ABS(table1.colInt)") + assert.Equal(t, getTestSerialize(t, ABSi(Int(11))), "ABS($1)") + }) +} + +func TestFuncSQRT(t *testing.T) { + t.Run("float", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, SQRTf(table1ColFloat)), "SQRT(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, SQRTf(Float(11.2222))), "SQRT($1)") + }) + + t.Run("integer", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, SQRTi(table1ColInt)), "SQRT(table1.colInt)") + assert.Equal(t, getTestSerialize(t, SQRTi(Int(11))), "SQRT($1)") + }) +} + +func TestFuncCBRT(t *testing.T) { + t.Run("float", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, CBRTf(table1ColFloat)), "CBRT(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, CBRTf(Float(11.2222))), "CBRT($1)") + }) + + t.Run("integer", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, CBRTi(table1ColInt)), "CBRT(table1.colInt)") + assert.Equal(t, getTestSerialize(t, CBRTi(Int(11))), "CBRT($1)") + }) +} + +func TestFuncMAX(t *testing.T) { + t.Run("float", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, MAXf(table1ColFloat)), "MAX(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, MAXf(Float(11.2222))), "MAX($1)") + }) + + t.Run("integer", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, MAXi(table1ColInt)), "MAX(table1.colInt)") + assert.Equal(t, getTestSerialize(t, MAXi(Int(11))), "MAX($1)") + }) +} + +func TestFuncSUM(t *testing.T) { + t.Run("float", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, SUMf(table1ColFloat)), "SUM(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, SUMf(Float(11.2222))), "SUM($1)") + }) + + t.Run("integer", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, SUMi(table1ColInt)), "SUM(table1.colInt)") + assert.Equal(t, getTestSerialize(t, SUMi(Int(11))), "SUM($1)") + }) +} + +func TestFuncCOUNT(t *testing.T) { + t.Run("float", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, COUNTf(table1ColFloat)), "COUNT(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, COUNTf(Float(11.2222))), "COUNT($1)") + }) + + t.Run("integer", func(t *testing.T) { + assert.Equal(t, getTestSerialize(t, COUNTi(table1ColInt)), "COUNT(table1.colInt)") + assert.Equal(t, getTestSerialize(t, COUNTi(Int(11))), "COUNT($1)") + }) +} + +func TestFuncCEIL(t *testing.T) { + assert.Equal(t, getTestSerialize(t, CEIL(table1ColFloat)), "CEIL(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, CEIL(Float(11.2222))), "CEIL($1)") +} + +func TestFuncFLOOR(t *testing.T) { + assert.Equal(t, getTestSerialize(t, FLOOR(table1ColFloat)), "FLOOR(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, FLOOR(Float(11.2222))), "FLOOR($1)") +} + +func TestFuncROUND(t *testing.T) { + assert.Equal(t, getTestSerialize(t, ROUND(table1ColFloat)), "ROUND(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, ROUND(Float(11.2222))), "ROUND($1)") + + assert.Equal(t, getTestSerialize(t, ROUND(table1ColFloat, Int(2))), "ROUND(table1.colFloat, $1)") + assert.Equal(t, getTestSerialize(t, ROUND(Float(11.2222), Int(1))), "ROUND($1, $2)") +} + +func TestFuncSIGN(t *testing.T) { + assert.Equal(t, getTestSerialize(t, SIGN(table1ColFloat)), "SIGN(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, SIGN(Float(11.2222))), "SIGN($1)") +} + +func TestFuncTRUNC(t *testing.T) { + assert.Equal(t, getTestSerialize(t, TRUNC(table1ColFloat)), "TRUNC(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, TRUNC(Float(11.2222))), "TRUNC($1)") + + assert.Equal(t, getTestSerialize(t, TRUNC(table1ColFloat, Int(2))), "TRUNC(table1.colFloat, $1)") + assert.Equal(t, getTestSerialize(t, TRUNC(Float(11.2222), Int(1))), "TRUNC($1, $2)") +} + +func TestFuncLN(t *testing.T) { + assert.Equal(t, getTestSerialize(t, LN(table1ColFloat)), "LN(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, LN(Float(11.2222))), "LN($1)") +} + +func TestFuncLOG(t *testing.T) { + assert.Equal(t, getTestSerialize(t, LOG(table1ColFloat)), "LOG(table1.colFloat)") + assert.Equal(t, getTestSerialize(t, LOG(Float(11.2222))), "LOG($1)") +} + func TestCase1(t *testing.T) { query := CASE(). WHEN(table3Col1.EQ(Int(1))).THEN(table3Col1.ADD(Int(1))). diff --git a/sqlbuilder/integer_expression.go b/sqlbuilder/integer_expression.go index c5e3b48..6598f29 100644 --- a/sqlbuilder/integer_expression.go +++ b/sqlbuilder/integer_expression.go @@ -17,11 +17,15 @@ type IntegerExpression interface { SUB(rhs IntegerExpression) IntegerExpression MUL(rhs IntegerExpression) IntegerExpression DIV(rhs IntegerExpression) IntegerExpression + MOD(rhs IntegerExpression) IntegerExpression + POW(rhs IntegerExpression) IntegerExpression - BitAnd(expression IntegerExpression) IntegerExpression - BitOr(expression IntegerExpression) IntegerExpression - BitXor(expression IntegerExpression) IntegerExpression - BitNot() IntegerExpression + BIT_AND(expression IntegerExpression) IntegerExpression + BIT_OR(expression IntegerExpression) IntegerExpression + BIT_XOR(expression IntegerExpression) IntegerExpression + BIT_NOT() IntegerExpression + BIT_SHIFT_LEFT(intExpression IntegerExpression) IntegerExpression + BIT_SHIFT_RIGHT(intExpression IntegerExpression) IntegerExpression } type integerInterfaceImpl struct { @@ -61,35 +65,51 @@ func (i *integerInterfaceImpl) LT_EQ(expression IntegerExpression) BoolExpressio } func (i *integerInterfaceImpl) ADD(expression IntegerExpression) IntegerExpression { - return NewBinaryIntegerExpression(i.parent, expression, "+") + return newBinaryIntegerExpression(i.parent, expression, "+") } func (i *integerInterfaceImpl) SUB(expression IntegerExpression) IntegerExpression { - return NewBinaryIntegerExpression(i.parent, expression, "-") + return newBinaryIntegerExpression(i.parent, expression, "-") } func (i *integerInterfaceImpl) MUL(expression IntegerExpression) IntegerExpression { - return NewBinaryIntegerExpression(i.parent, expression, "*") + return newBinaryIntegerExpression(i.parent, expression, "*") } func (i *integerInterfaceImpl) DIV(expression IntegerExpression) IntegerExpression { - return NewBinaryIntegerExpression(i.parent, expression, "/") + return newBinaryIntegerExpression(i.parent, expression, "/") } -func (i *integerInterfaceImpl) BitAnd(expression IntegerExpression) IntegerExpression { - return NewBinaryIntegerExpression(i.parent, expression, "&") +func (n *integerInterfaceImpl) MOD(expression IntegerExpression) IntegerExpression { + return newBinaryIntegerExpression(n.parent, expression, "%") } -func (i *integerInterfaceImpl) BitOr(expression IntegerExpression) IntegerExpression { - return NewBinaryIntegerExpression(i.parent, expression, "|") +func (n *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpression { + return newBinaryIntegerExpression(n.parent, expression, "^") } -func (i *integerInterfaceImpl) BitXor(expression IntegerExpression) IntegerExpression { - return NewBinaryIntegerExpression(i.parent, expression, "#") +func (i *integerInterfaceImpl) BIT_AND(expression IntegerExpression) IntegerExpression { + return newBinaryIntegerExpression(i.parent, expression, "&") } -func (i *integerInterfaceImpl) BitNot() IntegerExpression { - return NewPrefixIntegerOpExpression(i.parent, "~") +func (i *integerInterfaceImpl) BIT_OR(expression IntegerExpression) IntegerExpression { + return newBinaryIntegerExpression(i.parent, expression, "|") +} + +func (i *integerInterfaceImpl) BIT_XOR(expression IntegerExpression) IntegerExpression { + return newBinaryIntegerExpression(i.parent, expression, "#") +} + +func (i *integerInterfaceImpl) BIT_NOT() IntegerExpression { + return newPrefixIntegerOpExpression(i.parent, "~") +} + +func (i *integerInterfaceImpl) BIT_SHIFT_LEFT(intExpression IntegerExpression) IntegerExpression { + return newBinaryIntegerExpression(i.parent, intExpression, "<<") +} + +func (i *integerInterfaceImpl) BIT_SHIFT_RIGHT(intExpression IntegerExpression) IntegerExpression { + return newBinaryIntegerExpression(i.parent, intExpression, ">>") } //---------------------------------------------------// @@ -100,7 +120,7 @@ type binaryIntegerExpression struct { binaryOpExpression } -func NewBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression { +func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression { integerExpression := binaryIntegerExpression{} integerExpression.expressionInterfaceImpl.parent = &integerExpression @@ -119,7 +139,7 @@ type prefixIntegerOpExpression struct { prefixOpExpression } -func NewPrefixIntegerOpExpression(expression IntegerExpression, operator string) IntegerExpression { +func newPrefixIntegerOpExpression(expression IntegerExpression, operator string) IntegerExpression { integerExpression := prefixIntegerOpExpression{} integerExpression.prefixOpExpression = newPrefixExpression(expression, operator) diff --git a/sqlbuilder/integer_expression_test.go b/sqlbuilder/integer_expression_test.go new file mode 100644 index 0000000..7a759b7 --- /dev/null +++ b/sqlbuilder/integer_expression_test.go @@ -0,0 +1,68 @@ +package sqlbuilder + +import ( + "gotest.tools/assert" + "testing" +) + +func TestIntegerExpressionEQColumn(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.EQ(table2ColInt)), "(table1.colInt = table2.colInt)") +} + +func TestIntegerExpressionEQInt(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.EQ(Int(11))), "(table1.colInt = $1)") +} + +func TestIntegerExpressionNOT_EQ(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.NOT_EQ(table2ColInt)), "(table1.colInt != table2.colInt)") +} + +func TestIntegerExpressionGT(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.GT(table2ColInt)), "(table1.colInt > table2.colInt)") +} + +func TestIntegerExpressionGT_EQ(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.GT_EQ(table2ColInt)), "(table1.colInt >= table2.colInt)") +} + +func TestIntegerExpressionLT(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.LT(table2ColInt)), "(table1.colInt < table2.colInt)") +} + +func TestIntegerExpressionLT_EQ(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.LT_EQ(table2ColInt)), "(table1.colInt <= table2.colInt)") +} + +func TestIntegerExpressionADD(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.ADD(table2ColInt)), "(table1.colInt + table2.colInt)") +} + +func TestIntegerExpressionSUB(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.SUB(table2ColInt)), "(table1.colInt - table2.colInt)") +} + +func TestIntegerExpressionMUL(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.MUL(table2ColInt)), "(table1.colInt * table2.colInt)") +} + +func TestIntegerExpressionDIV(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.DIV(table2ColInt)), "(table1.colInt / table2.colInt)") +} + +func TestIntExpressionMOD(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.MOD(table2ColInt)), "(table1.colInt % table2.colInt)") +} + +func TestIntExpressionEXP(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.POW(table2ColInt)), "(table1.colInt ^ table2.colInt)") +} + +func TestIntExpressionBIT_SHIFT_LEFT(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.BIT_SHIFT_LEFT(table2ColInt)), "(table1.colInt << table2.colInt)") + assert.Equal(t, getTestSerialize(t, table1ColInt.BIT_SHIFT_LEFT(Int(2))), "(table1.colInt << $1)") +} + +func TestIntExpressionBIT_SHIFT_RIGHT(t *testing.T) { + assert.Equal(t, getTestSerialize(t, table1ColInt.BIT_SHIFT_RIGHT(table2ColInt)), "(table1.colInt >> table2.colInt)") + assert.Equal(t, getTestSerialize(t, table1ColInt.BIT_SHIFT_RIGHT(Int(11))), "(table1.colInt >> $1)") +} diff --git a/sqlbuilder/literal_expression.go b/sqlbuilder/literal_expression.go index 2121600..1217b54 100644 --- a/sqlbuilder/literal_expression.go +++ b/sqlbuilder/literal_expression.go @@ -26,7 +26,7 @@ type integerLiteralExpression struct { integerInterfaceImpl } -func Int(value int) IntegerExpression { +func Int(value int64) IntegerExpression { numLiteral := &integerLiteralExpression{} numLiteral.literalExpression = *Literal(value) diff --git a/sqlbuilder/string_expression_test.go b/sqlbuilder/string_expression_test.go index 7631d4d..8a3c8f2 100644 --- a/sqlbuilder/string_expression_test.go +++ b/sqlbuilder/string_expression_test.go @@ -6,7 +6,7 @@ import ( ) func TestStringEQColumn(t *testing.T) { - exp := table3StrCol.EQ(table2StrCol) + exp := table3StrCol.EQ(table2ColStr) out := queryData{} err := exp.serialize(select_statement, &out) @@ -26,7 +26,7 @@ func TestStringEQString(t *testing.T) { } func TestStringNOT_EQ(t *testing.T) { - exp := table3StrCol.NOT_EQ(table2StrCol) + exp := table3StrCol.NOT_EQ(table2ColStr) out := queryData{} err := exp.serialize(select_statement, &out) @@ -36,7 +36,7 @@ func TestStringNOT_EQ(t *testing.T) { } func TestStringGT(t *testing.T) { - exp := table3StrCol.GT(table2StrCol) + exp := table3StrCol.GT(table2ColStr) out := queryData{} err := exp.serialize(select_statement, &out) @@ -46,7 +46,7 @@ func TestStringGT(t *testing.T) { } func TestStringGT_EQ(t *testing.T) { - exp := table3StrCol.GT_EQ(table2StrCol) + exp := table3StrCol.GT_EQ(table2ColStr) out := queryData{} err := exp.serialize(select_statement, &out) @@ -56,7 +56,7 @@ func TestStringGT_EQ(t *testing.T) { } func TestStringLT(t *testing.T) { - exp := table3StrCol.LT(table2StrCol) + exp := table3StrCol.LT(table2ColStr) out := queryData{} err := exp.serialize(select_statement, &out) @@ -66,7 +66,7 @@ func TestStringLT(t *testing.T) { } func TestStringLT_EQ(t *testing.T) { - exp := table3StrCol.LT_EQ(table2StrCol) + exp := table3StrCol.LT_EQ(table2ColStr) out := queryData{} err := exp.serialize(select_statement, &out) diff --git a/sqlbuilder/test_utils.go b/sqlbuilder/test_utils.go index 987848b..32204f4 100644 --- a/sqlbuilder/test_utils.go +++ b/sqlbuilder/test_utils.go @@ -6,6 +6,7 @@ import ( ) var table1Col1 = NewIntegerColumn("col1", Nullable) +var table1ColInt = NewIntegerColumn("colInt", Nullable) var table1ColFloat = NewFloatColumn("colFloat", Nullable) var table1Col3 = NewIntegerColumn("col3", Nullable) var table1ColTime = NewTimeColumn("colTime", Nullable) @@ -15,6 +16,7 @@ var table1 = NewTable( "db", "table1", table1Col1, + table1ColInt, table1ColFloat, table1Col3, table1ColTime, @@ -22,7 +24,9 @@ var table1 = NewTable( var table2Col3 = NewIntegerColumn("col3", Nullable) var table2Col4 = NewIntegerColumn("col4", Nullable) -var table2StrCol = NewStringColumn("colStr", Nullable) +var table2ColInt = NewIntegerColumn("colInt", Nullable) +var table2ColFloat = NewFloatColumn("colFloat", Nullable) +var table2ColStr = NewStringColumn("colStr", Nullable) var table2ColBool = NewBoolColumn("colBool", Nullable) var table2ColTime = NewTimeColumn("colTime", Nullable) @@ -31,7 +35,9 @@ var table2 = NewTable( "table2", table2Col3, table2Col4, - table2StrCol, + table2ColInt, + table2ColFloat, + table2ColStr, table2ColBool, table2ColTime) diff --git a/tests/select_test.go b/tests/select_test.go index 621ef58..a4387a6 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -26,7 +26,7 @@ WHERE actor.actor_id = 1; SELECT(Actor.AllColumns). WHERE(Actor.ActorID.EQ(Int(1))) - assertQuery(t, query, expectedSql, 1) + assertQuery(t, query, expectedSql, int64(1)) actor := model.Actor{} err := query.Query(db, &actor) @@ -1077,7 +1077,7 @@ LIMIT 20; ORDER_BY(Payment.PaymentID.ASC()). LIMIT(20) - assertQuery(t, query, expectedQuery, 1, "ONE", 2, "TWO", 3, "THREE", "OTHER", int64(20)) + assertQuery(t, query, expectedQuery, int64(1), "ONE", int64(2), "TWO", int64(3), "THREE", "OTHER", int64(20)) dest := []struct { StaffIdNum string diff --git a/tests/types_test.go b/tests/types_test.go index 73887fe..cb71532 100644 --- a/tests/types_test.go +++ b/tests/types_test.go @@ -88,14 +88,11 @@ func TestBoolOperators(t *testing.T) { assert.NilError(t, err) } -func TestNumericOperators(t *testing.T) { +func TestFloatOperators(t *testing.T) { query := AllTypes.SELECT( AllTypes.Numeric.EQ(AllTypes.Numeric), AllTypes.Decimal.EQ(Float(12)), AllTypes.Real.EQ(Float(12.12)), - //AllTypes.Smallint.NOT_EQ(AllTypes.Real), - AllTypes.Integer.NOT_EQ(Int(12)), - AllTypes.Bigint.NOT_EQ(Int(12)), AllTypes.Numeric.IS_DISTINCT_FROM(AllTypes.Numeric), AllTypes.Decimal.IS_DISTINCT_FROM(Float(12)), AllTypes.Real.IS_DISTINCT_FROM(Float(12.12)), @@ -105,15 +102,80 @@ func TestNumericOperators(t *testing.T) { //AllTypes.Numeric.LT(AllTypes.Integer), AllTypes.Numeric.LT(Float(124)), AllTypes.Numeric.LT(Float(34.56)), - //AllTypes.Smallint.LT_EQ(AllTypes.Numeric), - AllTypes.Integer.LT_EQ(Int(45)), - AllTypes.Bigint.LT_EQ(Int(65)), //AllTypes.Numeric.GT(AllTypes.Smallint), AllTypes.Numeric.GT(Float(124)), AllTypes.Numeric.GT(Float(34.56)), + + AllTypes.Real.ADD(AllTypes.RealPtr), + AllTypes.Real.ADD(Float(11.22)), + AllTypes.Real.SUB(AllTypes.RealPtr), + AllTypes.Real.SUB(Float(11.22)), + AllTypes.Real.MUL(AllTypes.RealPtr), + AllTypes.Real.MUL(Float(11.22)), + AllTypes.Real.DIV(AllTypes.RealPtr), + AllTypes.Real.DIV(Float(11.22)), + AllTypes.Decimal.MOD(AllTypes.Decimal), + AllTypes.Decimal.MOD(Float(11.22)), + AllTypes.Real.POW(AllTypes.RealPtr), + AllTypes.Real.POW(Float(11.22)), + + ABSf(AllTypes.Real), + SQRTf(AllTypes.Real), + CBRTf(AllTypes.Real), + CEIL(AllTypes.Real), + FLOOR(AllTypes.Real), + ROUND(AllTypes.Decimal), + ROUND(AllTypes.Decimal, Int(3)).AS("round"), + SIGN(AllTypes.Real), + TRUNC(AllTypes.Decimal), + TRUNC(AllTypes.Decimal, Int(1)), + ) + + fmt.Println(query.DebugSql()) + + err := query.Query(db, &struct{}{}) + + assert.NilError(t, err) +} + +func TestIntegerOperators(t *testing.T) { + query := AllTypes.SELECT( + AllTypes.Integer.EQ(AllTypes.IntegerPtr), + AllTypes.Bigint.EQ(Int(12)), + //AllTypes.Smallint.NOT_EQ(AllTypes.Real), + AllTypes.Integer.NOT_EQ(AllTypes.IntegerPtr), + AllTypes.Bigint.NOT_EQ(Int(12)), + AllTypes.Integer.LT(AllTypes.IntegerPtr), + AllTypes.Bigint.LT(Int(65)), + //AllTypes.Smallint.LT_EQ(AllTypes.Numeric), + AllTypes.Integer.LT_EQ(AllTypes.IntegerPtr), + AllTypes.Bigint.LT_EQ(Int(65)), //AllTypes.Smallint.GT_EQ(AllTypes.Numeric), - AllTypes.Integer.GT_EQ(Int(45)), + AllTypes.Integer.GT(AllTypes.IntegerPtr), + AllTypes.Bigint.GT(Int(65)), + AllTypes.Integer.GT_EQ(AllTypes.IntegerPtr), AllTypes.Bigint.GT_EQ(Int(65)), + + AllTypes.Integer.ADD(AllTypes.Integer), + AllTypes.Integer.ADD(Int(11)), + AllTypes.Integer.SUB(AllTypes.Integer), + AllTypes.Integer.SUB(Int(11)), + AllTypes.Integer.MUL(AllTypes.Integer), + AllTypes.Integer.MUL(Int(11)), + AllTypes.Integer.DIV(AllTypes.Integer), + AllTypes.Integer.DIV(Int(11)), + AllTypes.Integer.MOD(AllTypes.Integer), + AllTypes.Integer.MOD(Int(11)), + AllTypes.Integer.POW(AllTypes.Smallint), + AllTypes.Integer.POW(Int(11)), + AllTypes.Integer.BIT_SHIFT_LEFT(AllTypes.Smallint), + AllTypes.Integer.BIT_SHIFT_LEFT(Int(11)), + AllTypes.Integer.BIT_SHIFT_RIGHT(AllTypes.Smallint), + AllTypes.Integer.BIT_SHIFT_RIGHT(Int(11)), + + ABSi(AllTypes.Integer), + SQRTi(AllTypes.Integer), + CBRTi(AllTypes.Integer), ) fmt.Println(query.DebugSql())