Add support for postgres arrays

This commit is contained in:
Arjen Brouwer 2024-09-03 15:39:36 +02:00 committed by go-jet
parent b835e25665
commit d3ada5361e
27 changed files with 558 additions and 74 deletions

View file

@ -0,0 +1,93 @@
package jet
// ArrayExpression interface
type ArrayExpression[E Expression] interface {
Expression
EQ(rhs ArrayExpression[E]) BoolExpression
NOT_EQ(rhs ArrayExpression[E]) BoolExpression
LT(rhs ArrayExpression[E]) BoolExpression
GT(rhs ArrayExpression[E]) BoolExpression
LT_EQ(rhs ArrayExpression[E]) BoolExpression
GT_EQ(rhs ArrayExpression[E]) BoolExpression
CONTAINS(rhs ArrayExpression[E]) BoolExpression
IS_CONTAINED_BY(rhs ArrayExpression[E]) BoolExpression
OVERLAP(rhs ArrayExpression[E]) BoolExpression
CONCAT(rhs ArrayExpression[E]) ArrayExpression[E]
CONCAT_ELEMENT(E) ArrayExpression[E]
AT(expression IntegerExpression) Expression
}
type arrayInterfaceImpl[E Expression] struct {
parent ArrayExpression[E]
}
type BinaryBoolOp func(Expression, Expression) BoolExpression
func (a arrayInterfaceImpl[E]) EQ(rhs ArrayExpression[E]) BoolExpression {
return Eq(a.parent, rhs)
}
func (a arrayInterfaceImpl[E]) NOT_EQ(rhs ArrayExpression[E]) BoolExpression {
return NotEq(a.parent, rhs)
}
func (a arrayInterfaceImpl[E]) LT(rhs ArrayExpression[E]) BoolExpression {
return Lt(a.parent, rhs)
}
func (a arrayInterfaceImpl[E]) GT(rhs ArrayExpression[E]) BoolExpression {
return Gt(a.parent, rhs)
}
func (a arrayInterfaceImpl[E]) LT_EQ(rhs ArrayExpression[E]) BoolExpression {
return LtEq(a.parent, rhs)
}
func (a arrayInterfaceImpl[E]) GT_EQ(rhs ArrayExpression[E]) BoolExpression {
return GtEq(a.parent, rhs)
}
func (a arrayInterfaceImpl[E]) CONTAINS(rhs ArrayExpression[E]) BoolExpression {
return Contains(a.parent, rhs)
}
func (a arrayInterfaceImpl[E]) IS_CONTAINED_BY(rhs ArrayExpression[E]) BoolExpression {
return IsContainedBy(a.parent, rhs)
}
func (a arrayInterfaceImpl[E]) OVERLAP(rhs ArrayExpression[E]) BoolExpression {
return Overlap(a.parent, rhs)
}
func (a arrayInterfaceImpl[E]) CONCAT(rhs ArrayExpression[E]) ArrayExpression[E] {
return ArrayExp[E](NewBinaryOperatorExpression(a.parent, rhs, "||"))
}
func (a arrayInterfaceImpl[E]) CONCAT_ELEMENT(rhs E) ArrayExpression[E] {
return ArrayExp[E](NewBinaryOperatorExpression(a.parent, rhs, "||"))
}
func (a arrayInterfaceImpl[E]) AT(expression IntegerExpression) Expression {
return arraySubscriptExpr(a.parent, expression)
}
type arrayExpressionWrapper[E Expression] struct {
arrayInterfaceImpl[E]
Expression
}
func newArrayExpressionWrap[E Expression](expression Expression) ArrayExpression[E] {
arrayExpressionWrapper := arrayExpressionWrapper[E]{Expression: expression}
arrayExpressionWrapper.arrayInterfaceImpl.parent = &arrayExpressionWrapper
return &arrayExpressionWrapper
}
// ArrayExp is array expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as array expression.
// Does not add sql cast to generated sql builder output.
func ArrayExp[E Expression](expression Expression) ArrayExpression[E] {
return newArrayExpressionWrap[E](expression)
}

View file

@ -0,0 +1,59 @@
package jet
import (
"github.com/lib/pq"
"testing"
)
func TestArrayExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.EQ(table2ColArray), "(table1.col_array_string = table2.col_array_string)")
}
func TestArrayExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.NOT_EQ(table2ColArray), "(table1.col_array_string != table2.col_array_string)")
assertClauseSerialize(t, table1ColStringArray.NOT_EQ(StringArray([]string{"x"})), "(table1.col_array_string != $1)", pq.StringArray{"x"})
}
func TestArrayExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.LT(table2ColArray), "(table1.col_array_string < table2.col_array_string)")
}
func TestArrayExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.GT(table2ColArray), "(table1.col_array_string > table2.col_array_string)")
}
func TestArrayExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.LT_EQ(table2ColArray), "(table1.col_array_string <= table2.col_array_string)")
}
func TestArrayExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.GT_EQ(table2ColArray), "(table1.col_array_string >= table2.col_array_string)")
}
func TestArrayExpressionCONTAINS(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.CONTAINS(table2ColArray), "(table1.col_array_string @> table2.col_array_string)")
assertClauseSerialize(t, table1ColStringArray.CONTAINS(StringArray([]string{"x"})), "(table1.col_array_string @> $1)", pq.StringArray{"x"})
}
func TestArrayExpressionCONTAINED_BY(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.IS_CONTAINED_BY(table2ColArray), "(table1.col_array_string <@ table2.col_array_string)")
assertClauseSerialize(t, table1ColStringArray.IS_CONTAINED_BY(StringArray([]string{"x"})), "(table1.col_array_string <@ $1)", pq.StringArray{"x"})
}
func TestArrayExpressionOVERLAP(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.OVERLAP(table2ColArray), "(table1.col_array_string && table2.col_array_string)")
}
func TestArrayExpressionCONCAT(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.CONCAT(table2ColArray), "(table1.col_array_string || table2.col_array_string)")
assertClauseSerialize(t, table1ColStringArray.CONCAT(StringArray([]string{"x"})), "(table1.col_array_string || $1)", pq.StringArray{"x"})
}
func TestArrayExpressionCONCAT_ELEMENT(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.CONCAT_ELEMENT(StringExp(table2ColArray.AT(Int(1)))), "(table1.col_array_string || (table2.col_array_string[$1]))", int64(1))
assertClauseSerialize(t, table1ColStringArray.CONCAT_ELEMENT(String("x")), "(table1.col_array_string || $1)", "x")
}
func TestArrayExpressionAT(t *testing.T) {
assertClauseSerialize(t, table1ColStringArray.AT(Int(1)), "(table1.col_array_string[$1])", int64(1))
}

View file

@ -121,6 +121,46 @@ func IntegerColumn(name string) ColumnInteger {
//------------------------------------------------------//
type ColumnArray[E Expression] interface {
ArrayExpression[E]
Column
From(subQuery SelectTable) ColumnArray[E]
SET(stringExp ArrayExpression[E]) ColumnAssigment
}
type arrayColumnImpl[E Expression] struct {
arrayInterfaceImpl[E]
ColumnExpressionImpl
}
func (a arrayColumnImpl[E]) From(subQuery SelectTable) ColumnArray[E] {
newArrayColumn := ArrayColumn[E](a.name)
newArrayColumn.setTableName(a.tableName)
newArrayColumn.setSubQuery(subQuery)
return newArrayColumn
}
func (a *arrayColumnImpl[E]) SET(stringExp ArrayExpression[E]) ColumnAssigment {
return columnAssigmentImpl{
column: a,
expression: stringExp,
}
}
// StringColumn creates named string column.
func ArrayColumn[E Expression](name string) ColumnArray[E] {
arrayColumn := &arrayColumnImpl[E]{}
arrayColumn.arrayInterfaceImpl.parent = arrayColumn
arrayColumn.ColumnExpressionImpl = NewColumnImpl(name, "", arrayColumn)
return arrayColumn
}
//------------------------------------------------------//
// ColumnString is interface for SQL text, character, character varying
// bytea, uuid columns and enums types.
type ColumnString interface {

View file

@ -1,6 +1,7 @@
package jet
import (
"github.com/lib/pq"
"testing"
)
@ -8,6 +9,42 @@ var subQuery = &selectTableImpl{
alias: "sub_query",
}
func TestNewArrayColumnString(t *testing.T) {
stringArrayColumn := ArrayColumn[StringExpression]("colArray").From(subQuery)
assertClauseSerialize(t, stringArrayColumn, `sub_query."colArray"`)
assertClauseSerialize(t, stringArrayColumn.EQ(StringArray([]string{"X"})), `(sub_query."colArray" = $1)`, pq.StringArray{"X"})
assertProjectionSerialize(t, stringArrayColumn, `sub_query."colArray" AS "colArray"`)
arrayColumn2 := table1ColStringArray.From(subQuery)
assertClauseSerialize(t, arrayColumn2, `sub_query."table1.col_array_string"`)
assertClauseSerialize(t, arrayColumn2.EQ(StringArray([]string{"X"})), `(sub_query."table1.col_array_string" = $1)`, pq.StringArray{"X"})
assertProjectionSerialize(t, arrayColumn2, `sub_query."table1.col_array_string" AS "table1.col_array_string"`)
}
func TestNewArrayColumnBool(t *testing.T) {
boolArrayColumn := ArrayColumn[BoolExpression]("colArrayBool").From(subQuery)
assertClauseSerialize(t, boolArrayColumn, `sub_query."colArrayBool"`)
assertClauseSerialize(t, boolArrayColumn.EQ(BoolArray([]bool{true})), `(sub_query."colArrayBool" = $1)`, pq.BoolArray{true})
assertProjectionSerialize(t, boolArrayColumn, `sub_query."colArrayBool" AS "colArrayBool"`)
arrayColumn2 := table1ColBoolArray.From(subQuery)
assertClauseSerialize(t, arrayColumn2, `sub_query."table1.col_array_bool"`)
assertClauseSerialize(t, arrayColumn2.EQ(BoolArray([]bool{true})), `(sub_query."table1.col_array_bool" = $1)`, pq.BoolArray{true})
assertProjectionSerialize(t, arrayColumn2, `sub_query."table1.col_array_bool" AS "table1.col_array_bool"`)
}
func TestNewArrayColumnInteger(t *testing.T) {
intArrayColumn := ArrayColumn[IntegerExpression]("colArrayInt").From(subQuery)
assertClauseSerialize(t, intArrayColumn, `sub_query."colArrayInt"`)
assertClauseSerialize(t, intArrayColumn.EQ(Int32Array([]int32{42})), `(sub_query."colArrayInt" = $1)`, pq.Int32Array{42})
assertProjectionSerialize(t, intArrayColumn, `sub_query."colArrayInt" AS "colArrayInt"`)
arrayColumn2 := table1ColIntArray.From(subQuery)
assertClauseSerialize(t, arrayColumn2, `sub_query."table1.col_array_int"`)
assertClauseSerialize(t, arrayColumn2.EQ(Int32Array([]int32{42})), `(sub_query."table1.col_array_int" = $1)`, pq.Int32Array{42})
assertProjectionSerialize(t, arrayColumn2, `sub_query."table1.col_array_int" AS "table1.col_array_int"`)
}
func TestNewBoolColumn(t *testing.T) {
boolColumn := BoolColumn("colBool").From(subQuery)
assertClauseSerialize(t, boolColumn, `sub_query."colBool"`)

View file

@ -316,6 +316,32 @@ func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder,
}
}
type arraySubscriptExpression struct {
ExpressionInterfaceImpl
array Expression
subscript IntegerExpression
}
func (a arraySubscriptExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, NoWrap) {
out.WriteString("(")
}
a.array.serialize(statement, out, FallTrough(options)...) // FallTrough here because complexExpression is just a wrapper
out.WriteString("[")
a.subscript.serialize(statement, out, FallTrough(options)...) // FallTrough here because complexExpression is just a wrapper
out.WriteString("]")
if !contains(options, NoWrap) {
out.WriteString(")")
}
}
func arraySubscriptExpr(array Expression, subscript IntegerExpression) Expression {
arraySubscriptExpression := &arraySubscriptExpression{array: array, subscript: subscript}
arraySubscriptExpression.ExpressionInterfaceImpl.Parent = arraySubscriptExpression
return arraySubscriptExpression
}
type skipParenthesisWrap struct {
Expression
}

View file

@ -2,6 +2,7 @@ package jet
import (
"fmt"
"github.com/lib/pq"
"time"
)
@ -160,6 +161,66 @@ func Decimal(value string) FloatExpression {
return &floatLiteral
}
// ---------------------------------------------------//
type boolArrayLiteral struct {
arrayInterfaceImpl[BoolExpression]
literalExpressionImpl
}
func BoolArray(values []bool) ArrayExpression[BoolExpression] {
l := boolArrayLiteral{}
l.literalExpressionImpl = *literal(pq.BoolArray(values))
l.arrayInterfaceImpl.parent = &l
return &l
}
type integerArrayLiteral struct {
arrayInterfaceImpl[IntegerExpression]
literalExpressionImpl
}
func Int64Array(values []int64) ArrayExpression[IntegerExpression] {
l := integerArrayLiteral{}
l.literalExpressionImpl = *literal(pq.Int64Array(values))
l.arrayInterfaceImpl.parent = &l
return &l
}
func Int32Array(values []int32) ArrayExpression[IntegerExpression] {
l := integerArrayLiteral{}
l.literalExpressionImpl = *literal(pq.Int32Array(values))
l.arrayInterfaceImpl.parent = &l
return &l
}
type stringArrayLiteral struct {
arrayInterfaceImpl[StringExpression]
literalExpressionImpl
}
func StringArray(values []string) ArrayExpression[StringExpression] {
l := stringArrayLiteral{}
l.literalExpressionImpl = *literal(pq.StringArray(values))
l.arrayInterfaceImpl.parent = &l
return &l
}
type unsafeArrayLiteral[E Expression] struct {
arrayInterfaceImpl[E]
literalExpressionImpl
}
func UnsafeArray[E LiteralExpression](values []interface{}) ArrayExpression[E] {
l := unsafeArrayLiteral[E]{}
l.literalExpressionImpl = *literal(pq.Array(values))
l.arrayInterfaceImpl.parent = &l
return &l
}
// ---------------------------------------------------//
type stringLiteral struct {
stringInterfaceImpl

View file

@ -22,6 +22,15 @@ func BIT_NOT(expr IntegerExpression) IntegerExpression {
return newPrefixIntegerOperatorExpression(expr, "~")
}
// ----------- Array operators -------------- //
func Any(lhs Expression, op BinaryBoolOp, rhs Expression) BoolExpression {
return op(lhs, Func("ANY", rhs))
}
func All(lhs Expression, op BinaryBoolOp, rhs Expression) BoolExpression {
return op(lhs, Func("ALL", rhs))
}
//----------- Comparison operators ---------------//
// EXISTS checks for existence of the rows in subQuery
@ -74,6 +83,11 @@ func Contains(lhs Expression, rhs Expression) BoolExpression {
return newBinaryBoolOperatorExpression(lhs, rhs, "@>")
}
// IsContainedBy returns a representation of "a <@ b"
func IsContainedBy(lhs Expression, rhs Expression) BoolExpression {
return newBinaryBoolOperatorExpression(lhs, rhs, "<@")
}
// Overlap returns a representation of "a && b"
func Overlap(lhs, rhs Expression) BoolExpression {
return newBinaryBoolOperatorExpression(lhs, rhs, "&&")

View file

@ -81,11 +81,11 @@ func (s *SQLBuilder) write(data []byte) {
}
func isPreSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':'
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':' || b == '['
}
func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' || b == '[' || b == ']'
}
// WriteAlias is used to add alias to output SQL
@ -226,6 +226,8 @@ func argToString(value interface{}) string {
case string:
return stringQuote(bindVal)
case []string:
return stringArrayQuote(bindVal)
case []byte:
return stringQuote(string(bindVal))
case uuid.UUID:
@ -253,6 +255,19 @@ func argToString(value interface{}) string {
}
}
func stringArrayQuote(val []string) string {
var sb strings.Builder
sb.WriteString(`'{`)
for i := 0; i < len(val); i++ {
if i > 0 {
sb.WriteString(`, `)
}
sb.WriteString(stringDoubleQuote(val[i]))
}
sb.WriteString(`}'`)
return sb.String()
}
func integerTypesToString(value interface{}) string {
switch bindVal := value.(type) {
case int:
@ -301,3 +316,7 @@ func shouldQuoteIdentifier(identifier string) bool {
func stringQuote(value string) string {
return `'` + strings.Replace(value, "'", "''", -1) + `'`
}
func stringDoubleQuote(value string) string {
return `"` + strings.Replace(value, `"`, `""`, -1) + `"`
}

View file

@ -16,6 +16,9 @@ type StringExpression interface {
BETWEEN(min, max StringExpression) BoolExpression
NOT_BETWEEN(min, max StringExpression) BoolExpression
ANY_EQ(rhs ArrayExpression[StringExpression]) BoolExpression
ALL_EQ(rhs ArrayExpression[StringExpression]) BoolExpression
CONCAT(rhs Expression) StringExpression
LIKE(pattern StringExpression) BoolExpression
@ -69,6 +72,14 @@ func (s *stringInterfaceImpl) NOT_BETWEEN(min, max StringExpression) BoolExpress
return NewBetweenOperatorExpression(s.parent, min, max, true)
}
func (i *stringInterfaceImpl) ANY_EQ(rhs ArrayExpression[StringExpression]) BoolExpression {
return Any(i.parent, Eq, rhs)
}
func (i *stringInterfaceImpl) ALL_EQ(rhs ArrayExpression[StringExpression]) BoolExpression {
return All(i.parent, Eq, rhs)
}
func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression {
return newBinaryStringOperatorExpression(s.parent, rhs, StringConcatOperator)
}

View file

@ -76,6 +76,14 @@ func TestStringNOT_REGEXP_LIKE(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP $1)", "JOHN")
}
func TestStringANY_EQ(t *testing.T) {
assertClauseSerialize(t, table2ColStr.ANY_EQ(table1ColStringArray), "(table2.col_str = ANY(table1.col_array_string))")
}
func TestStringALL_EQ(t *testing.T) {
assertClauseSerialize(t, table2ColStr.ALL_EQ(table1ColStringArray), "(table2.col_str = ALL(table1.col_array_string))")
}
func TestStringExp(t *testing.T) {
assertClauseSerialize(t, StringExp(table2ColFloat), "table2.col_float")
assertClauseSerialize(t, StringExp(table2ColFloat).NOT_LIKE(String("abc")), "(table2.col_float NOT LIKE $1)", "abc")

View file

@ -15,19 +15,22 @@ var defaultDialect = NewDialect(DialectParams{ // just for tests
})
var (
table1Col1 = IntegerColumn("col1")
table1ColInt = IntegerColumn("col_int")
table1ColFloat = FloatColumn("col_float")
table1Col3 = IntegerColumn("col3")
table1ColTime = TimeColumn("col_time")
table1ColTimez = TimezColumn("col_timez")
table1ColTimestamp = TimestampColumn("col_timestamp")
table1ColTimestampz = TimestampzColumn("col_timestampz")
table1ColBool = BoolColumn("col_bool")
table1ColDate = DateColumn("col_date")
table1ColRange = RangeColumn[Int8Expression]("col_range")
table1Col1 = IntegerColumn("col1")
table1ColInt = IntegerColumn("col_int")
table1ColFloat = FloatColumn("col_float")
table1Col3 = IntegerColumn("col3")
table1ColTime = TimeColumn("col_time")
table1ColTimez = TimezColumn("col_timez")
table1ColTimestamp = TimestampColumn("col_timestamp")
table1ColTimestampz = TimestampzColumn("col_timestampz")
table1ColBool = BoolColumn("col_bool")
table1ColDate = DateColumn("col_date")
table1ColRange = RangeColumn[Int8Expression]("col_range")
table1ColStringArray = ArrayColumn[StringExpression]("col_array_string")
table1ColBoolArray = ArrayColumn[BoolExpression]("col_array_bool")
table1ColIntArray = ArrayColumn[IntegerExpression]("col_array_int")
)
var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColRange, table1ColTimestamp, table1ColTimestampz)
var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColRange, table1ColTimestamp, table1ColTimestampz, table1ColStringArray, table1ColBoolArray, table1ColIntArray)
var (
table2Col3 = IntegerColumn("col3")
@ -42,8 +45,9 @@ var (
table2ColTimestampz = TimestampzColumn("col_timestampz")
table2ColDate = DateColumn("col_date")
table2ColRange = RangeColumn[Int8Expression]("col_range")
table2ColArray = ArrayColumn[StringExpression]("col_array_string")
)
var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColRange, table2ColTimestamp, table2ColTimestampz)
var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColRange, table2ColTimestamp, table2ColTimestampz, table2ColArray)
var (
table3Col1 = IntegerColumn("col1")