Jet internal refactor.

This commit is contained in:
go-jet 2019-08-11 14:29:03 +02:00
parent 4fbf576370
commit ee4897a1e2
49 changed files with 481 additions and 2528 deletions

View file

@ -71,19 +71,6 @@ func TestBoolLiteral(t *testing.T) {
assertClauseSerialize(t, Bool(false), "$1", false)
}
func TestExists(t *testing.T) {
assertClauseSerialize(t, EXISTS(
table2.
SELECT(Int(1)).
WHERE(table1Col1.EQ(table2Col3)),
),
`(EXISTS (
SELECT $1
FROM db.table2
WHERE table1.col1 = table2.col3
))`, int64(1))
}
func TestBoolExp(t *testing.T) {
assertClauseSerialize(t, BoolExp(String("true")), "$1", "true")
assertClauseSerialize(t, BoolExp(String("true")).IS_TRUE(), "$1 IS TRUE", "true")

View file

@ -18,12 +18,12 @@ type CastImpl struct {
expression Expression
}
func NewCastImpl(expression Expression) CastImpl {
func NewCastImpl(expression Expression) Cast {
castImpl := CastImpl{
expression: expression,
}
return castImpl
return &castImpl
}
func (b *CastImpl) AS(castType string) Expression {

View file

@ -1,7 +1,11 @@
package jet
//func TestCastAS(t *testing.T) {
// AssertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST(? AS boolean)", int64(1))
// AssertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)")
// AssertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)")
//}
import (
"testing"
)
func TestCastAS(t *testing.T) {
assertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST($1 AS boolean)", int64(1))
assertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)")
assertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)")
}

View file

@ -36,7 +36,7 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SqlBuilder) e
return errors.New("jet: no column selected for Projection")
}
return out.writeProjections(statementType, s.Projections)
return out.WriteProjections(statementType, s.Projections)
}
type ClauseFrom struct {
@ -77,9 +77,9 @@ func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SqlBuilder)
out.NewLine()
out.WriteString("GROUP BY")
out.increaseIdent()
out.IncreaseIdent()
err := serializeGroupByClauseList(statementType, c.List, out)
out.decreaseIdent()
out.DecreaseIdent()
return err
}
@ -173,15 +173,10 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB
wrap := s.OrderBy.List != nil || s.Limit.Count >= 0 || s.Offset.Count >= 0
//if wrap {
// out.WriteString("(")
// out.increaseIdent()
//}
if wrap {
out.NewLine()
out.WriteString("(")
out.increaseIdent()
out.IncreaseIdent()
}
for i, selectStmt := range s.Selects {
@ -207,7 +202,7 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB
}
if wrap {
out.decreaseIdent()
out.DecreaseIdent()
out.NewLine()
out.WriteString(")")
}
@ -224,12 +219,6 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB
return err
}
//if wrap {
// out.decreaseIdent()
// out.newLine()
// out.WriteString(")")
//}
return nil
}
@ -253,7 +242,7 @@ func (u *ClauseUpdate) Serialize(statementType StatementType, out *SqlBuilder) e
}
type ClauseSet struct {
Columns []IColumn
Columns []Column
Values []Serializer
}
@ -265,7 +254,7 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) erro
return errors.New("jet: mismatch in numers of columns and values")
}
out.increaseIdent(4)
out.IncreaseIdent(4)
for i, column := range s.Columns {
if i > 0 {
out.WriteString(", ")
@ -280,26 +269,26 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) erro
out.WriteString(" = ")
if err := Serialize(s.Values[i], UpdateStatementType, out); err != nil {
if err := s.Values[i].serialize(UpdateStatementType, out); err != nil {
return err
}
}
out.decreaseIdent(4)
out.DecreaseIdent(4)
return nil
}
type ClauseReturning struct {
Projections []Projection
}
func (r *ClauseReturning) Serialize(statementType StatementType, out *SqlBuilder) error {
return out.WriteReturning(statementType, r.Projections)
}
type ClauseInsert struct {
Table SerializerTable
Columns []IColumn
Columns []Column
}
func (i *ClauseInsert) GetColumns() []Column {
if len(i.Columns) > 0 {
return i.Columns
}
return i.Table.Columns()
}
func (i *ClauseInsert) Serialize(statementType StatementType, out *SqlBuilder) error {
@ -347,7 +336,7 @@ func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) e
out.WriteString(",")
}
out.increaseIdent()
out.IncreaseIdent()
out.NewLine()
out.WriteString("(")
@ -358,7 +347,7 @@ func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) e
}
out.writeByte(')')
out.decreaseIdent()
out.DecreaseIdent()
}
return nil
}
@ -459,235 +448,3 @@ func (i *ClauseIn) Serialize(statementType StatementType, out *SqlBuilder) error
return nil
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable2(Dialect Dialect, schemaName, name string, columns ...Column) TableImpl2 {
t := TableImpl2{
Dialect: Dialect,
schemaName: schemaName,
name: name,
columnList: columns,
}
for _, c := range columns {
c.SetTableName(name)
}
return t
}
type TableImpl2 struct {
Dialect Dialect
schemaName string
name string
alias string
columnList []Column
}
func (t *TableImpl2) AS(alias string) {
t.alias = alias
for _, c := range t.columnList {
c.SetTableName(alias)
}
}
func (t *TableImpl2) SchemaName() string {
return t.schemaName
}
func (t *TableImpl2) TableName() string {
return t.name
}
func (t *TableImpl2) Columns() []IColumn {
ret := []IColumn{}
for _, col := range t.columnList {
ret = append(ret, col)
}
return ret
}
func (t *TableImpl2) dialect() Dialect {
return t.Dialect
}
func (t *TableImpl2) accept(visitor visitor) {
visitor.visit(t)
}
func (t *TableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if t == nil {
return errors.New("jet: tableImpl is nil. ")
}
out.writeIdentifier(t.schemaName)
out.WriteString(".")
out.writeIdentifier(t.name)
if len(t.alias) > 0 {
out.WriteString("AS")
out.writeIdentifier(t.alias)
}
return nil
}
// Join expressions are pseudo readable tables.
type JoinTableImpl struct {
lhs Serializer
rhs Serializer
joinType JoinType
onCondition BoolExpression
}
func NewJoinTableImpl(lhs Serializer, rhs Serializer, joinType JoinType, onCondition BoolExpression) JoinTableImpl {
joinTable := JoinTableImpl{
lhs: lhs,
rhs: rhs,
joinType: joinType,
onCondition: onCondition,
}
return joinTable
}
func (t *JoinTableImpl) SchemaName() string {
return ""
}
func (t *JoinTableImpl) TableName() string {
return ""
}
func (t *JoinTableImpl) Columns() []IColumn {
//return append(t.lhs.columns(), t.rhs.columns()...)
panic("Unimplemented")
}
func (t *JoinTableImpl) accept(visitor visitor) {
//t.lhs.accept(visitor)
//t.rhs.accept(visitor)
//TODO: uncoment
}
func (t *JoinTableImpl) dialect() Dialect {
return detectDialect(t)
}
func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) {
if t == nil {
return errors.New("jet: Join table is nil. ")
}
if utils.IsNil(t.lhs) {
return errors.New("jet: left hand side of join operation is nil table")
}
if err = t.lhs.serialize(statement, out); err != nil {
return
}
out.NewLine()
switch t.joinType {
case InnerJoin:
out.WriteString("INNER JOIN")
case LeftJoin:
out.WriteString("LEFT JOIN")
case RightJoin:
out.WriteString("RIGHT JOIN")
case FullJoin:
out.WriteString("FULL JOIN")
case CrossJoin:
out.WriteString("CROSS JOIN")
}
if utils.IsNil(t.rhs) {
return errors.New("jet: right hand side of join operation is nil table")
}
if err = t.rhs.serialize(statement, out); err != nil {
return
}
if t.onCondition == nil && t.joinType != CrossJoin {
return errors.New("jet: join condition is nil")
}
if t.onCondition != nil {
out.WriteString("ON")
if err = t.onCondition.serialize(statement, out); err != nil {
return
}
}
return nil
}
// SelectTable is interface for SELECT sub-queries
type SelectTable interface {
Alias() string
AllColumns() ProjectionList
}
type SelectTableImpl2 struct {
selectStmt StatementWithProjections
alias string
projections []Projection
}
func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl2 {
selectTable := SelectTableImpl2{selectStmt: selectStmt, alias: alias}
for _, projection := range selectStmt.projections() {
newProjection := projection.fromImpl(&selectTable)
selectTable.projections = append(selectTable.projections, newProjection)
}
return selectTable
}
func (s *SelectTableImpl2) Alias() string {
return s.alias
}
func (s *SelectTableImpl2) Columns() []IColumn {
return nil
}
func (s *SelectTableImpl2) accept(visitor visitor) {
visitor.visit(s)
s.selectStmt.accept(visitor)
}
func (s *SelectTableImpl2) dialect() Dialect {
return detectDialect(s.selectStmt)
}
func (s *SelectTableImpl2) AllColumns() ProjectionList {
return s.projections
}
func (s *SelectTableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if s == nil {
return errors.New("jet: Expression table is nil. ")
}
err := s.selectStmt.serialize(statement, out)
if err != nil {
return err
}
out.WriteString("AS")
out.writeIdentifier(s.alias)
return nil
}

View file

@ -2,7 +2,7 @@
package jet
type IColumn interface {
type Column interface {
Name() string
TableName() string
@ -12,9 +12,9 @@ type IColumn interface {
}
// Column is common column interface for all types of columns.
type Column interface {
type ColumnExpression interface {
Column
Expression
IColumn
}
// The base type for real materialized columns.
@ -28,7 +28,7 @@ type columnImpl struct {
subQuery SelectTable
}
func newColumn(name string, tableName string, parent Column) columnImpl {
func newColumn(name string, tableName string, parent ColumnExpression) columnImpl {
bc := columnImpl{
name: name,
tableName: tableName,
@ -109,19 +109,19 @@ func (c columnImpl) serialize(statement StatementType, out *SqlBuilder, options
type IColumnList interface {
Projection
IColumn
Column
Columns() []Column
Columns() []ColumnExpression
}
func ColumnList(columns ...Column) IColumnList {
func ColumnList(columns ...ColumnExpression) IColumnList {
return columnListImpl(columns)
}
// ColumnList is redefined type to support list of columns as single Projection
type columnListImpl []Column
type columnListImpl []ColumnExpression
func (cl columnListImpl) Columns() []Column {
func (cl columnListImpl) Columns() []ColumnExpression {
return cl
}

View file

@ -3,7 +3,7 @@ package jet
// ColumnBool is interface for SQL boolean columns.
type ColumnBool interface {
BoolExpression
IColumn
Column
From(subQuery SelectTable) ColumnBool
}
@ -42,7 +42,7 @@ func BoolColumn(name string) ColumnBool {
// ColumnFloat is interface for SQL real, numeric, decimal or double precision column.
type ColumnFloat interface {
FloatExpression
IColumn
Column
From(subQuery SelectTable) ColumnFloat
}
@ -80,7 +80,7 @@ func FloatColumn(name string) ColumnFloat {
// ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger interface {
IntegerExpression
IColumn
Column
From(subQuery SelectTable) ColumnInteger
}
@ -118,7 +118,7 @@ func IntegerColumn(name string) ColumnInteger {
// bytea, uuid columns and enums types.
type ColumnString interface {
StringExpression
IColumn
Column
From(subQuery SelectTable) ColumnString
}
@ -155,7 +155,7 @@ func StringColumn(name string) ColumnString {
// ColumnTime is interface for SQL time column.
type ColumnTime interface {
TimeExpression
IColumn
Column
From(subQuery SelectTable) ColumnTime
}
@ -190,7 +190,7 @@ func TimeColumn(name string) ColumnTime {
// ColumnTimez is interface of SQL time with time zone columns.
type ColumnTimez interface {
TimezExpression
IColumn
Column
From(subQuery SelectTable) ColumnTimez
}
@ -227,7 +227,7 @@ func TimezColumn(name string) ColumnTimez {
// ColumnTimestamp is interface of SQL timestamp columns.
type ColumnTimestamp interface {
TimestampExpression
IColumn
Column
From(subQuery SelectTable) ColumnTimestamp
}
@ -264,7 +264,7 @@ func TimestampColumn(name string) ColumnTimestamp {
// ColumnTimestampz is interface of SQL timestamp with timezone columns.
type ColumnTimestampz interface {
TimestampzExpression
IColumn
Column
From(subQuery SelectTable) ColumnTimestampz
}
@ -301,7 +301,7 @@ func TimestampzColumn(name string) ColumnTimestampz {
// ColumnDate is interface of SQL date columns.
type ColumnDate interface {
DateExpression
IColumn
Column
From(subQuery SelectTable) ColumnDate
}

View file

@ -4,7 +4,9 @@ import (
"testing"
)
var subQuery = table1.SELECT(table1ColFloat, table1ColInt).AsTable("sub_query")
var subQuery = &SelectTableImpl2{
alias: "sub_query",
}
func TestNewBoolColumn(t *testing.T) {
boolColumn := BoolColumn("colBool").From(subQuery)

View file

@ -1,110 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
)
// DeleteStatement is interface for SQL DELETE statement
type DeleteStatement interface {
Statement
WHERE(expression BoolExpression) DeleteStatement
RETURNING(projections ...Projection) DeleteStatement
}
func newDeleteStatement(table WritableTable) DeleteStatement {
return &deleteStatementImpl{
table: table,
}
}
type deleteStatementImpl struct {
table WritableTable
where BoolExpression
returning []Projection
}
func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
d.where = expression
return d
}
func (d *deleteStatementImpl) RETURNING(projections ...Projection) DeleteStatement {
d.returning = projections
return d
}
func (d *deleteStatementImpl) accept(visitor visitor) {
visitor.visit(d)
d.table.accept(visitor)
}
func (d *deleteStatementImpl) serializeImpl(out *SqlBuilder) error {
if d == nil {
return errors.New("jet: delete statement is nil")
}
out.NewLine()
out.WriteString("DELETE FROM")
if d.table == nil {
return errors.New("jet: nil tableName")
}
if err := d.table.serialize(DeleteStatementType, out); err != nil {
return err
}
if d.where == nil {
return errors.New("jet: deleting without a WHERE clause")
}
if err := out.writeWhere(DeleteStatementType, d.where); err != nil {
return err
}
if err := out.WriteReturning(DeleteStatementType, d.returning); err != nil {
return err
}
return nil
}
func (d *deleteStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
queryData := &SqlBuilder{
Dialect: detectDialect(d, dialect...),
}
err = d.serializeImpl(queryData)
if err != nil {
return
}
query, args = queryData.finalize()
return
}
func (d *deleteStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(d, dialect...)
}
func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(d, db, destination)
}
func (d *deleteStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, d, db, destination)
}
func (d *deleteStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(d, db)
}
func (d *deleteStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, d, db)
}

View file

@ -1,25 +0,0 @@
package jet
import (
"testing"
)
func TestDeleteUnconditionally(t *testing.T) {
assertStatementErr(t, table1.DELETE(), `jet: deleting without a WHERE clause`)
assertStatementErr(t, table1.DELETE().WHERE(nil), `jet: deleting without a WHERE clause`)
}
func TestDeleteWithWhere(t *testing.T) {
assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), `
DELETE FROM db.table1
WHERE table1.col1 = $1;
`, int64(1))
}
func TestDeleteWithWhereAndReturning(t *testing.T) {
assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).RETURNING(table1Col1), `
DELETE FROM db.table1
WHERE table1.col1 = $1
RETURNING table1.col1 AS "table1.col1";
`, int64(1))
}

View file

@ -1,17 +1,5 @@
package jet
import (
"strconv"
)
var ANSII = NewDialect(DialectParams{ // just for tests
AliasQuoteChar: '"',
IdentifierQuoteChar: '"',
ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord)
},
})
type Dialect interface {
Name() string
PackageName() string
@ -25,7 +13,7 @@ type SerializeFunc func(statement StatementType, out *SqlBuilder, options ...Ser
type SerializeOverride func(expressions ...Expression) SerializeFunc
type QueryPlaceholderFunc func(ord int) string
type UpdateAssigmentFunc func(columns []IColumn, values []Serializer, out *SqlBuilder) (err error)
type UpdateAssigmentFunc func(columns []Column, values []Serializer, out *SqlBuilder) (err error)
type DialectParams struct {
Name string

View file

@ -26,33 +26,14 @@ func TestExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
}
func TestIN(t *testing.T) {
assertClauseSerialize(t, table2ColInt.IN(Int(1), Int(2), Int(3)),
`(table2.col_int IN ($1, $2, $3))`, int64(1), int64(2), int64(3))
assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)),
`($1 IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
}
func TestNOT_IN(t *testing.T) {
assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)),
`($1 NOT IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, table2ColInt.NOT_IN(Int(1), Int(2), Int(3)),
`(table2.col_int NOT IN ($1, $2, $3))`, int64(1), int64(2), int64(3))
assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) NOT IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
}

View file

@ -1,180 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/utils"
)
// InsertStatement is interface for SQL INSERT statements
type InsertStatement interface {
Statement
// Insert row of values
VALUES(value interface{}, values ...interface{}) InsertStatement
// Insert row of values, where value for each column is extracted from filed of structure data.
// If data is not struct or there is no field for every column selected, this method will panic.
MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
RETURNING(projections ...Projection) InsertStatement
}
func newInsertStatement(t WritableTable, columns []IColumn) InsertStatement {
return &insertStatementImpl{
table: t,
columns: columns,
}
}
type insertStatementImpl struct {
table WritableTable
columns []IColumn
rows [][]Serializer
query SelectStatement
returning []Projection
}
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
i.rows = append(i.rows, UnwindRowFromValues(value, values))
return i
}
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
i.rows = append(i.rows, UnwindRowFromModel(i.getColumns(), data))
return i
}
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
i.rows = append(i.rows, UnwindRowsFromModels(i.getColumns(), data)...)
return i
}
func (i *insertStatementImpl) RETURNING(projections ...Projection) InsertStatement {
i.returning = projections
return i
}
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
i.query = selectStatement
return i
}
func (i *insertStatementImpl) getColumns() []IColumn {
if len(i.columns) > 0 {
return i.columns
}
return i.table.columns()
}
func (i *insertStatementImpl) accept(visitor visitor) {
visitor.visit(i)
i.table.accept(visitor)
}
func (i *insertStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(i, dialect...)
}
func (i *insertStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
out := &SqlBuilder{
Dialect: detectDialect(i, dialect...),
}
out.NewLine()
out.WriteString("INSERT INTO")
if utils.IsNil(i.table) {
return "", nil, errors.New("jet: table is nil")
}
err = i.table.serialize(InsertStatementType, out)
if err != nil {
return
}
if len(i.columns) > 0 {
out.WriteString("(")
err = SerializeColumnNames(i.columns, out)
if err != nil {
return
}
out.WriteString(")")
}
//TODO:
if len(i.rows) == 0 && i.query == nil {
return "", nil, errors.New("jet: no row values or query specified")
}
if len(i.rows) > 0 && i.query != nil {
return "", nil, errors.New("jet: only row values or query has to be specified")
}
if len(i.rows) > 0 {
out.WriteString("VALUES")
for rowIndex, row := range i.rows {
if rowIndex > 0 {
out.WriteString(",")
}
out.increaseIdent()
out.NewLine()
out.WriteString("(")
err = SerializeClauseList(InsertStatementType, row, out)
if err != nil {
return "", nil, err
}
out.writeByte(')')
out.decreaseIdent()
}
}
if i.query != nil {
err = i.query.serialize(InsertStatementType, out)
if err != nil {
return
}
}
if err = out.WriteReturning(InsertStatementType, i.returning); err != nil {
return
}
query, args = out.finalize()
return
}
func (i *insertStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(i, db, destination)
}
func (i *insertStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, i, db, destination)
}
func (i *insertStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(i, db)
}
func (i *insertStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, i, db)
}

View file

@ -1,147 +0,0 @@
package jet
import (
"gotest.tools/assert"
"testing"
"time"
)
func TestInvalidInsert(t *testing.T) {
assertStatementErr(t, table1.INSERT(table1Col1), "jet: no row values or query specified")
assertStatementErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list")
}
func TestInsertNilValue(t *testing.T) {
assertStatement(t, table1.INSERT(table1Col1).VALUES(nil), `
INSERT INTO db.table1 (col1) VALUES
($1);
`, nil)
}
func TestInsertSingleValue(t *testing.T) {
assertStatement(t, table1.INSERT(table1Col1).VALUES(1), `
INSERT INTO db.table1 (col1) VALUES
($1);
`, int(1))
}
func TestInsertWithColumnList(t *testing.T) {
columnList := ColumnList(table3ColInt, table3StrCol)
assertStatement(t, table3.INSERT(columnList).VALUES(1, 3), `
INSERT INTO db.table3 (col_int, col2) VALUES
($1, $2);
`, 1, 3)
}
func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
assertStatement(t, table1.INSERT(table1ColTime).VALUES(date), `
INSERT INTO db.table1 (col_time) VALUES
($1);
`, date)
}
func TestInsertMultipleValues(t *testing.T) {
assertStatement(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), `
INSERT INTO db.table1 (col1, col_float, col3) VALUES
($1, $2, $3);
`, 1, 2, 3)
}
func TestInsertMultipleRows(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColFloat).
VALUES(1, 2).
VALUES(11, 22).
VALUES(111, 222)
assertStatement(t, stmt, `
INSERT INTO db.table1 (col1, col_float) VALUES
($1, $2),
($3, $4),
($5, $6);
`, 1, 2, 11, 22, 111, 222)
}
func TestInsertValuesFromModel(t *testing.T) {
type Table1Model struct {
Col1 *int
ColFloat float64
}
one := 1
toInsert := Table1Model{
Col1: &one,
ColFloat: 1.11,
}
stmt := table1.INSERT(table1Col1, table1ColFloat).
MODEL(toInsert).
MODEL(&toInsert)
expectedSQL := `
INSERT INTO db.table1 (col1, col_float)
VALUES ($1, $2),
($3, $4);
`
assertStatement(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11))
}
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "missing struct field for column : col1")
}()
type Table1Model struct {
Col1Prim int
Col2 string
}
newData := Table1Model{
Col1Prim: 1,
Col2: "one",
}
table1.
INSERT(table1Col1, table1ColFloat).
MODEL(newData)
}
func TestInsertFromNonStructModel(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "argument mismatch: expected struct, got []int")
}()
table2.INSERT(table2ColInt).MODEL([]int{})
}
func TestInsertQuery(t *testing.T) {
stmt := table1.INSERT(table1Col1).
QUERY(table1.SELECT(table1Col1))
var expectedSQL = `
INSERT INTO db.table1 (col1) (
SELECT table1.col1 AS "table1.col1"
FROM db.table1
);
`
assertStatement(t, stmt, expectedSQL)
}
func TestInsertDefaultValue(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColFloat).
VALUES(DEFAULT, "two")
var expectedSQL = `
INSERT INTO db.table1 (col1, col_float) VALUES
(DEFAULT, $1);
`
assertStatement(t, stmt, expectedSQL, "two")
}

View file

@ -1,112 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
)
// TableLockMode is a type of possible SQL table lock
type TableLockMode string
// LockStatement interface for SQL LOCK statement
type LockStatement interface {
Statement
IN(lockMode string) LockStatement
NOWAIT() LockStatement
}
type lockStatementImpl struct {
tables []WritableTable
lockMode string
nowait bool
}
// LOCK creates lock statement for list of tables.
func LOCK(tables ...WritableTable) LockStatement {
return &lockStatementImpl{
tables: tables,
}
}
func (l *lockStatementImpl) IN(lockMode string) LockStatement {
l.lockMode = lockMode
return l
}
func (l *lockStatementImpl) NOWAIT() LockStatement {
l.nowait = true
return l
}
func (l *lockStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(l, dialect...)
}
func (l *lockStatementImpl) accept(visitor visitor) {
visitor.visit(l)
for _, table := range l.tables {
table.accept(visitor)
}
}
func (l *lockStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
if l == nil {
return "", nil, errors.New("jet: nil Statement")
}
if len(l.tables) == 0 {
return "", nil, errors.New("jet: There is no table selected to be locked")
}
out := &SqlBuilder{
Dialect: detectDialect(l, dialect...),
}
out.NewLine()
out.WriteString("LOCK TABLE")
for i, table := range l.tables {
if i > 0 {
out.WriteString(", ")
}
err := table.serialize(LockStatementType, out)
if err != nil {
return "", nil, err
}
}
if l.lockMode != "" {
out.WriteString("IN")
out.WriteString(string(l.lockMode))
out.WriteString("MODE")
}
if l.nowait {
out.WriteString("NOWAIT")
}
query, args = out.finalize()
return
}
func (l *lockStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(l, db, destination)
}
func (l *lockStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, l, db, destination)
}
func (l *lockStatementImpl) Exec(db execution.DB) (sql.Result, error) {
return exec(l, db)
}
func (l *lockStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, l, db)
}

View file

@ -17,7 +17,7 @@ func BIT_NOT(expr IntegerExpression) IntegerExpression {
//----------- Comparison operators ---------------//
// EXISTS checks for existence of the rows in subQuery
func EXISTS(subQuery SelectStatement) BoolExpression {
func EXISTS(subQuery Expression) BoolExpression {
return newPrefixBoolOperator(subQuery, "EXISTS")
}

View file

@ -0,0 +1,48 @@
package jet
// SelectLock is interface for SELECT statement locks
type SelectLock interface {
Serializer
NOWAIT() SelectLock
SKIP_LOCKED() SelectLock
}
type selectLockImpl struct {
lockStrength string
noWait, skipLocked bool
}
func NewSelectLock(name string) func() SelectLock {
return func() SelectLock {
return newSelectLock(name)
}
}
func newSelectLock(lockStrength string) SelectLock {
return &selectLockImpl{lockStrength: lockStrength}
}
func (s *selectLockImpl) NOWAIT() SelectLock {
s.noWait = true
return s
}
func (s *selectLockImpl) SKIP_LOCKED() SelectLock {
s.skipLocked = true
return s
}
func (s *selectLockImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
out.WriteString(s.lockStrength)
if s.noWait {
out.WriteString("NOWAIT")
}
if s.skipLocked {
out.WriteString("SKIP LOCKED")
}
return nil
}

View file

@ -1,355 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
)
// SelectStatement is interface for SQL SELECT statements
type SelectStatement interface {
Statement
IExpression
DISTINCT() SelectStatement
FROM(table ReadableTable) SelectStatement
WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...GroupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement
ORDER_BY(orderByClauses ...OrderByClause) SelectStatement
LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement
FOR(lock SelectLock) SelectStatement
UNION(rhs SelectStatement) SelectStatement
UNION_ALL(rhs SelectStatement) SelectStatement
INTERSECT(rhs SelectStatement) SelectStatement
INTERSECT_ALL(rhs SelectStatement) SelectStatement
EXCEPT(rhs SelectStatement) SelectStatement
EXCEPT_ALL(rhs SelectStatement) SelectStatement
AsTable(alias string) SelectTable
projections() []Projection
}
//SELECT creates new SelectStatement with list of projections
func SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(nil, append([]Projection{projection1}, projections...))
}
type selectStatementImpl struct {
ExpressionInterfaceImpl
parent SelectStatement
table ReadableTable
distinct bool
projectionList []Projection
where BoolExpression
groupBy []GroupByClause
having BoolExpression
orderBy []OrderByClause
limit, offset int64
lockFor SelectLock
}
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
newSelect := &selectStatementImpl{
table: table,
projectionList: projections,
limit: -1,
offset: -1,
distinct: false,
}
newSelect.ExpressionInterfaceImpl.Parent = newSelect
newSelect.parent = newSelect
return newSelect
}
func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement {
s.table = table
return s.parent
}
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
//return newSelectTable(s.parent, alias)
panic("UNimplemented.")
}
func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement {
s.where = expression
return s.parent
}
func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement {
s.groupBy = groupByClauses
return s.parent
}
func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement {
s.having = expression
return s.parent
}
func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement {
s.orderBy = clauses
return s.parent
}
func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement {
s.offset = offset
return s.parent
}
func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement {
s.limit = limit
return s.parent
}
func (s *selectStatementImpl) DISTINCT() SelectStatement {
s.distinct = true
return s.parent
}
func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement {
s.lockFor = lock
return s.parent
}
func (s *selectStatementImpl) UNION(rhs SelectStatement) SelectStatement {
return UNION(s.parent, rhs)
}
func (s *selectStatementImpl) UNION_ALL(rhs SelectStatement) SelectStatement {
return UNION_ALL(s.parent, rhs)
}
func (s *selectStatementImpl) INTERSECT(rhs SelectStatement) SelectStatement {
return INTERSECT(s.parent, rhs)
}
func (s *selectStatementImpl) INTERSECT_ALL(rhs SelectStatement) SelectStatement {
return INTERSECT_ALL(s.parent, rhs)
}
func (s *selectStatementImpl) EXCEPT(rhs SelectStatement) SelectStatement {
return EXCEPT(s.parent, rhs)
}
func (s *selectStatementImpl) EXCEPT_ALL(rhs SelectStatement) SelectStatement {
return EXCEPT_ALL(s.parent, rhs)
}
func (s *selectStatementImpl) projections() []Projection {
return s.projectionList
}
func (s *selectStatementImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if s == nil {
return errors.New("jet: Select expression is nil. ")
}
out.WriteString("(")
out.increaseIdent()
err := s.serializeImpl(out)
out.decreaseIdent()
if err != nil {
return err
}
out.NewLine()
out.WriteString(")")
return nil
}
func (s *selectStatementImpl) serializeImpl(out *SqlBuilder) error {
if s == nil {
return errors.New("jet: Select expression is nil. ")
}
out.NewLine()
out.WriteString("SELECT")
if s.distinct {
out.WriteString("DISTINCT")
}
if len(s.projectionList) == 0 {
return errors.New("jet: no column selected for Projection")
}
err := out.writeProjections(SelectStatementType, s.projectionList)
if err != nil {
return err
}
if s.table != nil {
if err := out.writeFrom(SelectStatementType, s.table); err != nil {
return err
}
}
if s.where != nil {
err := out.writeWhere(SelectStatementType, s.where)
if err != nil {
return nil
}
}
if s.groupBy != nil && len(s.groupBy) > 0 {
err := out.writeGroupBy(SelectStatementType, s.groupBy)
if err != nil {
return err
}
}
if s.having != nil {
err := out.writeHaving(SelectStatementType, s.having)
if err != nil {
return err
}
}
if s.orderBy != nil {
err := out.writeOrderBy(SelectStatementType, s.orderBy)
if err != nil {
return err
}
}
if s.limit >= 0 {
out.NewLine()
out.WriteString("LIMIT")
out.insertParametrizedArgument(s.limit)
}
if s.offset >= 0 {
out.NewLine()
out.WriteString("OFFSET")
out.insertParametrizedArgument(s.offset)
}
if s.lockFor != nil {
out.NewLine()
out.WriteString("FOR")
err := s.lockFor.serialize(SelectStatementType, out)
if err != nil {
return err
}
}
return nil
}
func (s *selectStatementImpl) accept(visitor visitor) {
visitor.visit(s)
if s.table != nil {
s.table.accept(visitor)
}
if s.where != nil {
s.where.accept(visitor)
}
if s.having != nil {
s.having.accept(visitor)
}
}
func (s *selectStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
queryData := &SqlBuilder{
Dialect: detectDialect(s, dialect...),
}
err = s.serializeImpl(queryData)
if err != nil {
return "", nil, err
}
query, args = queryData.finalize()
return
}
func (s *selectStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(s.parent, dialect...)
}
func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(s.parent, db, destination)
}
func (s *selectStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, s.parent, db, destination)
}
func (s *selectStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(s.parent, db)
}
func (s *selectStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, s.parent, db)
}
// SelectLock is interface for SELECT statement locks
type SelectLock interface {
Serializer
NOWAIT() SelectLock
SKIP_LOCKED() SelectLock
}
type selectLockImpl struct {
lockStrength string
noWait, skipLocked bool
}
func NewSelectLock(name string) func() SelectLock {
return func() SelectLock {
return newSelectLock(name)
}
}
func newSelectLock(lockStrength string) SelectLock {
return &selectLockImpl{lockStrength: lockStrength}
}
func (s *selectLockImpl) NOWAIT() SelectLock {
s.noWait = true
return s
}
func (s *selectLockImpl) SKIP_LOCKED() SelectLock {
s.skipLocked = true
return s
}
func (s *selectLockImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
out.WriteString(s.lockStrength)
if s.noWait {
out.WriteString("NOWAIT")
}
if s.skipLocked {
out.WriteString("SKIP LOCKED")
}
return nil
}

View file

@ -1,198 +0,0 @@
package jet
import "testing"
func TestInvalidSelect(t *testing.T) {
assertStatementErr(t, SELECT(nil), "jet: Projection is nil")
}
func TestSelectColumnList(t *testing.T) {
columnList := ColumnList(table2ColInt, table2ColFloat, table3ColInt)
assertStatement(t, SELECT(columnList).FROM(table2), `
SELECT table2.col_int AS "table2.col_int",
table2.col_float AS "table2.col_float",
table3.col_int AS "table3.col_int"
FROM db.table2;
`)
}
func TestSelectLiterals(t *testing.T) {
assertStatement(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), `
SELECT $1,
$2,
$3
FROM db.table1;
`, int64(1), 2.2, false)
}
func TestSelectDistinct(t *testing.T) {
assertStatement(t, SELECT(table1ColBool).DISTINCT().FROM(table1), `
SELECT DISTINCT table1.col_bool AS "table1.col_bool"
FROM db.table1;
`)
}
func TestSelectFrom(t *testing.T) {
assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1;
`)
assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int);
`)
assertStatement(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int);
`)
}
func TestSelectWhere(t *testing.T) {
assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE $1;
`, true)
assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE table1.col_int >= $1;
`, int64(10))
}
func TestSelectGroupBy(t *testing.T) {
assertStatement(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
GROUP BY table2.col_float;
`)
}
func TestSelectHaving(t *testing.T) {
assertStatement(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), `
SELECT table3.col_int AS "table3.col_int"
FROM db.table3
HAVING table1.col_bool = $1;
`, true)
}
func TestSelectOrderBy(t *testing.T) {
assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), `
SELECT table2.col_float AS "table2.col_float"
FROM db.table2
ORDER BY table2.col_int DESC;
`)
assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), `
SELECT table2.col_float AS "table2.col_float"
FROM db.table2
ORDER BY table2.col_int DESC, table2.col_int ASC;
`)
}
func TestSelectLimitOffset(t *testing.T) {
assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
LIMIT $1;
`, int64(10))
assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
LIMIT $1
OFFSET $2;
`, int64(10), int64(2))
}
func TestSelectSets(t *testing.T) {
select1 := SELECT(table1ColBool).FROM(table1)
select2 := SELECT(table2ColBool).FROM(table2)
assertStatement(t, select1.UNION(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
UNION
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.UNION_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
UNION ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.INTERSECT(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
INTERSECT
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.INTERSECT_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
INTERSECT ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.EXCEPT(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
EXCEPT
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.EXCEPT_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
EXCEPT ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
}

View file

@ -1,70 +1,58 @@
package jet
//// SelectTable is interface for SELECT sub-queries
//type SelectTable interface {
// ReadableTable
//
// Alias() string
//
// AllColumns() ProjectionList
//}
//
//type selectTableImpl struct {
// readableTableInterfaceImpl
// selectStmt SelectStatement
// alias string
//
// projections []Projection
//}
//
//func newSelectTable(selectStmt SelectStatement, alias string) SelectTable {
// expTable := &selectTableImpl{selectStmt: selectStmt, alias: alias}
//
// expTable.readableTableInterfaceImpl.parent = expTable
//
// for _, projection := range selectStmt.projections() {
// newProjection := projection.fromImpl(expTable)
//
// expTable.projections = append(expTable.projections, newProjection)
// }
//
// return expTable
//}
//
//func (s *selectTableImpl) Alias() string {
// return s.alias
//}
//
//func (s *selectTableImpl) columns() []IColumn {
// return nil
//}
//
//func (s *selectTableImpl) accept(visitor visitor) {
// visitor.visit(s)
// s.selectStmt.accept(visitor)
//}
//
//func (s *selectTableImpl) dialect() Dialect {
// return detectDialect(s.selectStmt)
//}
//
//func (s *selectTableImpl) AllColumns() ProjectionList {
// return s.projections
//}
//
//func (s *selectTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
// if s == nil {
// return errors.New("jet: Expression table is nil. ")
// }
//
// err := s.selectStmt.serialize(statement, out)
//
// if err != nil {
// return err
// }
//
// out.WriteString("AS")
// out.writeIdentifier(s.alias)
//
// return nil
//}
import "errors"
// SelectTable is interface for SELECT sub-queries
type SelectTable interface {
Alias() string
AllColumns() ProjectionList
}
type SelectTableImpl2 struct {
selectStmt StatementWithProjections
alias string
projections []Projection
}
func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl2 {
selectTable := SelectTableImpl2{selectStmt: selectStmt, alias: alias}
for _, projection := range selectStmt.projections() {
newProjection := projection.fromImpl(&selectTable)
selectTable.projections = append(selectTable.projections, newProjection)
}
return selectTable
}
func (s *SelectTableImpl2) Alias() string {
return s.alias
}
func (s *SelectTableImpl2) accept(visitor visitor) {
visitor.visit(s)
s.selectStmt.accept(visitor)
}
func (s *SelectTableImpl2) AllColumns() ProjectionList {
return s.projections
}
func (s *SelectTableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if s == nil {
return errors.New("jet: Expression table is nil. ")
}
err := s.selectStmt.serialize(statement, out)
if err != nil {
return err
}
out.WriteString("AS")
out.writeIdentifier(s.alias)
return nil
}

View file

@ -1,197 +0,0 @@
package jet
import (
"errors"
)
// UNION effectively appends the result of sub-queries(select statements) into single query.
// It eliminates duplicate rows from its result.
func UNION(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(Union, false, toSelectList(lhs, rhs, selects...))
}
// UNION_ALL effectively appends the result of sub-queries(select statements) into single query.
// It does not eliminates duplicate rows from its result.
func UNION_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(Union, true, toSelectList(lhs, rhs, selects...))
}
// INTERSECT returns all rows that are in query results.
// It eliminates duplicate rows from its result.
func INTERSECT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(Intersect, false, toSelectList(lhs, rhs, selects...))
}
// INTERSECT_ALL returns all rows that are in query results.
// It does not eliminates duplicate rows from its result.
func INTERSECT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(Intersect, true, toSelectList(lhs, rhs, selects...))
}
// EXCEPT returns all rows that are in the result of query lhs but not in the result of query rhs.
// It eliminates duplicate rows from its result.
func EXCEPT(lhs, rhs SelectStatement) SelectStatement {
return newSetStatementImpl(Except, false, toSelectList(lhs, rhs))
}
// EXCEPT_ALL returns all rows that are in the result of query lhs but not in the result of query rhs.
// It does not eliminates duplicate rows from its result.
func EXCEPT_ALL(lhs, rhs SelectStatement) SelectStatement {
return newSetStatementImpl(Except, true, toSelectList(lhs, rhs))
}
func toSelectList(lhs, rhs SelectStatement, selects ...SelectStatement) []SelectStatement {
return append([]SelectStatement{lhs, rhs}, selects...)
}
const (
Union = "UNION"
Intersect = "INTERSECT"
Except = "EXCEPT"
)
// Similar to selectStatementImpl, but less complete
type setStatementImpl struct {
selectStatementImpl
operator string
all bool
selects []SelectStatement
}
func newSetStatementImpl(operator string, all bool, selects []SelectStatement) SelectStatement {
setStatement := &setStatementImpl{
operator: operator,
all: all,
selects: selects,
}
setStatement.selectStatementImpl.ExpressionInterfaceImpl.Parent = setStatement
setStatement.selectStatementImpl.parent = setStatement
setStatement.limit = -1
setStatement.offset = -1
return setStatement
}
func (s *setStatementImpl) accept(visitor visitor) {
visitor.visit(s)
for _, selects := range s.selects {
selects.accept(visitor)
}
}
func (s *setStatementImpl) projections() []Projection {
if len(s.selects) > 0 {
return s.selects[0].projections()
}
return []Projection{}
}
func (s *setStatementImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if s == nil {
return errors.New("jet: Set expression is nil. ")
}
wrap := s.orderBy != nil || s.limit >= 0 || s.offset >= 0
if wrap {
out.WriteString("(")
out.increaseIdent()
}
err := s.serializeImpl(out)
if err != nil {
return err
}
if wrap {
out.decreaseIdent()
out.NewLine()
out.WriteString(")")
}
return nil
}
func (s *setStatementImpl) serializeImpl(out *SqlBuilder) error {
if s == nil {
return errors.New("jet: Set expression is nil. ")
}
if len(s.selects) < 2 {
return errors.New("jet: UNION Statement must have at least two SELECT statements")
}
if setOverride := out.Dialect.SerializeOverride(s.operator); setOverride != nil {
return setOverride()(SelectStatementType, out)
}
out.NewLine()
out.WriteString("(")
out.increaseIdent()
for i, selectStmt := range s.selects {
out.NewLine()
if i > 0 {
out.WriteString(s.operator)
if s.all {
out.WriteString("ALL")
}
out.NewLine()
}
if selectStmt == nil {
return errors.New("jet: select statement is nil")
}
err := selectStmt.serialize(SetStatementType, out)
if err != nil {
return err
}
}
out.decreaseIdent()
out.NewLine()
out.WriteString(")")
if s.orderBy != nil {
err := out.writeOrderBy(SetStatementType, s.orderBy)
if err != nil {
return err
}
}
if s.limit >= 0 {
out.NewLine()
out.WriteString("LIMIT")
out.insertParametrizedArgument(s.limit)
}
if s.offset >= 0 {
out.NewLine()
out.WriteString("OFFSET")
out.insertParametrizedArgument(s.offset)
}
return nil
}
func (s *setStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
queryData := &SqlBuilder{
Dialect: detectDialect(s, dialect...),
}
err = s.serializeImpl(queryData)
if err != nil {
return
}
query, args = queryData.finalize()
return
}

View file

@ -1,301 +0,0 @@
package jet
import (
"gotest.tools/assert"
"testing"
)
func TestUnionTwoSelect(t *testing.T) {
var expectedSQL = `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`
unionStmt1 := table1.
SELECT(table1Col1).
UNION(
table2.SELECT(table2Col3),
)
unionStmt2 := UNION(table1.SELECT(table1Col1), table2.SELECT(table2Col3))
assertStatement(t, unionStmt1, expectedSQL)
assertStatement(t, unionStmt2, expectedSQL)
}
func TestUnionNilSelect(t *testing.T) {
unionStmt := table1.
SELECT(table1Col1).
UNION(nil)
assertStatementErr(t, unionStmt, "jet: select statement is nil")
}
func TestUnionThreeSelect1(t *testing.T) {
unionStmt1 := table1.SELECT(table1Col1).
UNION(
table2.SELECT(table2Col3),
).
UNION(
table3.SELECT(table3Col1),
)
var expectedSQL = `
(
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
)
UNION
(
SELECT table3.col1 AS "table3.col1"
FROM db.table3
)
);
`
assertStatement(t, unionStmt1, expectedSQL)
}
func TestUnionThreeSelect2(t *testing.T) {
unionStmt2 := UNION(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
table3.SELECT(table3Col1),
)
var expectedSQL = `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
UNION
(
SELECT table3.col1 AS "table3.col1"
FROM db.table3
)
);
`
assertStatement(t, unionStmt2, expectedSQL)
}
func TestUnionWithOrderBy(t *testing.T) {
unionStmt := UNION(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).
ORDER_BY(table1Col1.ASC())
assertStatement(t, unionStmt, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
)
ORDER BY "table1.col1" ASC;
`)
}
func TestUnionWithLimitAndOffset(t *testing.T) {
query, args, err := UNION(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).
LIMIT(10).
OFFSET(11).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
)
LIMIT $1
OFFSET $2;
`)
assert.Equal(t, len(args), 2)
}
func TestUnionInUnion(t *testing.T) {
expectedSQL := `
(
(
SELECT table2.col3 AS "table2.col3",
table2.col3 AS "table2.col3"
FROM db.table2
)
UNION
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
)
);
`
query := UNION(
SELECT(table2Col3, table2Col3).FROM(table2),
UNION_ALL(table1.SELECT(table1Col1), table2.SELECT(table2Col3)),
)
assertStatement(t, query, expectedSQL)
}
func TestUnionALL(t *testing.T) {
query, args, err := UNION_ALL(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestINTERSECT(t *testing.T) {
query, args, err := INTERSECT(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
INTERSECT
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestINTERSECT_ALL(t *testing.T) {
query, args, err := INTERSECT_ALL(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
INTERSECT ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestEXCEPT(t *testing.T) {
query, args, err := EXCEPT(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
EXCEPT
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestEXCEPT_ALL(t *testing.T) {
query, args, err := EXCEPT_ALL(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
EXCEPT ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}

View file

@ -24,7 +24,7 @@ func (s *SqlBuilder) DebugSQL() string {
const defaultIdent = 5
func (q *SqlBuilder) increaseIdent(ident ...int) {
func (q *SqlBuilder) IncreaseIdent(ident ...int) {
if len(ident) > 0 {
q.ident += ident[0]
} else {
@ -32,7 +32,7 @@ func (q *SqlBuilder) increaseIdent(ident ...int) {
}
}
func (q *SqlBuilder) decreaseIdent(ident ...int) {
func (q *SqlBuilder) DecreaseIdent(ident ...int) {
toDecrease := defaultIdent
if len(ident) > 0 {
@ -46,10 +46,10 @@ func (q *SqlBuilder) decreaseIdent(ident ...int) {
q.ident -= toDecrease
}
func (q *SqlBuilder) writeProjections(statement StatementType, projections []Projection) error {
q.increaseIdent()
func (q *SqlBuilder) WriteProjections(statement StatementType, projections []Projection) error {
q.IncreaseIdent()
err := SerializeProjectionList(statement, projections, q)
q.decreaseIdent()
q.DecreaseIdent()
return err
}
@ -57,9 +57,9 @@ func (q *SqlBuilder) writeFrom(statement StatementType, table Serializer) error
q.NewLine()
q.WriteString("FROM")
q.increaseIdent()
q.IncreaseIdent()
err := table.serialize(statement, q)
q.decreaseIdent()
q.DecreaseIdent()
return err
}
@ -68,9 +68,9 @@ func (q *SqlBuilder) writeWhere(statement StatementType, where Expression) error
q.NewLine()
q.WriteString("WHERE")
q.increaseIdent()
q.IncreaseIdent()
err := where.serialize(statement, q, noWrap)
q.decreaseIdent()
q.DecreaseIdent()
return err
}
@ -79,9 +79,9 @@ func (q *SqlBuilder) writeGroupBy(statement StatementType, groupBy []GroupByClau
q.NewLine()
q.WriteString("GROUP BY")
q.increaseIdent()
q.IncreaseIdent()
err := serializeGroupByClauseList(statement, groupBy, q)
q.decreaseIdent()
q.DecreaseIdent()
return err
}
@ -90,9 +90,9 @@ func (q *SqlBuilder) writeOrderBy(statement StatementType, orderBy []OrderByClau
q.NewLine()
q.WriteString("ORDER BY")
q.increaseIdent()
q.IncreaseIdent()
err := serializeOrderByClauseList(statement, orderBy, q)
q.decreaseIdent()
q.DecreaseIdent()
return err
}
@ -101,9 +101,9 @@ func (q *SqlBuilder) writeHaving(statement StatementType, having Expression) err
q.NewLine()
q.WriteString("HAVING")
q.increaseIdent()
q.IncreaseIdent()
err := having.serialize(statement, q, noWrap)
q.decreaseIdent()
q.DecreaseIdent()
return err
}
@ -115,9 +115,9 @@ func (q *SqlBuilder) WriteReturning(statement StatementType, returning []Project
q.NewLine()
q.WriteString("RETURNING")
q.increaseIdent()
q.IncreaseIdent()
return q.writeProjections(statement, returning)
return q.WriteProjections(statement, returning)
}
func (q *SqlBuilder) NewLine() {

View file

@ -204,7 +204,7 @@ func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, opti
if !contains(options, noWrap) {
out.WriteString("(")
out.increaseIdent()
out.IncreaseIdent()
}
for _, clause := range s.Clauses {
@ -216,7 +216,7 @@ func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, opti
}
if !contains(options, noWrap) {
out.decreaseIdent()
out.DecreaseIdent()
out.NewLine()
out.WriteString(")")
}

View file

@ -11,154 +11,36 @@ type SerializerTable interface {
}
type TableInterface interface {
Columns() []IColumn
}
type TableBase interface {
dialect() Dialect
columns() []IColumn
}
type readableTable interface {
// Generates a select query on the current tableName.
SELECT(projection Projection, projections ...Projection) SelectStatement
// Creates a inner join tableName Expression using onCondition.
INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a left join tableName Expression using onCondition.
LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a right join tableName Expression using onCondition.
RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a full join tableName Expression using onCondition.
FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a cross join tableName Expression using onCondition.
CROSS_JOIN(table ReadableTable) ReadableTable
}
type writableTable interface {
INSERT(columns ...IColumn) InsertStatement
UPDATE(column IColumn, columns ...IColumn) UpdateStatement
DELETE() DeleteStatement
LOCK() LockStatement
}
// ReadableTable interface
type ReadableTable interface {
TableBase
readableTable
Serializer
acceptsVisitor
}
// WritableTable interface
type WritableTable interface {
TableBase
writableTable
Serializer
acceptsVisitor
}
// Table interface
type Table interface {
TableBase
readableTable
writableTable
Serializer
acceptsVisitor
Columns() []Column
SchemaName() string
TableName() string
AS(alias string)
}
type readableTableInterfaceImpl struct {
parent ReadableTable
}
// Generates a select query on the current tableName.
func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...))
}
// Creates a inner join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, InnerJoin, onCondition)
}
// Creates a left join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, LeftJoin, onCondition)
}
// Creates a right join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, RightJoin, onCondition)
}
func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, FullJoin, onCondition)
}
func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable {
return newJoinTable(r.parent, table, CrossJoin, nil)
}
type writableTableInterfaceImpl struct {
parent WritableTable
}
func (w *writableTableInterfaceImpl) INSERT(columns ...IColumn) InsertStatement {
return newInsertStatement(w.parent, UnwidColumnList(columns))
}
func (w *writableTableInterfaceImpl) UPDATE(column IColumn, columns ...IColumn) UpdateStatement {
return newUpdateStatement(w.parent, UnwindColumns(column, columns...))
}
func (w *writableTableInterfaceImpl) DELETE() DeleteStatement {
return newDeleteStatement(w.parent)
}
func (w *writableTableInterfaceImpl) LOCK() LockStatement {
return LOCK(w.parent)
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(Dialect Dialect, schemaName, name string, columns ...Column) Table {
func NewTable(schemaName, name string, columns ...ColumnExpression) TableImpl {
t := &tableImpl{
Dialect: Dialect,
t := TableImpl{
schemaName: schemaName,
name: name,
columnList: columns,
}
for _, c := range columns {
c.SetTableName(name)
}
t.readableTableInterfaceImpl.parent = t
t.writableTableInterfaceImpl.parent = t
return t
}
type tableImpl struct {
readableTableInterfaceImpl
writableTableInterfaceImpl
Dialect Dialect
type TableImpl struct {
schemaName string
name string
alias string
columnList []Column
columnList []ColumnExpression
}
func (t *tableImpl) AS(alias string) {
func (t *TableImpl) AS(alias string) {
t.alias = alias
for _, c := range t.columnList {
@ -166,16 +48,16 @@ func (t *tableImpl) AS(alias string) {
}
}
func (t *tableImpl) SchemaName() string {
func (t *TableImpl) SchemaName() string {
return t.schemaName
}
func (t *tableImpl) TableName() string {
func (t *TableImpl) TableName() string {
return t.name
}
func (t *tableImpl) columns() []IColumn {
ret := []IColumn{}
func (t *TableImpl) Columns() []Column {
ret := []Column{}
for _, col := range t.columnList {
ret = append(ret, col)
@ -184,15 +66,11 @@ func (t *tableImpl) columns() []IColumn {
return ret
}
func (t *tableImpl) dialect() Dialect {
return t.Dialect
}
func (t *tableImpl) accept(visitor visitor) {
func (t *TableImpl) accept(visitor visitor) {
visitor.visit(t)
}
func (t *tableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
func (t *TableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if t == nil {
return errors.New("jet: tableImpl is nil. ")
}
@ -220,55 +98,45 @@ const (
)
// Join expressions are pseudo readable tables.
type joinTable struct {
readableTableInterfaceImpl
lhs ReadableTable
rhs ReadableTable
type JoinTableImpl struct {
lhs Serializer
rhs Serializer
joinType JoinType
onCondition BoolExpression
}
func newJoinTable(
lhs ReadableTable,
rhs ReadableTable,
joinType JoinType,
onCondition BoolExpression) *joinTable {
func NewJoinTableImpl(lhs Serializer, rhs Serializer, joinType JoinType, onCondition BoolExpression) JoinTableImpl {
joinTable := &joinTable{
joinTable := JoinTableImpl{
lhs: lhs,
rhs: rhs,
joinType: joinType,
onCondition: onCondition,
}
joinTable.readableTableInterfaceImpl.parent = joinTable
return joinTable
}
func (t *joinTable) SchemaName() string {
func (t *JoinTableImpl) SchemaName() string {
return ""
}
func (t *joinTable) TableName() string {
func (t *JoinTableImpl) TableName() string {
return ""
}
func (t *joinTable) columns() []IColumn {
return append(t.lhs.columns(), t.rhs.columns()...)
func (t *JoinTableImpl) Columns() []Column {
//return append(t.lhs.columns(), t.rhs.columns()...)
panic("Unimplemented")
}
func (t *joinTable) accept(visitor visitor) {
t.lhs.accept(visitor)
t.rhs.accept(visitor)
func (t *JoinTableImpl) accept(visitor visitor) {
//t.lhs.accept(visitor)
//t.rhs.accept(visitor)
//TODO: remove
}
func (t *joinTable) dialect() Dialect {
return detectDialect(t)
}
func (t *joinTable) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) {
func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) {
if t == nil {
return errors.New("jet: Join table is nil. ")
}
@ -318,8 +186,8 @@ func (t *joinTable) serialize(statement StatementType, out *SqlBuilder, options
return nil
}
func UnwindColumns(column1 IColumn, columns ...IColumn) []IColumn {
columnList := []IColumn{}
func UnwindColumns(column1 Column, columns ...Column) []Column {
columnList := []Column{}
if val, ok := column1.(IColumnList); ok {
for _, col := range val.Columns() {
@ -334,8 +202,8 @@ func UnwindColumns(column1 IColumn, columns ...IColumn) []IColumn {
return columnList
}
func UnwidColumnList(columns []IColumn) []IColumn {
ret := []IColumn{}
func UnwidColumnList(columns []Column) []Column {
ret := []Column{}
for _, col := range columns {
if columnList, ok := col.(IColumnList); ok {

View file

@ -1,101 +0,0 @@
package jet
import (
"testing"
)
func TestJoinNilInputs(t *testing.T) {
assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
"jet: right hand side of join operation is nil table")
assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil),
"jet: join condition is nil")
}
func TestINNER_JOIN(t *testing.T) {
assertClauseSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).
INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int)
INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(Int(1))).
INNER_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = $1)
INNER JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2))
}
func TestLEFT_JOIN(t *testing.T) {
assertClauseSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)
LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(Int(1))).
LEFT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = $1)
LEFT JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2))
}
func TestRIGHT_JOIN(t *testing.T) {
assertClauseSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)
RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))).
RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = $1)
RIGHT JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2))
}
func TestFULL_JOIN(t *testing.T) {
assertClauseSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)).
FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int)
FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(Int(1))).
FULL_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = $1)
FULL JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2))
}
func TestCROSS_JOIN(t *testing.T) {
assertClauseSerialize(t, table1.
CROSS_JOIN(table2),
`db.table1
CROSS JOIN db.table2`)
assertClauseSerialize(t, table1.
CROSS_JOIN(table2).
CROSS_JOIN(table3),
`db.table1
CROSS JOIN db.table2
CROSS JOIN db.table3`)
}

View file

@ -2,9 +2,18 @@ package jet
import (
"gotest.tools/assert"
"strconv"
"testing"
)
var DefaultDialect = NewDialect(DialectParams{ // just for tests
AliasQuoteChar: '"',
IdentifierQuoteChar: '"',
ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord)
},
})
var table1Col1 = IntegerColumn("col1")
var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float")
@ -16,21 +25,7 @@ var table1ColTimestampz = TimestampzColumn("col_timestampz")
var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date")
var table1 = NewTable(
ANSII,
"db",
"table1",
table1Col1,
table1ColInt,
table1ColFloat,
table1Col3,
table1ColTime,
table1ColTimez,
table1ColBool,
table1ColDate,
table1ColTimestamp,
table1ColTimestampz,
)
var table1 = NewTable("db", "table1", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz)
var table2Col3 = IntegerColumn("col3")
var table2Col4 = IntegerColumn("col4")
@ -44,46 +39,27 @@ var table2ColTimestamp = TimestampColumn("col_timestamp")
var table2ColTimestampz = TimestampzColumn("col_timestampz")
var table2ColDate = DateColumn("col_date")
var table2 = NewTable(
ANSII,
"db",
"table2",
table2Col3,
table2Col4,
table2ColInt,
table2ColFloat,
table2ColStr,
table2ColBool,
table2ColTime,
table2ColTimez,
table2ColDate,
table2ColTimestamp,
table2ColTimestampz,
)
var table2 = NewTable("db", "table2", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz)
var table3Col1 = IntegerColumn("col1")
var table3ColInt = IntegerColumn("col_int")
var table3StrCol = StringColumn("col2")
var table3 = NewTable(
ANSII,
"db",
"table3",
table3Col1,
table3ColInt,
table3StrCol)
var table3 = NewTable("db", "table3", table3Col1, table3ColInt, table3StrCol)
func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) {
out := SqlBuilder{Dialect: ANSII}
out := SqlBuilder{Dialect: DefaultDialect}
err := clause.serialize(SelectStatementType, &out)
assert.NilError(t, err)
//fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}
func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) {
out := SqlBuilder{Dialect: ANSII}
out := SqlBuilder{Dialect: DefaultDialect}
err := clause.serialize(SelectStatementType, &out)
//fmt.Println(out.buff.String())
@ -92,7 +68,7 @@ func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string)
}
func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) {
out := SqlBuilder{Dialect: ANSII}
out := SqlBuilder{Dialect: DefaultDialect}
err := projection.serializeForProjection(SelectStatementType, &out)
assert.NilError(t, err)
@ -110,7 +86,7 @@ func assertStatement(t *testing.T, query Statement, expectedQuery string, expect
}
func assertStatementErr(t *testing.T, stmt Statement, errorStr string) {
_, _, err := stmt.Sql(ANSII)
_, _, err := stmt.Sql(DefaultDialect)
assert.Assert(t, err != nil)
assert.Error(t, err, errorStr)

View file

@ -1,5 +1,3 @@
// +build todo
package jet
import "testing"
@ -9,46 +7,46 @@ var timestampz = Timestampz(2000, 1, 31, 10, 20, 0, 0, 2)
func TestTimestampzExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.EQ(table2ColTimestampz), "(table1.col_timestampz = table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.EQ(timestampz),
"(table1.col_timestampz = $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
"(table1.col_timestampz = $1)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(table2ColTimestampz), "(table1.col_timestampz != table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS NOT DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT(table2ColTimestampz), "(table1.col_timestampz < table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(table2ColTimestampz), "(table1.col_timestampz <= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT(table2ColTimestampz), "(table1.col_timestampz > table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(table2ColTimestampz), "(table1.col_timestampz >= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExp(t *testing.T) {
assertClauseSerialize(t, TimestampzExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz),
"(table1.col_float < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
"(table1.col_float < $1)", "2000-01-31 10:20:00.000 +002")
}

View file

@ -1,5 +1,3 @@
// +build TODO
package jet
import "testing"
@ -8,46 +6,46 @@ var timezVar = Timez(10, 20, 0, 0, 4)
func TestTimezExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.EQ(table2ColTimez), "(table1.col_timez = table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1)", "10:20:00.000 +04")
}
func TestTimezExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.NOT_EQ(table2ColTimez), "(table1.col_timez != table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1)", "10:20:00.000 +04")
}
func TestTimezExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1)", "10:20:00.000 +04")
}
func TestTimezExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS NOT DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1)", "10:20:00.000 +04")
}
func TestTimezExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT(table2ColTimez), "(table1.col_timez < table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1)", "10:20:00.000 +04")
}
func TestTimezExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT_EQ(table2ColTimez), "(table1.col_timez <= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1)", "10:20:00.000 +04")
}
func TestTimezExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT(table2ColTimez), "(table1.col_timez > table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1)", "10:20:00.000 +04")
}
func TestTimezExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT_EQ(table2ColTimez), "(table1.col_timez >= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1)", "10:20:00.000 +04")
}
func TestTimezExp(t *testing.T) {
assertClauseSerialize(t, TimezExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, 4)),
"(table1.col_float < $1::time with time zone)", string("01:01:01.001 +04"))
"(table1.col_float < $1)", string("01:01:01.001 +04"))
}

View file

@ -1,126 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/utils"
)
// UpdateStatement is interface of SQL UPDATE statement
type UpdateStatement interface {
Statement
SET(value interface{}, values ...interface{}) UpdateStatement
MODEL(data interface{}) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...Projection) UpdateStatement
}
func newUpdateStatement(table WritableTable, columns []IColumn) UpdateStatement {
return &updateStatementImpl{
table: table,
columns: columns,
values: make([]Serializer, 0, len(columns)),
}
}
type updateStatementImpl struct {
table WritableTable
columns []IColumn
values []Serializer
where BoolExpression
returning []Projection
}
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
u.values = UnwindRowFromValues(value, values)
return u
}
func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement {
u.values = UnwindRowFromModel(u.columns, data)
return u
}
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.where = expression
return u
}
func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement {
u.returning = projections
return u
}
func (u *updateStatementImpl) accept(visitor visitor) {
visitor.visit(u)
u.table.accept(visitor)
}
func (u *updateStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
out := &SqlBuilder{
Dialect: detectDialect(u, dialect...),
}
out.NewLine()
out.WriteString("UPDATE")
if utils.IsNil(u.table) {
return "", nil, errors.New("jet: table to update is nil")
}
if err = u.table.serialize(UpdateStatementType, out); err != nil {
return
}
if len(u.columns) == 0 {
return "", nil, errors.New("jet: no columns selected")
}
if len(u.values) == 0 {
return "", nil, errors.New("jet: no values to updated")
}
out.NewLine()
out.WriteString("SET")
if u.where == nil {
return "", nil, errors.New("jet: WHERE clause not set")
}
if err = out.writeWhere(UpdateStatementType, u.where); err != nil {
return
}
if err = out.WriteReturning(UpdateStatementType, u.returning); err != nil {
return
}
query, args = out.finalize()
return
}
func (u *updateStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(u, dialect...)
}
func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(u, db, destination)
}
func (u *updateStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, u, db, destination)
}
func (u *updateStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(u, db)
}
func (u *updateStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, u, db)
}

View file

@ -1,57 +0,0 @@
package jet
import (
"testing"
)
func TestUpdateWithOneValue(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = $1
WHERE table1.col_int >= $2;
`
stmt := table1.UPDATE(table1ColInt).
SET(1).
WHERE(table1ColInt.GT_EQ(Int(33)))
assertStatement(t, stmt, expectedSQL, 1, int64(33))
}
func TestUpdateWithValues(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = $1, col_float = $2
WHERE table1.col_int >= $3;
`
stmt := table1.UPDATE(table1ColInt, table1ColFloat).
SET(1, 22.2).
WHERE(table1ColInt.GT_EQ(Int(33)))
assertStatement(t, stmt, expectedSQL, 1, 22.2, int64(33))
}
func TestUpdateOneColumnWithSelect(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_float = (
SELECT table1.col_float AS "table1.col_float"
FROM db.table1
)
WHERE table1.col1 = $1
RETURNING table1.col1 AS "table1.col1";
`
stmt := table1.
UPDATE(table1ColFloat).
SET(
table1.SELECT(table1ColFloat),
).
WHERE(table1Col1.EQ(Int(2))).
RETURNING(table1Col1)
assertStatement(t, stmt, expectedSQL, int64(2))
}
func TestInvalidInputs(t *testing.T) {
assertStatementErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set")
assertStatementErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list")
}

View file

@ -98,7 +98,7 @@ func SerializeProjectionList(statement StatementType, projections []Projection,
return nil
}
func SerializeColumnNames(columns []IColumn, out *SqlBuilder) error {
func SerializeColumnNames(columns []Column, out *SqlBuilder) error {
for i, col := range columns {
if i > 0 {
out.WriteString(", ")
@ -114,7 +114,7 @@ func SerializeColumnNames(columns []IColumn, out *SqlBuilder) error {
return nil
}
func ColumnListToProjectionList(columns []Column) []Projection {
func ColumnListToProjectionList(columns []ColumnExpression) []Projection {
var ret []Projection
for _, column := range columns {
@ -132,7 +132,7 @@ func valueToClause(value interface{}) Serializer {
return literal(value)
}
func UnwindRowFromModel(columns []IColumn, data interface{}) []Serializer {
func UnwindRowFromModel(columns []Column, data interface{}) []Serializer {
structValue := reflect.Indirect(reflect.ValueOf(data))
row := []Serializer{}
@ -163,7 +163,7 @@ func UnwindRowFromModel(columns []IColumn, data interface{}) []Serializer {
return row
}
func UnwindRowsFromModels(columns []IColumn, data interface{}) [][]Serializer {
func UnwindRowsFromModels(columns []Column, data interface{}) [][]Serializer {
sliceValue := reflect.Indirect(reflect.ValueOf(data))
mustBe(sliceValue, reflect.Slice)

View file

@ -45,10 +45,10 @@ func (f *DialectFinder) mustGetDialect() Dialect {
func (f *DialectFinder) visit(element acceptsVisitor) {
if table, ok := element.(TableBase); ok {
dialect := table.dialect()
f.dialects[dialect.Name()] = dialect
}
//if table, ok := element.(TableBase); ok {
// dialect := table.dialect()
// f.dialects[dialect.Name()] = dialect
//}
}
func detectDialect(element acceptsVisitor, dialectOverride ...Dialect) Dialect {