Merge remote-tracking branch 'upstream/master' into stmt-cache2

# Conflicts:
#	tests/postgres/alltypes_test.go
#	tests/postgres/northwind_test.go
#	tests/postgres/sample_test.go
#	tests/postgres/update_test.go
#	tests/sqlite/insert_test.go
#	tests/sqlite/main_test.go
#	tests/sqlite/sample_test.go
#	tests/sqlite/update_test.go
This commit is contained in:
go-jet 2024-10-19 14:01:55 +02:00
commit 4bb9775134
97 changed files with 2306 additions and 537 deletions

View file

@ -11,6 +11,13 @@ type columnAssigmentImpl struct {
expression Expression
}
func NewColumnAssignment(serializer ColumnSerializer, expression Expression) ColumnAssigment {
return &columnAssigmentImpl{
column: serializer,
expression: expression,
}
}
func (a columnAssigmentImpl) isColumnAssigment() {}
func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {

View file

@ -13,6 +13,7 @@ type Dialect interface {
ArgumentPlaceholder() QueryPlaceholderFunc
IsReservedWord(name string) bool
SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName(index int) string
}
// SerializerFunc func
@ -35,6 +36,7 @@ type DialectParams struct {
ArgumentPlaceholder QueryPlaceholderFunc
ReservedWords []string
SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName func(index int) string
}
// NewDialect creates new dialect with params
@ -49,6 +51,7 @@ func NewDialect(params DialectParams) Dialect {
argumentPlaceholder: params.ArgumentPlaceholder,
reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords),
serializeOrderBy: params.SerializeOrderBy,
valuesDefaultColumnName: params.ValuesDefaultColumnName,
}
}
@ -62,6 +65,7 @@ type dialectImpl struct {
argumentPlaceholder QueryPlaceholderFunc
reservedWords map[string]bool
serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
valuesDefaultColumnName func(index int) string
}
func (d *dialectImpl) Name() string {
@ -107,6 +111,10 @@ func (d *dialectImpl) SerializeOrderBy() func(expression Expression, ascending,
return d.serializeOrderBy
}
func (d *dialectImpl) ValuesDefaultColumnName(index int) string {
return d.valuesDefaultColumnName(index)
}
func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
ret := map[string]bool{}
for _, elem := range arr {

View file

@ -51,12 +51,12 @@ func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression {
// IN checks if this expressions matches any in expressions list
func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN")
return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "IN")
}
// NOT_IN checks if this expressions is different of all expressions in expressions list
func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN")
return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "NOT IN")
}
// AS the temporary alias name to assign to the expression
@ -316,15 +316,6 @@ func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder,
}
}
type skipParenthesisWrap struct {
Expression
}
func skipWrap(expression Expression) Expression {
return &skipParenthesisWrap{expression}
}
// since the expression is a function parameter, there is no need to wrap it in parentheses
func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.Expression.serialize(statement, out, append(options, NoWrap)...)
func wrap(expressions ...Expression) Expression {
return NewFunc("", expressions, nil)
}

View file

@ -12,11 +12,6 @@ func OR(expressions ...BoolExpression) BoolExpression {
return newBoolExpressionListOperator("OR", expressions...)
}
// ROW function is used to create a tuple value that consists of a set of expressions or column values.
func ROW(expressions ...Expression) Expression {
return NewFunc("ROW", expressions, nil)
}
// ------------------ Mathematical functions ---------------//
// ABSf calculates absolute value from float expression
@ -711,7 +706,7 @@ func (p parametersSerializer) serialize(statement StatementType, out *SQLBuilder
if _, isStatement := expression.(Statement); isStatement {
expression.serialize(statement, out, options...)
} else {
skipWrap(expression).serialize(statement, out, options...)
expression.serialize(statement, out, append(options, NoWrap, Ident)...)
}
}
}

View file

@ -374,32 +374,6 @@ func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, option
//---------------------------------------------------//
type wrap struct {
ExpressionInterfaceImpl
expressions []Expression
}
func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
if len(n.expressions) == 1 {
options = append(options, NoWrap, Ident)
}
serializeExpressionList(statementType, n.expressions, ", ", out, options...)
out.WriteString(")")
}
// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... )
func WRAP(expression ...Expression) Expression {
wrap := &wrap{expressions: expression}
wrap.ExpressionInterfaceImpl.Parent = wrap
return wrap
}
//---------------------------------------------------//
type rawExpression struct {
ExpressionInterfaceImpl

View file

@ -54,7 +54,12 @@ type orderSetAggregateFuncExpression struct {
func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(p.name)
WRAP(p.fraction).serialize(statement, out, FallTrough(options)...)
if p.fraction != nil {
wrap(p.fraction).serialize(statement, out, FallTrough(options)...)
} else {
wrap().serialize(statement, out, FallTrough(options)...)
}
out.WriteString("WITHIN GROUP")
p.orderBy.serialize(statement, out)
}

View file

@ -39,7 +39,7 @@ AVG(table1.col_int) AS "avg",
table2.col3 AS "col3",
table2.col4 AS "col4"`)
subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery"))
subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery", nil))
assertProjectionSerialize(t, subQueryProjections,
`"subQuery"."table1.col3" AS "table1.col3",

View file

@ -8,7 +8,7 @@ type rawStatementImpl struct {
}
// RawStatement creates new sql statements from raw query and optional map of named arguments
func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) Statement {
func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement {
newRawStatement := rawStatementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
dialect: dialect,

View file

@ -0,0 +1,102 @@
package jet
// RowExpression interface
type RowExpression interface {
Expression
HasProjections
EQ(rhs RowExpression) BoolExpression
NOT_EQ(rhs RowExpression) BoolExpression
IS_DISTINCT_FROM(rhs RowExpression) BoolExpression
IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression
LT(rhs RowExpression) BoolExpression
LT_EQ(rhs RowExpression) BoolExpression
GT(rhs RowExpression) BoolExpression
GT_EQ(rhs RowExpression) BoolExpression
}
type rowInterfaceImpl struct {
parent Expression
dialect Dialect
elemCount int
}
func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression {
return Eq(n.parent, rhs)
}
func (n *rowInterfaceImpl) NOT_EQ(rhs RowExpression) BoolExpression {
return NotEq(n.parent, rhs)
}
func (n *rowInterfaceImpl) IS_DISTINCT_FROM(rhs RowExpression) BoolExpression {
return IsDistinctFrom(n.parent, rhs)
}
func (n *rowInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression {
return IsNotDistinctFrom(n.parent, rhs)
}
func (n *rowInterfaceImpl) GT(rhs RowExpression) BoolExpression {
return Gt(n.parent, rhs)
}
func (n *rowInterfaceImpl) GT_EQ(rhs RowExpression) BoolExpression {
return GtEq(n.parent, rhs)
}
func (n *rowInterfaceImpl) LT(rhs RowExpression) BoolExpression {
return Lt(n.parent, rhs)
}
func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression {
return LtEq(n.parent, rhs)
}
func (n *rowInterfaceImpl) projections() ProjectionList {
var ret ProjectionList
for i := 0; i < n.elemCount; i++ {
rowColumn := NewColumnImpl(n.dialect.ValuesDefaultColumnName(i), "", nil)
ret = append(ret, &rowColumn)
}
return ret
}
// ---------------------------------------------------//
type rowExpressionWrapper struct {
rowInterfaceImpl
Expression
}
func newRowExpression(name string, dialect Dialect, expressions ...Expression) RowExpression {
ret := &rowExpressionWrapper{}
ret.rowInterfaceImpl.parent = ret
ret.Expression = NewFunc(name, expressions, ret)
ret.dialect = dialect
ret.elemCount = len(expressions)
return ret
}
// ROW function is used to create a tuple value that consists of a set of expressions or column values.
func ROW(dialect Dialect, expressions ...Expression) RowExpression {
return newRowExpression("ROW", dialect, expressions...)
}
// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`.
func WRAP(dialect Dialect, expressions ...Expression) RowExpression {
return newRowExpression("", dialect, expressions...)
}
// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
// This enables the Go compiler to interpret any expression as a row expression
// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
func RowExp(expression Expression) RowExpression {
rowExpressionWrap := rowExpressionWrapper{Expression: expression}
rowExpressionWrap.rowInterfaceImpl.parent = &rowExpressionWrap
return &rowExpressionWrap
}

View file

@ -8,15 +8,21 @@ type SelectTable interface {
}
type selectTableImpl struct {
Statement SerializerHasProjections
alias string
Statement SerializerHasProjections
alias string
columnAliases []ColumnExpression
}
// NewSelectTable func
func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl {
func NewSelectTable(selectStmt SerializerHasProjections, alias string, columnAliases []ColumnExpression) selectTableImpl {
selectTable := selectTableImpl{
Statement: selectStmt,
alias: alias,
Statement: selectStmt,
alias: alias,
columnAliases: columnAliases,
}
for _, column := range selectTable.columnAliases {
column.setSubQuery(selectTable)
}
return selectTable
@ -31,6 +37,10 @@ func (s selectTableImpl) Alias() string {
}
func (s selectTableImpl) AllColumns() ProjectionList {
if len(s.columnAliases) > 0 {
return ColumnListToProjectionList(s.columnAliases)
}
projectionList := s.projections().fromImpl(s)
return projectionList.(ProjectionList)
}
@ -40,6 +50,12 @@ func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, opt
out.WriteString("AS")
out.WriteIdentifier(s.alias)
if len(s.columnAliases) > 0 {
out.WriteByte('(')
SerializeColumnExpressionNames(s.columnAliases, out)
out.WriteByte(')')
}
}
// --------------------------------------
@ -50,7 +66,7 @@ type lateralImpl struct {
// NewLateral creates new lateral expression from select statement with alias
func NewLateral(selectStmt SerializerStatement, alias string) SelectTable {
return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias)}
return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias, nil)}
}
func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {

View file

@ -180,7 +180,7 @@ func duration(f func()) time.Duration {
f()
return time.Now().Sub(start)
return time.Since(start)
}
// ExpressionStatement interfacess

35
internal/jet/values.go Normal file
View file

@ -0,0 +1,35 @@
package jet
// Values hold a set of one or more rows
type Values []RowExpression
func (v Values) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteByte('(')
out.IncreaseIdent(5)
out.NewLine()
out.WriteString("VALUES")
for rowIndex, row := range v {
if rowIndex > 0 {
out.WriteString(",")
out.NewLine()
} else {
out.IncreaseIdent(7)
}
row.serialize(statement, out, options...)
}
out.DecreaseIdent(7)
out.DecreaseIdent(5)
out.NewLine()
out.WriteByte(')')
}
func (v Values) projections() ProjectionList {
if len(v) == 0 {
return nil
}
return v[0].projections()
}

View file

@ -64,7 +64,7 @@ type CommonTableExpression struct {
// CTE creates new named CommonTableExpression
func CTE(name string, columns ...ColumnExpression) CommonTableExpression {
cte := CommonTableExpression{
selectTableImpl: NewSelectTable(nil, name),
selectTableImpl: NewSelectTable(nil, name, columns),
Columns: columns,
}
@ -99,12 +99,3 @@ func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilde
out.WriteIdentifier(c.alias)
}
}
// AllColumns returns list of all projections in the CTE
func (c CommonTableExpression) AllColumns() ProjectionList {
if len(c.Columns) > 0 {
return ColumnListToProjectionList(c.Columns)
}
return c.selectTableImpl.AllColumns()
}

View file

@ -9,17 +9,15 @@ import (
jet2 "github.com/go-jet/jet/v2/internal/jet/db"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
// UnixTimeComparer will compare time equality while ignoring time zone
@ -105,11 +103,12 @@ func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
}
// SaveJSONFile saves v as json at testRelativePath
// nolint:unused
func SaveJSONFile(v interface{}, testRelativePath string) {
jsonText, _ := json.MarshalIndent(v, "", "\t")
filePath := getFullPath(testRelativePath)
err := ioutil.WriteFile(filePath, jsonText, 0644)
err := os.WriteFile(filePath, jsonText, 0600)
throw.OnError(err)
}
@ -118,7 +117,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) {
func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath)
fileJSONData, err := ioutil.ReadFile(filePath)
fileJSONData, err := os.ReadFile(filePath) // #nosec G304
require.NoError(t, err)
if runtime.GOOS == "windows" {
@ -245,7 +244,7 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter
// AssertFileContent check if file content at filePath contains expectedContent text.
func AssertFileContent(t *testing.T, filePath string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
enumFileData, err := os.ReadFile(filePath) // #nosec G304
require.NoError(t, err)
@ -254,7 +253,7 @@ func AssertFileContent(t *testing.T, filePath string, expectedContent string) {
// AssertFileNamesEqual check if all filesInfos are contained in fileNames
func AssertFileNamesEqual(t *testing.T, dirPath string, fileNames ...string) {
files, err := ioutil.ReadDir(dirPath)
files, err := os.ReadDir(dirPath)
require.NoError(t, err)
require.Equal(t, len(files), len(fileNames))
@ -293,76 +292,6 @@ func printDiff(actual, expected interface{}, options ...cmp.Option) {
fmt.Println(expected)
}
// BoolPtr returns address of bool parameter
func BoolPtr(b bool) *bool {
return &b
}
// Int8Ptr returns address of int8 parameter
func Int8Ptr(i int8) *int8 {
return &i
}
// UInt8Ptr returns address of uint8 parameter
func UInt8Ptr(i uint8) *uint8 {
return &i
}
// Int16Ptr returns address of int16 parameter
func Int16Ptr(i int16) *int16 {
return &i
}
// UInt16Ptr returns address of uint16 parameter
func UInt16Ptr(i uint16) *uint16 {
return &i
}
// Int32Ptr returns address of int32 parameter
func Int32Ptr(i int32) *int32 {
return &i
}
// UInt32Ptr returns address of uint32 parameter
func UInt32Ptr(i uint32) *uint32 {
return &i
}
// Int64Ptr returns address of int64 parameter
func Int64Ptr(i int64) *int64 {
return &i
}
// UInt64Ptr returns address of uint64 parameter
func UInt64Ptr(i uint64) *uint64 {
return &i
}
// StringPtr returns address of string parameter
func StringPtr(s string) *string {
return &s
}
// TimePtr returns address of time.Time parameter
func TimePtr(t time.Time) *time.Time {
return &t
}
// ByteArrayPtr returns address of []byte parameter
func ByteArrayPtr(arr []byte) *[]byte {
return &arr
}
// Float32Ptr returns address of float32 parameter
func Float32Ptr(f float32) *float32 {
return &f
}
// Float64Ptr returns address of float64 parameter
func Float64Ptr(f float64) *float64 {
return &f
}
// UUIDPtr returns address of uuid.UUID
func UUIDPtr(u string) *uuid.UUID {
newUUID := uuid.MustParse(u)

View file

@ -1,6 +1,7 @@
package filesys
import (
"errors"
"fmt"
"go/format"
"os"
@ -16,7 +17,7 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error {
newGoFilePath += ".go"
}
file, err := os.Create(newGoFilePath)
file, err := os.Create(newGoFilePath) // #nosec 304
if err != nil {
return err
@ -28,7 +29,10 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error {
// if there is a format error we will write unformulated text for debug purposes
if err != nil {
file.Write(text)
_, writeErr := file.Write(text)
if writeErr != nil {
return errors.Join(writeErr, fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err))
}
return fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err)
}
@ -43,7 +47,7 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error {
// EnsureDirPathExist ensures dir path exists. If path does not exist, creates new path.
func EnsureDirPathExist(dirPath string) error {
if _, err := os.Stat(dirPath); os.IsNotExist(err) {
err := os.MkdirAll(dirPath, os.ModePerm)
err := os.MkdirAll(dirPath, 0o750)
if err != nil {
return fmt.Errorf("can't create directory - %s: %w", dirPath, err)

View file

@ -0,0 +1,6 @@
package ptr
// Of returns the address of any given parameter
func Of[T any](value T) *T {
return &value
}