diff --git a/sqlbuilder/func_expression.go b/sqlbuilder/func_expression.go index 2509cee..25e0646 100644 --- a/sqlbuilder/func_expression.go +++ b/sqlbuilder/func_expression.go @@ -72,6 +72,34 @@ func newIntegerFunc(name string, expressions ...expression) IntegerExpression { return floatFunc } +type stringFunc struct { + funcExpressionImpl + stringInterfaceImpl +} + +func newStringFunc(name string, expressions ...expression) StringExpression { + stringFunc := &stringFunc{} + + stringFunc.funcExpressionImpl = *newFunc(name, expressions, stringFunc) + stringFunc.stringInterfaceImpl.parent = stringFunc + + return stringFunc +} + +type boolFunc struct { + funcExpressionImpl + boolInterfaceImpl +} + +func newBoolFunc(name string, expressions ...expression) BoolExpression { + boolFunc := &boolFunc{} + + boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) + boolFunc.boolInterfaceImpl.parent = boolFunc + + return boolFunc +} + // ------------------ Mathematical functions ---------------// func ABSf(floatExpression FloatExpression) FloatExpression { @@ -157,3 +185,154 @@ func COUNTf(floatExpression FloatExpression) FloatExpression { func COUNTi(integerExpression IntegerExpression) IntegerExpression { return newIntegerFunc("COUNT", integerExpression) } + +//------------ String functions ------------------// + +func BIT_LENGTH(stringExpression StringExpression) IntegerExpression { + return newIntegerFunc("BIT_LENGTH", stringExpression) +} + +func CHAR_LENGTH(stringExpression StringExpression) IntegerExpression { + return newIntegerFunc("CHAR_LENGTH", stringExpression) +} + +func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression { + return newIntegerFunc("OCTET_LENGTH", stringExpression) +} + +func LOWER(stringExpression StringExpression) StringExpression { + return newStringFunc("LOWER", stringExpression) +} + +func UPPER(stringExpression StringExpression) StringExpression { + return newStringFunc("UPPER", stringExpression) +} + +func BTRIM(stringExpression StringExpression) StringExpression { + return newStringFunc("BTRIM", stringExpression) +} + +func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { + if len(trimChars) > 0 { + return newStringFunc("LTRIM", str, trimChars[0]) + } + return newStringFunc("LTRIM", str) +} + +func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { + if len(trimChars) > 0 { + return newStringFunc("RTRIM", str, trimChars[0]) + } + return newStringFunc("RTRIM", str) +} + +func CHR(integerExpression IntegerExpression) StringExpression { + return newStringFunc("CHR", integerExpression) +} + +//func CONCAT(expressions ...expression) StringExpression { +// return newStringFunc("CONCAT", expressions...) +//} +// +//func CONCAT_WS(expressions ...expression) StringExpression { +// return newStringFunc("CONCAT_WS", expressions...) +//} + +func CONVERT(str StringExpression, fromEncoding StringExpression, toEncoding StringExpression) StringExpression { + return newStringFunc("CONVERT", str, fromEncoding, toEncoding) +} + +func CONVERT_FROM(str StringExpression, fromEncoding StringExpression) StringExpression { + return newStringFunc("CONVERT_FROM", str, fromEncoding) +} + +func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression { + return newStringFunc("CONVERT_TO", str, toEncoding) +} + +func ENCODE(data StringExpression, format StringExpression) StringExpression { + return newStringFunc("ENCODE", data, format) +} + +func DECODE(data StringExpression, format StringExpression) StringExpression { + return newStringFunc("DECODE", data, format) +} + +//func FORMAT(formatStr StringExpression, formatArgs ...expression) StringExpression { +// args := []expression{formatStr} +// args = append(args, formatArgs...) +// return newStringFunc("FORMAT", args...) +//} + +func INITCAP(str StringExpression) StringExpression { + return newStringFunc("INITCAP", str) +} + +func LEFT(str StringExpression, n IntegerExpression) StringExpression { + return newStringFunc("LEFT", str, n) +} + +func RIGHT(str StringExpression, n IntegerExpression) StringExpression { + return newStringFunc("RIGHT", str, n) +} + +func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression { + if len(encoding) > 0 { + return newStringFunc("LENGTH", str, encoding[0]) + } + return newStringFunc("LENGTH", str) +} + +func LPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression { + if len(text) > 0 { + return newStringFunc("LPAD", str, length, text[0]) + } + + return newStringFunc("LPAD", str, length) +} + +func RPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression { + if len(text) > 0 { + return newStringFunc("RPAD", str, length, text[0]) + } + + return newStringFunc("RPAD", str, length) +} + +func MD5(stringExpression StringExpression) StringExpression { + return newStringFunc("MD5", stringExpression) +} + +func REPEAT(str StringExpression, n IntegerExpression) StringExpression { + return newStringFunc("REPEAT", str, n) +} + +func REPLACE(text, from, to StringExpression) StringExpression { + return newStringFunc("REPLACE", text, from, to) +} + +func REVERSE(stringExpression StringExpression) StringExpression { + return newStringFunc("REVERSE", stringExpression) +} + +func STRPOS(str, substring StringExpression) IntegerExpression { + return newIntegerFunc("STRPOS", str, substring) +} + +func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression { + if len(count) > 0 { + return newStringFunc("SUBSTR", str, from, count[0]) + } + return newStringFunc("SUBSTR", str, from) +} + +func TO_ASCII(str StringExpression, encoding ...StringExpression) StringExpression { + if len(encoding) > 0 { + return newStringFunc("TO_ASCII", str, encoding[0]) + } + return newStringFunc("TO_ASCII", str) +} + +func TO_HEX(number IntegerExpression) StringExpression { + return newStringFunc("TO_HEX", number) +} diff --git a/sqlbuilder/literal_expression.go b/sqlbuilder/literal_expression.go index 1217b54..cd462a9 100644 --- a/sqlbuilder/literal_expression.go +++ b/sqlbuilder/literal_expression.go @@ -73,7 +73,7 @@ type stringLiteral struct { literalExpression } -func String(value string) stringExpression { +func String(value string) StringExpression { stringLiteral := stringLiteral{} stringLiteral.literalExpression = *Literal(value) diff --git a/sqlbuilder/string_expression.go b/sqlbuilder/string_expression.go index 59e023a..ad6aee8 100644 --- a/sqlbuilder/string_expression.go +++ b/sqlbuilder/string_expression.go @@ -1,51 +1,75 @@ package sqlbuilder -type stringExpression interface { +type StringExpression interface { expression - EQ(rhs stringExpression) BoolExpression - NOT_EQ(rhs stringExpression) BoolExpression - IS_DISTINCT_FROM(rhs stringExpression) BoolExpression - IS_NOT_DISTINCT_FROM(rhs stringExpression) BoolExpression + EQ(rhs StringExpression) BoolExpression + NOT_EQ(rhs StringExpression) BoolExpression + IS_DISTINCT_FROM(rhs StringExpression) BoolExpression + IS_NOT_DISTINCT_FROM(rhs StringExpression) BoolExpression - LT(rhs stringExpression) BoolExpression - LT_EQ(rhs stringExpression) BoolExpression - GT(rhs stringExpression) BoolExpression - GT_EQ(rhs stringExpression) BoolExpression + LT(rhs StringExpression) BoolExpression + LT_EQ(rhs StringExpression) BoolExpression + GT(rhs StringExpression) BoolExpression + GT_EQ(rhs StringExpression) BoolExpression + + CONCAT(rhs expression) StringExpression } type stringInterfaceImpl struct { - parent stringExpression + parent StringExpression } -func (s *stringInterfaceImpl) EQ(rhs stringExpression) BoolExpression { +func (s *stringInterfaceImpl) EQ(rhs StringExpression) BoolExpression { return EQ(s.parent, rhs) } -func (s *stringInterfaceImpl) NOT_EQ(rhs stringExpression) BoolExpression { +func (s *stringInterfaceImpl) NOT_EQ(rhs StringExpression) BoolExpression { return NOT_EQ(s.parent, rhs) } -func (s *stringInterfaceImpl) IS_DISTINCT_FROM(rhs stringExpression) BoolExpression { +func (s *stringInterfaceImpl) IS_DISTINCT_FROM(rhs StringExpression) BoolExpression { return IS_DISTINCT_FROM(s.parent, rhs) } -func (s *stringInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs stringExpression) BoolExpression { +func (s *stringInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs StringExpression) BoolExpression { return IS_NOT_DISTINCT_FROM(s.parent, rhs) } -func (s *stringInterfaceImpl) GT(rhs stringExpression) BoolExpression { +func (s *stringInterfaceImpl) GT(rhs StringExpression) BoolExpression { return GT(s.parent, rhs) } -func (s *stringInterfaceImpl) GT_EQ(rhs stringExpression) BoolExpression { +func (s *stringInterfaceImpl) GT_EQ(rhs StringExpression) BoolExpression { return GT_EQ(s.parent, rhs) } -func (s *stringInterfaceImpl) LT(rhs stringExpression) BoolExpression { +func (s *stringInterfaceImpl) LT(rhs StringExpression) BoolExpression { return LT(s.parent, rhs) } -func (s *stringInterfaceImpl) LT_EQ(rhs stringExpression) BoolExpression { +func (s *stringInterfaceImpl) LT_EQ(rhs StringExpression) BoolExpression { return LT_EQ(s.parent, rhs) } + +func (s *stringInterfaceImpl) CONCAT(rhs expression) StringExpression { + return newBinaryStringExpression(s.parent, rhs, "||") +} + +//---------------------------------------------------// +type binaryStringExpression struct { + expressionInterfaceImpl + stringInterfaceImpl + + binaryOpExpression +} + +func newBinaryStringExpression(lhs, rhs expression, operator string) StringExpression { + boolExpression := binaryStringExpression{} + + boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) + boolExpression.expressionInterfaceImpl.parent = &boolExpression + boolExpression.stringInterfaceImpl.parent = &boolExpression + + return &boolExpression +} diff --git a/sqlbuilder/string_expression_test.go b/sqlbuilder/string_expression_test.go index 9829459..57b1eca 100644 --- a/sqlbuilder/string_expression_test.go +++ b/sqlbuilder/string_expression_test.go @@ -40,3 +40,8 @@ func TestStringLT_EQ(t *testing.T) { assertExpressionSerialize(t, exp, "(table3.col2 <= table2.colStr)") assertExpressionSerialize(t, table3StrCol.LT_EQ(String("JOHN")), "(table3.col2 <= $1)", "JOHN") } + +func TestStringCONCAT(t *testing.T) { + assertExpressionSerialize(t, table3StrCol.CONCAT(table2ColStr), "(table3.col2 || table2.colStr)") + assertExpressionSerialize(t, table3StrCol.CONCAT(String("JOHN")), "(table3.col2 || $1)", "JOHN") +} diff --git a/sqlbuilder/utils.go b/sqlbuilder/utils.go index 11fd42b..2ac6d21 100644 --- a/sqlbuilder/utils.go +++ b/sqlbuilder/utils.go @@ -113,6 +113,16 @@ func serializeColumnList(statement statementType, columns []column, out *queryDa return nil } +//func stringExpressionListToExpressionList(stringExpressions []StringExpression) []expression{ +// var ret []expression +// +// for _, strExp := range stringExpressions { +// ret = append(ret, strExp) +// } +// +// return ret +//} + func Query(statement Statement, db execution.Db, destination interface{}) error { query, args, err := statement.Sql() diff --git a/tests/types_test.go b/tests/types_test.go index cb71532..68cad95 100644 --- a/tests/types_test.go +++ b/tests/types_test.go @@ -39,8 +39,43 @@ func TestStringOperators(t *testing.T) { AllTypes.Text.LT(String("Text")), AllTypes.Text.LT_EQ(AllTypes.CharacterVaryingPtr), AllTypes.Text.LT_EQ(String("Text")), + AllTypes.Text.CONCAT(String("text2")), + AllTypes.Text.CONCAT(Int(11)), + + BIT_LENGTH(AllTypes.Text), + CHAR_LENGTH(AllTypes.Character), + OCTET_LENGTH(AllTypes.Text), + LOWER(AllTypes.CharacterVaryingPtr), + UPPER(AllTypes.Character), + BTRIM(AllTypes.CharacterVarying), + LTRIM(AllTypes.CharacterVarying, String("A")), + RTRIM(AllTypes.CharacterVarying, String("B")), + CHR(Int(65)), + //CONCAT(String("string1"), Int(1), Float(11.12)), + //CONCAT_WS(String("string1"), Int(1), Float(11.12)), + CONVERT(String("text_in_utf8"), String("UTF8"), String("LATIN1")), + CONVERT_FROM(String("text_in_utf8"), String("UTF8")), + CONVERT_TO(String("text_in_utf8"), String("UTF8")), + ENCODE(String("123\000\001"), String("base64")), + DECODE(String("MTIzAAE="), String("base64")), + //FORMAT(String("Hello %s, %1$s"), String("World")), + INITCAP(String("hi THOMAS")), + LEFT(String("abcde"), Int(2)), + RIGHT(String("abcde"), Int(2)), + LENGTH(String("jose")), + LENGTH(String("jose"), String("UTF8")), + LPAD(String("Hi"), Int(5), String("xy")), + RPAD(String("Hi"), Int(5), String("xy")), + MD5(AllTypes.CharacterVarying), + REPEAT(AllTypes.Text, Int(33)), + REPLACE(AllTypes.Character, String("BA"), String("AB")), + REVERSE(AllTypes.CharacterVarying), + STRPOS(AllTypes.Text, String("A")), + SUBSTR(AllTypes.CharacterPtr, Int(3), Int(2)), + TO_HEX(AllTypes.IntegerPtr), ) + //fmt.Println(query.Sql()) fmt.Println(query.DebugSql()) err := query.Query(db, &struct{}{})