Jet internal refactor.

This commit is contained in:
go-jet 2019-08-11 14:29:03 +02:00
parent 4fbf576370
commit ee4897a1e2
49 changed files with 481 additions and 2528 deletions

View file

@ -71,19 +71,6 @@ func TestBoolLiteral(t *testing.T) {
assertClauseSerialize(t, Bool(false), "$1", false) assertClauseSerialize(t, Bool(false), "$1", false)
} }
func TestExists(t *testing.T) {
assertClauseSerialize(t, EXISTS(
table2.
SELECT(Int(1)).
WHERE(table1Col1.EQ(table2Col3)),
),
`(EXISTS (
SELECT $1
FROM db.table2
WHERE table1.col1 = table2.col3
))`, int64(1))
}
func TestBoolExp(t *testing.T) { func TestBoolExp(t *testing.T) {
assertClauseSerialize(t, BoolExp(String("true")), "$1", "true") assertClauseSerialize(t, BoolExp(String("true")), "$1", "true")
assertClauseSerialize(t, BoolExp(String("true")).IS_TRUE(), "$1 IS TRUE", "true") assertClauseSerialize(t, BoolExp(String("true")).IS_TRUE(), "$1 IS TRUE", "true")

View file

@ -18,12 +18,12 @@ type CastImpl struct {
expression Expression expression Expression
} }
func NewCastImpl(expression Expression) CastImpl { func NewCastImpl(expression Expression) Cast {
castImpl := CastImpl{ castImpl := CastImpl{
expression: expression, expression: expression,
} }
return castImpl return &castImpl
} }
func (b *CastImpl) AS(castType string) Expression { func (b *CastImpl) AS(castType string) Expression {

View file

@ -1,7 +1,11 @@
package jet package jet
//func TestCastAS(t *testing.T) { import (
// AssertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST(? AS boolean)", int64(1)) "testing"
// AssertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)") )
// AssertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)")
//} func TestCastAS(t *testing.T) {
assertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST($1 AS boolean)", int64(1))
assertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)")
assertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)")
}

View file

@ -36,7 +36,7 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SqlBuilder) e
return errors.New("jet: no column selected for Projection") return errors.New("jet: no column selected for Projection")
} }
return out.writeProjections(statementType, s.Projections) return out.WriteProjections(statementType, s.Projections)
} }
type ClauseFrom struct { type ClauseFrom struct {
@ -77,9 +77,9 @@ func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SqlBuilder)
out.NewLine() out.NewLine()
out.WriteString("GROUP BY") out.WriteString("GROUP BY")
out.increaseIdent() out.IncreaseIdent()
err := serializeGroupByClauseList(statementType, c.List, out) err := serializeGroupByClauseList(statementType, c.List, out)
out.decreaseIdent() out.DecreaseIdent()
return err return err
} }
@ -173,15 +173,10 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB
wrap := s.OrderBy.List != nil || s.Limit.Count >= 0 || s.Offset.Count >= 0 wrap := s.OrderBy.List != nil || s.Limit.Count >= 0 || s.Offset.Count >= 0
//if wrap {
// out.WriteString("(")
// out.increaseIdent()
//}
if wrap { if wrap {
out.NewLine() out.NewLine()
out.WriteString("(") out.WriteString("(")
out.increaseIdent() out.IncreaseIdent()
} }
for i, selectStmt := range s.Selects { for i, selectStmt := range s.Selects {
@ -207,7 +202,7 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB
} }
if wrap { if wrap {
out.decreaseIdent() out.DecreaseIdent()
out.NewLine() out.NewLine()
out.WriteString(")") out.WriteString(")")
} }
@ -224,12 +219,6 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB
return err return err
} }
//if wrap {
// out.decreaseIdent()
// out.newLine()
// out.WriteString(")")
//}
return nil return nil
} }
@ -253,7 +242,7 @@ func (u *ClauseUpdate) Serialize(statementType StatementType, out *SqlBuilder) e
} }
type ClauseSet struct { type ClauseSet struct {
Columns []IColumn Columns []Column
Values []Serializer Values []Serializer
} }
@ -265,7 +254,7 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) erro
return errors.New("jet: mismatch in numers of columns and values") return errors.New("jet: mismatch in numers of columns and values")
} }
out.increaseIdent(4) out.IncreaseIdent(4)
for i, column := range s.Columns { for i, column := range s.Columns {
if i > 0 { if i > 0 {
out.WriteString(", ") out.WriteString(", ")
@ -280,26 +269,26 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) erro
out.WriteString(" = ") out.WriteString(" = ")
if err := Serialize(s.Values[i], UpdateStatementType, out); err != nil { if err := s.Values[i].serialize(UpdateStatementType, out); err != nil {
return err return err
} }
} }
out.decreaseIdent(4) out.DecreaseIdent(4)
return nil return nil
} }
type ClauseReturning struct {
Projections []Projection
}
func (r *ClauseReturning) Serialize(statementType StatementType, out *SqlBuilder) error {
return out.WriteReturning(statementType, r.Projections)
}
type ClauseInsert struct { type ClauseInsert struct {
Table SerializerTable Table SerializerTable
Columns []IColumn Columns []Column
}
func (i *ClauseInsert) GetColumns() []Column {
if len(i.Columns) > 0 {
return i.Columns
}
return i.Table.Columns()
} }
func (i *ClauseInsert) Serialize(statementType StatementType, out *SqlBuilder) error { func (i *ClauseInsert) Serialize(statementType StatementType, out *SqlBuilder) error {
@ -347,7 +336,7 @@ func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) e
out.WriteString(",") out.WriteString(",")
} }
out.increaseIdent() out.IncreaseIdent()
out.NewLine() out.NewLine()
out.WriteString("(") out.WriteString("(")
@ -358,7 +347,7 @@ func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) e
} }
out.writeByte(')') out.writeByte(')')
out.decreaseIdent() out.DecreaseIdent()
} }
return nil return nil
} }
@ -459,235 +448,3 @@ func (i *ClauseIn) Serialize(statementType StatementType, out *SqlBuilder) error
return nil return nil
} }
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable2(Dialect Dialect, schemaName, name string, columns ...Column) TableImpl2 {
t := TableImpl2{
Dialect: Dialect,
schemaName: schemaName,
name: name,
columnList: columns,
}
for _, c := range columns {
c.SetTableName(name)
}
return t
}
type TableImpl2 struct {
Dialect Dialect
schemaName string
name string
alias string
columnList []Column
}
func (t *TableImpl2) AS(alias string) {
t.alias = alias
for _, c := range t.columnList {
c.SetTableName(alias)
}
}
func (t *TableImpl2) SchemaName() string {
return t.schemaName
}
func (t *TableImpl2) TableName() string {
return t.name
}
func (t *TableImpl2) Columns() []IColumn {
ret := []IColumn{}
for _, col := range t.columnList {
ret = append(ret, col)
}
return ret
}
func (t *TableImpl2) dialect() Dialect {
return t.Dialect
}
func (t *TableImpl2) accept(visitor visitor) {
visitor.visit(t)
}
func (t *TableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if t == nil {
return errors.New("jet: tableImpl is nil. ")
}
out.writeIdentifier(t.schemaName)
out.WriteString(".")
out.writeIdentifier(t.name)
if len(t.alias) > 0 {
out.WriteString("AS")
out.writeIdentifier(t.alias)
}
return nil
}
// Join expressions are pseudo readable tables.
type JoinTableImpl struct {
lhs Serializer
rhs Serializer
joinType JoinType
onCondition BoolExpression
}
func NewJoinTableImpl(lhs Serializer, rhs Serializer, joinType JoinType, onCondition BoolExpression) JoinTableImpl {
joinTable := JoinTableImpl{
lhs: lhs,
rhs: rhs,
joinType: joinType,
onCondition: onCondition,
}
return joinTable
}
func (t *JoinTableImpl) SchemaName() string {
return ""
}
func (t *JoinTableImpl) TableName() string {
return ""
}
func (t *JoinTableImpl) Columns() []IColumn {
//return append(t.lhs.columns(), t.rhs.columns()...)
panic("Unimplemented")
}
func (t *JoinTableImpl) accept(visitor visitor) {
//t.lhs.accept(visitor)
//t.rhs.accept(visitor)
//TODO: uncoment
}
func (t *JoinTableImpl) dialect() Dialect {
return detectDialect(t)
}
func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) {
if t == nil {
return errors.New("jet: Join table is nil. ")
}
if utils.IsNil(t.lhs) {
return errors.New("jet: left hand side of join operation is nil table")
}
if err = t.lhs.serialize(statement, out); err != nil {
return
}
out.NewLine()
switch t.joinType {
case InnerJoin:
out.WriteString("INNER JOIN")
case LeftJoin:
out.WriteString("LEFT JOIN")
case RightJoin:
out.WriteString("RIGHT JOIN")
case FullJoin:
out.WriteString("FULL JOIN")
case CrossJoin:
out.WriteString("CROSS JOIN")
}
if utils.IsNil(t.rhs) {
return errors.New("jet: right hand side of join operation is nil table")
}
if err = t.rhs.serialize(statement, out); err != nil {
return
}
if t.onCondition == nil && t.joinType != CrossJoin {
return errors.New("jet: join condition is nil")
}
if t.onCondition != nil {
out.WriteString("ON")
if err = t.onCondition.serialize(statement, out); err != nil {
return
}
}
return nil
}
// SelectTable is interface for SELECT sub-queries
type SelectTable interface {
Alias() string
AllColumns() ProjectionList
}
type SelectTableImpl2 struct {
selectStmt StatementWithProjections
alias string
projections []Projection
}
func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl2 {
selectTable := SelectTableImpl2{selectStmt: selectStmt, alias: alias}
for _, projection := range selectStmt.projections() {
newProjection := projection.fromImpl(&selectTable)
selectTable.projections = append(selectTable.projections, newProjection)
}
return selectTable
}
func (s *SelectTableImpl2) Alias() string {
return s.alias
}
func (s *SelectTableImpl2) Columns() []IColumn {
return nil
}
func (s *SelectTableImpl2) accept(visitor visitor) {
visitor.visit(s)
s.selectStmt.accept(visitor)
}
func (s *SelectTableImpl2) dialect() Dialect {
return detectDialect(s.selectStmt)
}
func (s *SelectTableImpl2) AllColumns() ProjectionList {
return s.projections
}
func (s *SelectTableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if s == nil {
return errors.New("jet: Expression table is nil. ")
}
err := s.selectStmt.serialize(statement, out)
if err != nil {
return err
}
out.WriteString("AS")
out.writeIdentifier(s.alias)
return nil
}

View file

@ -2,7 +2,7 @@
package jet package jet
type IColumn interface { type Column interface {
Name() string Name() string
TableName() string TableName() string
@ -12,9 +12,9 @@ type IColumn interface {
} }
// Column is common column interface for all types of columns. // Column is common column interface for all types of columns.
type Column interface { type ColumnExpression interface {
Column
Expression Expression
IColumn
} }
// The base type for real materialized columns. // The base type for real materialized columns.
@ -28,7 +28,7 @@ type columnImpl struct {
subQuery SelectTable subQuery SelectTable
} }
func newColumn(name string, tableName string, parent Column) columnImpl { func newColumn(name string, tableName string, parent ColumnExpression) columnImpl {
bc := columnImpl{ bc := columnImpl{
name: name, name: name,
tableName: tableName, tableName: tableName,
@ -109,19 +109,19 @@ func (c columnImpl) serialize(statement StatementType, out *SqlBuilder, options
type IColumnList interface { type IColumnList interface {
Projection Projection
IColumn Column
Columns() []Column Columns() []ColumnExpression
} }
func ColumnList(columns ...Column) IColumnList { func ColumnList(columns ...ColumnExpression) IColumnList {
return columnListImpl(columns) return columnListImpl(columns)
} }
// ColumnList is redefined type to support list of columns as single Projection // ColumnList is redefined type to support list of columns as single Projection
type columnListImpl []Column type columnListImpl []ColumnExpression
func (cl columnListImpl) Columns() []Column { func (cl columnListImpl) Columns() []ColumnExpression {
return cl return cl
} }

View file

@ -3,7 +3,7 @@ package jet
// ColumnBool is interface for SQL boolean columns. // ColumnBool is interface for SQL boolean columns.
type ColumnBool interface { type ColumnBool interface {
BoolExpression BoolExpression
IColumn Column
From(subQuery SelectTable) ColumnBool From(subQuery SelectTable) ColumnBool
} }
@ -42,7 +42,7 @@ func BoolColumn(name string) ColumnBool {
// ColumnFloat is interface for SQL real, numeric, decimal or double precision column. // ColumnFloat is interface for SQL real, numeric, decimal or double precision column.
type ColumnFloat interface { type ColumnFloat interface {
FloatExpression FloatExpression
IColumn Column
From(subQuery SelectTable) ColumnFloat From(subQuery SelectTable) ColumnFloat
} }
@ -80,7 +80,7 @@ func FloatColumn(name string) ColumnFloat {
// ColumnInteger is interface for SQL smallint, integer, bigint columns. // ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger interface { type ColumnInteger interface {
IntegerExpression IntegerExpression
IColumn Column
From(subQuery SelectTable) ColumnInteger From(subQuery SelectTable) ColumnInteger
} }
@ -118,7 +118,7 @@ func IntegerColumn(name string) ColumnInteger {
// bytea, uuid columns and enums types. // bytea, uuid columns and enums types.
type ColumnString interface { type ColumnString interface {
StringExpression StringExpression
IColumn Column
From(subQuery SelectTable) ColumnString From(subQuery SelectTable) ColumnString
} }
@ -155,7 +155,7 @@ func StringColumn(name string) ColumnString {
// ColumnTime is interface for SQL time column. // ColumnTime is interface for SQL time column.
type ColumnTime interface { type ColumnTime interface {
TimeExpression TimeExpression
IColumn Column
From(subQuery SelectTable) ColumnTime From(subQuery SelectTable) ColumnTime
} }
@ -190,7 +190,7 @@ func TimeColumn(name string) ColumnTime {
// ColumnTimez is interface of SQL time with time zone columns. // ColumnTimez is interface of SQL time with time zone columns.
type ColumnTimez interface { type ColumnTimez interface {
TimezExpression TimezExpression
IColumn Column
From(subQuery SelectTable) ColumnTimez From(subQuery SelectTable) ColumnTimez
} }
@ -227,7 +227,7 @@ func TimezColumn(name string) ColumnTimez {
// ColumnTimestamp is interface of SQL timestamp columns. // ColumnTimestamp is interface of SQL timestamp columns.
type ColumnTimestamp interface { type ColumnTimestamp interface {
TimestampExpression TimestampExpression
IColumn Column
From(subQuery SelectTable) ColumnTimestamp From(subQuery SelectTable) ColumnTimestamp
} }
@ -264,7 +264,7 @@ func TimestampColumn(name string) ColumnTimestamp {
// ColumnTimestampz is interface of SQL timestamp with timezone columns. // ColumnTimestampz is interface of SQL timestamp with timezone columns.
type ColumnTimestampz interface { type ColumnTimestampz interface {
TimestampzExpression TimestampzExpression
IColumn Column
From(subQuery SelectTable) ColumnTimestampz From(subQuery SelectTable) ColumnTimestampz
} }
@ -301,7 +301,7 @@ func TimestampzColumn(name string) ColumnTimestampz {
// ColumnDate is interface of SQL date columns. // ColumnDate is interface of SQL date columns.
type ColumnDate interface { type ColumnDate interface {
DateExpression DateExpression
IColumn Column
From(subQuery SelectTable) ColumnDate From(subQuery SelectTable) ColumnDate
} }

View file

@ -4,7 +4,9 @@ import (
"testing" "testing"
) )
var subQuery = table1.SELECT(table1ColFloat, table1ColInt).AsTable("sub_query") var subQuery = &SelectTableImpl2{
alias: "sub_query",
}
func TestNewBoolColumn(t *testing.T) { func TestNewBoolColumn(t *testing.T) {
boolColumn := BoolColumn("colBool").From(subQuery) boolColumn := BoolColumn("colBool").From(subQuery)

View file

@ -1,110 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
)
// DeleteStatement is interface for SQL DELETE statement
type DeleteStatement interface {
Statement
WHERE(expression BoolExpression) DeleteStatement
RETURNING(projections ...Projection) DeleteStatement
}
func newDeleteStatement(table WritableTable) DeleteStatement {
return &deleteStatementImpl{
table: table,
}
}
type deleteStatementImpl struct {
table WritableTable
where BoolExpression
returning []Projection
}
func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
d.where = expression
return d
}
func (d *deleteStatementImpl) RETURNING(projections ...Projection) DeleteStatement {
d.returning = projections
return d
}
func (d *deleteStatementImpl) accept(visitor visitor) {
visitor.visit(d)
d.table.accept(visitor)
}
func (d *deleteStatementImpl) serializeImpl(out *SqlBuilder) error {
if d == nil {
return errors.New("jet: delete statement is nil")
}
out.NewLine()
out.WriteString("DELETE FROM")
if d.table == nil {
return errors.New("jet: nil tableName")
}
if err := d.table.serialize(DeleteStatementType, out); err != nil {
return err
}
if d.where == nil {
return errors.New("jet: deleting without a WHERE clause")
}
if err := out.writeWhere(DeleteStatementType, d.where); err != nil {
return err
}
if err := out.WriteReturning(DeleteStatementType, d.returning); err != nil {
return err
}
return nil
}
func (d *deleteStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
queryData := &SqlBuilder{
Dialect: detectDialect(d, dialect...),
}
err = d.serializeImpl(queryData)
if err != nil {
return
}
query, args = queryData.finalize()
return
}
func (d *deleteStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(d, dialect...)
}
func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(d, db, destination)
}
func (d *deleteStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, d, db, destination)
}
func (d *deleteStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(d, db)
}
func (d *deleteStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, d, db)
}

View file

@ -1,25 +0,0 @@
package jet
import (
"testing"
)
func TestDeleteUnconditionally(t *testing.T) {
assertStatementErr(t, table1.DELETE(), `jet: deleting without a WHERE clause`)
assertStatementErr(t, table1.DELETE().WHERE(nil), `jet: deleting without a WHERE clause`)
}
func TestDeleteWithWhere(t *testing.T) {
assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), `
DELETE FROM db.table1
WHERE table1.col1 = $1;
`, int64(1))
}
func TestDeleteWithWhereAndReturning(t *testing.T) {
assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).RETURNING(table1Col1), `
DELETE FROM db.table1
WHERE table1.col1 = $1
RETURNING table1.col1 AS "table1.col1";
`, int64(1))
}

View file

@ -1,17 +1,5 @@
package jet package jet
import (
"strconv"
)
var ANSII = NewDialect(DialectParams{ // just for tests
AliasQuoteChar: '"',
IdentifierQuoteChar: '"',
ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord)
},
})
type Dialect interface { type Dialect interface {
Name() string Name() string
PackageName() string PackageName() string
@ -25,7 +13,7 @@ type SerializeFunc func(statement StatementType, out *SqlBuilder, options ...Ser
type SerializeOverride func(expressions ...Expression) SerializeFunc type SerializeOverride func(expressions ...Expression) SerializeFunc
type QueryPlaceholderFunc func(ord int) string type QueryPlaceholderFunc func(ord int) string
type UpdateAssigmentFunc func(columns []IColumn, values []Serializer, out *SqlBuilder) (err error) type UpdateAssigmentFunc func(columns []Column, values []Serializer, out *SqlBuilder) (err error)
type DialectParams struct { type DialectParams struct {
Name string Name string

View file

@ -26,33 +26,14 @@ func TestExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
} }
func TestIN(t *testing.T) { func TestIN(t *testing.T) {
assertClauseSerialize(t, table2ColInt.IN(Int(1), Int(2), Int(3)),
`(table2.col_int IN ($1, $2, $3))`, int64(1), int64(2), int64(3))
assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)),
`($1 IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
} }
func TestNOT_IN(t *testing.T) { func TestNOT_IN(t *testing.T) {
assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), assertClauseSerialize(t, table2ColInt.NOT_IN(Int(1), Int(2), Int(3)),
`($1 NOT IN (( `(table2.col_int NOT IN ($1, $2, $3))`, int64(1), int64(2), int64(3))
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) NOT IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
} }

View file

@ -1,180 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/utils"
)
// InsertStatement is interface for SQL INSERT statements
type InsertStatement interface {
Statement
// Insert row of values
VALUES(value interface{}, values ...interface{}) InsertStatement
// Insert row of values, where value for each column is extracted from filed of structure data.
// If data is not struct or there is no field for every column selected, this method will panic.
MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
RETURNING(projections ...Projection) InsertStatement
}
func newInsertStatement(t WritableTable, columns []IColumn) InsertStatement {
return &insertStatementImpl{
table: t,
columns: columns,
}
}
type insertStatementImpl struct {
table WritableTable
columns []IColumn
rows [][]Serializer
query SelectStatement
returning []Projection
}
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
i.rows = append(i.rows, UnwindRowFromValues(value, values))
return i
}
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
i.rows = append(i.rows, UnwindRowFromModel(i.getColumns(), data))
return i
}
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
i.rows = append(i.rows, UnwindRowsFromModels(i.getColumns(), data)...)
return i
}
func (i *insertStatementImpl) RETURNING(projections ...Projection) InsertStatement {
i.returning = projections
return i
}
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
i.query = selectStatement
return i
}
func (i *insertStatementImpl) getColumns() []IColumn {
if len(i.columns) > 0 {
return i.columns
}
return i.table.columns()
}
func (i *insertStatementImpl) accept(visitor visitor) {
visitor.visit(i)
i.table.accept(visitor)
}
func (i *insertStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(i, dialect...)
}
func (i *insertStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
out := &SqlBuilder{
Dialect: detectDialect(i, dialect...),
}
out.NewLine()
out.WriteString("INSERT INTO")
if utils.IsNil(i.table) {
return "", nil, errors.New("jet: table is nil")
}
err = i.table.serialize(InsertStatementType, out)
if err != nil {
return
}
if len(i.columns) > 0 {
out.WriteString("(")
err = SerializeColumnNames(i.columns, out)
if err != nil {
return
}
out.WriteString(")")
}
//TODO:
if len(i.rows) == 0 && i.query == nil {
return "", nil, errors.New("jet: no row values or query specified")
}
if len(i.rows) > 0 && i.query != nil {
return "", nil, errors.New("jet: only row values or query has to be specified")
}
if len(i.rows) > 0 {
out.WriteString("VALUES")
for rowIndex, row := range i.rows {
if rowIndex > 0 {
out.WriteString(",")
}
out.increaseIdent()
out.NewLine()
out.WriteString("(")
err = SerializeClauseList(InsertStatementType, row, out)
if err != nil {
return "", nil, err
}
out.writeByte(')')
out.decreaseIdent()
}
}
if i.query != nil {
err = i.query.serialize(InsertStatementType, out)
if err != nil {
return
}
}
if err = out.WriteReturning(InsertStatementType, i.returning); err != nil {
return
}
query, args = out.finalize()
return
}
func (i *insertStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(i, db, destination)
}
func (i *insertStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, i, db, destination)
}
func (i *insertStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(i, db)
}
func (i *insertStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, i, db)
}

View file

@ -1,147 +0,0 @@
package jet
import (
"gotest.tools/assert"
"testing"
"time"
)
func TestInvalidInsert(t *testing.T) {
assertStatementErr(t, table1.INSERT(table1Col1), "jet: no row values or query specified")
assertStatementErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list")
}
func TestInsertNilValue(t *testing.T) {
assertStatement(t, table1.INSERT(table1Col1).VALUES(nil), `
INSERT INTO db.table1 (col1) VALUES
($1);
`, nil)
}
func TestInsertSingleValue(t *testing.T) {
assertStatement(t, table1.INSERT(table1Col1).VALUES(1), `
INSERT INTO db.table1 (col1) VALUES
($1);
`, int(1))
}
func TestInsertWithColumnList(t *testing.T) {
columnList := ColumnList(table3ColInt, table3StrCol)
assertStatement(t, table3.INSERT(columnList).VALUES(1, 3), `
INSERT INTO db.table3 (col_int, col2) VALUES
($1, $2);
`, 1, 3)
}
func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
assertStatement(t, table1.INSERT(table1ColTime).VALUES(date), `
INSERT INTO db.table1 (col_time) VALUES
($1);
`, date)
}
func TestInsertMultipleValues(t *testing.T) {
assertStatement(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), `
INSERT INTO db.table1 (col1, col_float, col3) VALUES
($1, $2, $3);
`, 1, 2, 3)
}
func TestInsertMultipleRows(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColFloat).
VALUES(1, 2).
VALUES(11, 22).
VALUES(111, 222)
assertStatement(t, stmt, `
INSERT INTO db.table1 (col1, col_float) VALUES
($1, $2),
($3, $4),
($5, $6);
`, 1, 2, 11, 22, 111, 222)
}
func TestInsertValuesFromModel(t *testing.T) {
type Table1Model struct {
Col1 *int
ColFloat float64
}
one := 1
toInsert := Table1Model{
Col1: &one,
ColFloat: 1.11,
}
stmt := table1.INSERT(table1Col1, table1ColFloat).
MODEL(toInsert).
MODEL(&toInsert)
expectedSQL := `
INSERT INTO db.table1 (col1, col_float)
VALUES ($1, $2),
($3, $4);
`
assertStatement(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11))
}
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "missing struct field for column : col1")
}()
type Table1Model struct {
Col1Prim int
Col2 string
}
newData := Table1Model{
Col1Prim: 1,
Col2: "one",
}
table1.
INSERT(table1Col1, table1ColFloat).
MODEL(newData)
}
func TestInsertFromNonStructModel(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "argument mismatch: expected struct, got []int")
}()
table2.INSERT(table2ColInt).MODEL([]int{})
}
func TestInsertQuery(t *testing.T) {
stmt := table1.INSERT(table1Col1).
QUERY(table1.SELECT(table1Col1))
var expectedSQL = `
INSERT INTO db.table1 (col1) (
SELECT table1.col1 AS "table1.col1"
FROM db.table1
);
`
assertStatement(t, stmt, expectedSQL)
}
func TestInsertDefaultValue(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColFloat).
VALUES(DEFAULT, "two")
var expectedSQL = `
INSERT INTO db.table1 (col1, col_float) VALUES
(DEFAULT, $1);
`
assertStatement(t, stmt, expectedSQL, "two")
}

View file

@ -1,112 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
)
// TableLockMode is a type of possible SQL table lock
type TableLockMode string
// LockStatement interface for SQL LOCK statement
type LockStatement interface {
Statement
IN(lockMode string) LockStatement
NOWAIT() LockStatement
}
type lockStatementImpl struct {
tables []WritableTable
lockMode string
nowait bool
}
// LOCK creates lock statement for list of tables.
func LOCK(tables ...WritableTable) LockStatement {
return &lockStatementImpl{
tables: tables,
}
}
func (l *lockStatementImpl) IN(lockMode string) LockStatement {
l.lockMode = lockMode
return l
}
func (l *lockStatementImpl) NOWAIT() LockStatement {
l.nowait = true
return l
}
func (l *lockStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(l, dialect...)
}
func (l *lockStatementImpl) accept(visitor visitor) {
visitor.visit(l)
for _, table := range l.tables {
table.accept(visitor)
}
}
func (l *lockStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
if l == nil {
return "", nil, errors.New("jet: nil Statement")
}
if len(l.tables) == 0 {
return "", nil, errors.New("jet: There is no table selected to be locked")
}
out := &SqlBuilder{
Dialect: detectDialect(l, dialect...),
}
out.NewLine()
out.WriteString("LOCK TABLE")
for i, table := range l.tables {
if i > 0 {
out.WriteString(", ")
}
err := table.serialize(LockStatementType, out)
if err != nil {
return "", nil, err
}
}
if l.lockMode != "" {
out.WriteString("IN")
out.WriteString(string(l.lockMode))
out.WriteString("MODE")
}
if l.nowait {
out.WriteString("NOWAIT")
}
query, args = out.finalize()
return
}
func (l *lockStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(l, db, destination)
}
func (l *lockStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, l, db, destination)
}
func (l *lockStatementImpl) Exec(db execution.DB) (sql.Result, error) {
return exec(l, db)
}
func (l *lockStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, l, db)
}

View file

@ -17,7 +17,7 @@ func BIT_NOT(expr IntegerExpression) IntegerExpression {
//----------- Comparison operators ---------------// //----------- Comparison operators ---------------//
// EXISTS checks for existence of the rows in subQuery // EXISTS checks for existence of the rows in subQuery
func EXISTS(subQuery SelectStatement) BoolExpression { func EXISTS(subQuery Expression) BoolExpression {
return newPrefixBoolOperator(subQuery, "EXISTS") return newPrefixBoolOperator(subQuery, "EXISTS")
} }

View file

@ -0,0 +1,48 @@
package jet
// SelectLock is interface for SELECT statement locks
type SelectLock interface {
Serializer
NOWAIT() SelectLock
SKIP_LOCKED() SelectLock
}
type selectLockImpl struct {
lockStrength string
noWait, skipLocked bool
}
func NewSelectLock(name string) func() SelectLock {
return func() SelectLock {
return newSelectLock(name)
}
}
func newSelectLock(lockStrength string) SelectLock {
return &selectLockImpl{lockStrength: lockStrength}
}
func (s *selectLockImpl) NOWAIT() SelectLock {
s.noWait = true
return s
}
func (s *selectLockImpl) SKIP_LOCKED() SelectLock {
s.skipLocked = true
return s
}
func (s *selectLockImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
out.WriteString(s.lockStrength)
if s.noWait {
out.WriteString("NOWAIT")
}
if s.skipLocked {
out.WriteString("SKIP LOCKED")
}
return nil
}

View file

@ -1,355 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
)
// SelectStatement is interface for SQL SELECT statements
type SelectStatement interface {
Statement
IExpression
DISTINCT() SelectStatement
FROM(table ReadableTable) SelectStatement
WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...GroupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement
ORDER_BY(orderByClauses ...OrderByClause) SelectStatement
LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement
FOR(lock SelectLock) SelectStatement
UNION(rhs SelectStatement) SelectStatement
UNION_ALL(rhs SelectStatement) SelectStatement
INTERSECT(rhs SelectStatement) SelectStatement
INTERSECT_ALL(rhs SelectStatement) SelectStatement
EXCEPT(rhs SelectStatement) SelectStatement
EXCEPT_ALL(rhs SelectStatement) SelectStatement
AsTable(alias string) SelectTable
projections() []Projection
}
//SELECT creates new SelectStatement with list of projections
func SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(nil, append([]Projection{projection1}, projections...))
}
type selectStatementImpl struct {
ExpressionInterfaceImpl
parent SelectStatement
table ReadableTable
distinct bool
projectionList []Projection
where BoolExpression
groupBy []GroupByClause
having BoolExpression
orderBy []OrderByClause
limit, offset int64
lockFor SelectLock
}
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
newSelect := &selectStatementImpl{
table: table,
projectionList: projections,
limit: -1,
offset: -1,
distinct: false,
}
newSelect.ExpressionInterfaceImpl.Parent = newSelect
newSelect.parent = newSelect
return newSelect
}
func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement {
s.table = table
return s.parent
}
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
//return newSelectTable(s.parent, alias)
panic("UNimplemented.")
}
func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement {
s.where = expression
return s.parent
}
func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement {
s.groupBy = groupByClauses
return s.parent
}
func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement {
s.having = expression
return s.parent
}
func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement {
s.orderBy = clauses
return s.parent
}
func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement {
s.offset = offset
return s.parent
}
func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement {
s.limit = limit
return s.parent
}
func (s *selectStatementImpl) DISTINCT() SelectStatement {
s.distinct = true
return s.parent
}
func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement {
s.lockFor = lock
return s.parent
}
func (s *selectStatementImpl) UNION(rhs SelectStatement) SelectStatement {
return UNION(s.parent, rhs)
}
func (s *selectStatementImpl) UNION_ALL(rhs SelectStatement) SelectStatement {
return UNION_ALL(s.parent, rhs)
}
func (s *selectStatementImpl) INTERSECT(rhs SelectStatement) SelectStatement {
return INTERSECT(s.parent, rhs)
}
func (s *selectStatementImpl) INTERSECT_ALL(rhs SelectStatement) SelectStatement {
return INTERSECT_ALL(s.parent, rhs)
}
func (s *selectStatementImpl) EXCEPT(rhs SelectStatement) SelectStatement {
return EXCEPT(s.parent, rhs)
}
func (s *selectStatementImpl) EXCEPT_ALL(rhs SelectStatement) SelectStatement {
return EXCEPT_ALL(s.parent, rhs)
}
func (s *selectStatementImpl) projections() []Projection {
return s.projectionList
}
func (s *selectStatementImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if s == nil {
return errors.New("jet: Select expression is nil. ")
}
out.WriteString("(")
out.increaseIdent()
err := s.serializeImpl(out)
out.decreaseIdent()
if err != nil {
return err
}
out.NewLine()
out.WriteString(")")
return nil
}
func (s *selectStatementImpl) serializeImpl(out *SqlBuilder) error {
if s == nil {
return errors.New("jet: Select expression is nil. ")
}
out.NewLine()
out.WriteString("SELECT")
if s.distinct {
out.WriteString("DISTINCT")
}
if len(s.projectionList) == 0 {
return errors.New("jet: no column selected for Projection")
}
err := out.writeProjections(SelectStatementType, s.projectionList)
if err != nil {
return err
}
if s.table != nil {
if err := out.writeFrom(SelectStatementType, s.table); err != nil {
return err
}
}
if s.where != nil {
err := out.writeWhere(SelectStatementType, s.where)
if err != nil {
return nil
}
}
if s.groupBy != nil && len(s.groupBy) > 0 {
err := out.writeGroupBy(SelectStatementType, s.groupBy)
if err != nil {
return err
}
}
if s.having != nil {
err := out.writeHaving(SelectStatementType, s.having)
if err != nil {
return err
}
}
if s.orderBy != nil {
err := out.writeOrderBy(SelectStatementType, s.orderBy)
if err != nil {
return err
}
}
if s.limit >= 0 {
out.NewLine()
out.WriteString("LIMIT")
out.insertParametrizedArgument(s.limit)
}
if s.offset >= 0 {
out.NewLine()
out.WriteString("OFFSET")
out.insertParametrizedArgument(s.offset)
}
if s.lockFor != nil {
out.NewLine()
out.WriteString("FOR")
err := s.lockFor.serialize(SelectStatementType, out)
if err != nil {
return err
}
}
return nil
}
func (s *selectStatementImpl) accept(visitor visitor) {
visitor.visit(s)
if s.table != nil {
s.table.accept(visitor)
}
if s.where != nil {
s.where.accept(visitor)
}
if s.having != nil {
s.having.accept(visitor)
}
}
func (s *selectStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
queryData := &SqlBuilder{
Dialect: detectDialect(s, dialect...),
}
err = s.serializeImpl(queryData)
if err != nil {
return "", nil, err
}
query, args = queryData.finalize()
return
}
func (s *selectStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(s.parent, dialect...)
}
func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(s.parent, db, destination)
}
func (s *selectStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, s.parent, db, destination)
}
func (s *selectStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(s.parent, db)
}
func (s *selectStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, s.parent, db)
}
// SelectLock is interface for SELECT statement locks
type SelectLock interface {
Serializer
NOWAIT() SelectLock
SKIP_LOCKED() SelectLock
}
type selectLockImpl struct {
lockStrength string
noWait, skipLocked bool
}
func NewSelectLock(name string) func() SelectLock {
return func() SelectLock {
return newSelectLock(name)
}
}
func newSelectLock(lockStrength string) SelectLock {
return &selectLockImpl{lockStrength: lockStrength}
}
func (s *selectLockImpl) NOWAIT() SelectLock {
s.noWait = true
return s
}
func (s *selectLockImpl) SKIP_LOCKED() SelectLock {
s.skipLocked = true
return s
}
func (s *selectLockImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
out.WriteString(s.lockStrength)
if s.noWait {
out.WriteString("NOWAIT")
}
if s.skipLocked {
out.WriteString("SKIP LOCKED")
}
return nil
}

View file

@ -1,198 +0,0 @@
package jet
import "testing"
func TestInvalidSelect(t *testing.T) {
assertStatementErr(t, SELECT(nil), "jet: Projection is nil")
}
func TestSelectColumnList(t *testing.T) {
columnList := ColumnList(table2ColInt, table2ColFloat, table3ColInt)
assertStatement(t, SELECT(columnList).FROM(table2), `
SELECT table2.col_int AS "table2.col_int",
table2.col_float AS "table2.col_float",
table3.col_int AS "table3.col_int"
FROM db.table2;
`)
}
func TestSelectLiterals(t *testing.T) {
assertStatement(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), `
SELECT $1,
$2,
$3
FROM db.table1;
`, int64(1), 2.2, false)
}
func TestSelectDistinct(t *testing.T) {
assertStatement(t, SELECT(table1ColBool).DISTINCT().FROM(table1), `
SELECT DISTINCT table1.col_bool AS "table1.col_bool"
FROM db.table1;
`)
}
func TestSelectFrom(t *testing.T) {
assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1;
`)
assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int);
`)
assertStatement(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int);
`)
}
func TestSelectWhere(t *testing.T) {
assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE $1;
`, true)
assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE table1.col_int >= $1;
`, int64(10))
}
func TestSelectGroupBy(t *testing.T) {
assertStatement(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
GROUP BY table2.col_float;
`)
}
func TestSelectHaving(t *testing.T) {
assertStatement(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), `
SELECT table3.col_int AS "table3.col_int"
FROM db.table3
HAVING table1.col_bool = $1;
`, true)
}
func TestSelectOrderBy(t *testing.T) {
assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), `
SELECT table2.col_float AS "table2.col_float"
FROM db.table2
ORDER BY table2.col_int DESC;
`)
assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), `
SELECT table2.col_float AS "table2.col_float"
FROM db.table2
ORDER BY table2.col_int DESC, table2.col_int ASC;
`)
}
func TestSelectLimitOffset(t *testing.T) {
assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
LIMIT $1;
`, int64(10))
assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
LIMIT $1
OFFSET $2;
`, int64(10), int64(2))
}
func TestSelectSets(t *testing.T) {
select1 := SELECT(table1ColBool).FROM(table1)
select2 := SELECT(table2ColBool).FROM(table2)
assertStatement(t, select1.UNION(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
UNION
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.UNION_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
UNION ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.INTERSECT(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
INTERSECT
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.INTERSECT_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
INTERSECT ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.EXCEPT(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
EXCEPT
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.EXCEPT_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
EXCEPT ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
}

View file

@ -1,70 +1,58 @@
package jet package jet
//// SelectTable is interface for SELECT sub-queries import "errors"
//type SelectTable interface {
// ReadableTable // SelectTable is interface for SELECT sub-queries
// type SelectTable interface {
// Alias() string Alias() string
// AllColumns() ProjectionList
// AllColumns() ProjectionList }
//}
// type SelectTableImpl2 struct {
//type selectTableImpl struct { selectStmt StatementWithProjections
// readableTableInterfaceImpl alias string
// selectStmt SelectStatement
// alias string projections []Projection
// }
// projections []Projection
//} func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl2 {
// selectTable := SelectTableImpl2{selectStmt: selectStmt, alias: alias}
//func newSelectTable(selectStmt SelectStatement, alias string) SelectTable {
// expTable := &selectTableImpl{selectStmt: selectStmt, alias: alias} for _, projection := range selectStmt.projections() {
// newProjection := projection.fromImpl(&selectTable)
// expTable.readableTableInterfaceImpl.parent = expTable
// selectTable.projections = append(selectTable.projections, newProjection)
// for _, projection := range selectStmt.projections() { }
// newProjection := projection.fromImpl(expTable)
// return selectTable
// expTable.projections = append(expTable.projections, newProjection) }
// }
// func (s *SelectTableImpl2) Alias() string {
// return expTable return s.alias
//} }
//
//func (s *selectTableImpl) Alias() string { func (s *SelectTableImpl2) accept(visitor visitor) {
// return s.alias visitor.visit(s)
//} s.selectStmt.accept(visitor)
// }
//func (s *selectTableImpl) columns() []IColumn {
// return nil func (s *SelectTableImpl2) AllColumns() ProjectionList {
//} return s.projections
// }
//func (s *selectTableImpl) accept(visitor visitor) {
// visitor.visit(s) func (s *SelectTableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
// s.selectStmt.accept(visitor) if s == nil {
//} return errors.New("jet: Expression table is nil. ")
// }
//func (s *selectTableImpl) dialect() Dialect {
// return detectDialect(s.selectStmt) err := s.selectStmt.serialize(statement, out)
//}
// if err != nil {
//func (s *selectTableImpl) AllColumns() ProjectionList { return err
// return s.projections }
//}
// out.WriteString("AS")
//func (s *selectTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { out.writeIdentifier(s.alias)
// if s == nil {
// return errors.New("jet: Expression table is nil. ") return nil
// } }
//
// err := s.selectStmt.serialize(statement, out)
//
// if err != nil {
// return err
// }
//
// out.WriteString("AS")
// out.writeIdentifier(s.alias)
//
// return nil
//}

View file

@ -1,197 +0,0 @@
package jet
import (
"errors"
)
// UNION effectively appends the result of sub-queries(select statements) into single query.
// It eliminates duplicate rows from its result.
func UNION(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(Union, false, toSelectList(lhs, rhs, selects...))
}
// UNION_ALL effectively appends the result of sub-queries(select statements) into single query.
// It does not eliminates duplicate rows from its result.
func UNION_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(Union, true, toSelectList(lhs, rhs, selects...))
}
// INTERSECT returns all rows that are in query results.
// It eliminates duplicate rows from its result.
func INTERSECT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(Intersect, false, toSelectList(lhs, rhs, selects...))
}
// INTERSECT_ALL returns all rows that are in query results.
// It does not eliminates duplicate rows from its result.
func INTERSECT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(Intersect, true, toSelectList(lhs, rhs, selects...))
}
// EXCEPT returns all rows that are in the result of query lhs but not in the result of query rhs.
// It eliminates duplicate rows from its result.
func EXCEPT(lhs, rhs SelectStatement) SelectStatement {
return newSetStatementImpl(Except, false, toSelectList(lhs, rhs))
}
// EXCEPT_ALL returns all rows that are in the result of query lhs but not in the result of query rhs.
// It does not eliminates duplicate rows from its result.
func EXCEPT_ALL(lhs, rhs SelectStatement) SelectStatement {
return newSetStatementImpl(Except, true, toSelectList(lhs, rhs))
}
func toSelectList(lhs, rhs SelectStatement, selects ...SelectStatement) []SelectStatement {
return append([]SelectStatement{lhs, rhs}, selects...)
}
const (
Union = "UNION"
Intersect = "INTERSECT"
Except = "EXCEPT"
)
// Similar to selectStatementImpl, but less complete
type setStatementImpl struct {
selectStatementImpl
operator string
all bool
selects []SelectStatement
}
func newSetStatementImpl(operator string, all bool, selects []SelectStatement) SelectStatement {
setStatement := &setStatementImpl{
operator: operator,
all: all,
selects: selects,
}
setStatement.selectStatementImpl.ExpressionInterfaceImpl.Parent = setStatement
setStatement.selectStatementImpl.parent = setStatement
setStatement.limit = -1
setStatement.offset = -1
return setStatement
}
func (s *setStatementImpl) accept(visitor visitor) {
visitor.visit(s)
for _, selects := range s.selects {
selects.accept(visitor)
}
}
func (s *setStatementImpl) projections() []Projection {
if len(s.selects) > 0 {
return s.selects[0].projections()
}
return []Projection{}
}
func (s *setStatementImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if s == nil {
return errors.New("jet: Set expression is nil. ")
}
wrap := s.orderBy != nil || s.limit >= 0 || s.offset >= 0
if wrap {
out.WriteString("(")
out.increaseIdent()
}
err := s.serializeImpl(out)
if err != nil {
return err
}
if wrap {
out.decreaseIdent()
out.NewLine()
out.WriteString(")")
}
return nil
}
func (s *setStatementImpl) serializeImpl(out *SqlBuilder) error {
if s == nil {
return errors.New("jet: Set expression is nil. ")
}
if len(s.selects) < 2 {
return errors.New("jet: UNION Statement must have at least two SELECT statements")
}
if setOverride := out.Dialect.SerializeOverride(s.operator); setOverride != nil {
return setOverride()(SelectStatementType, out)
}
out.NewLine()
out.WriteString("(")
out.increaseIdent()
for i, selectStmt := range s.selects {
out.NewLine()
if i > 0 {
out.WriteString(s.operator)
if s.all {
out.WriteString("ALL")
}
out.NewLine()
}
if selectStmt == nil {
return errors.New("jet: select statement is nil")
}
err := selectStmt.serialize(SetStatementType, out)
if err != nil {
return err
}
}
out.decreaseIdent()
out.NewLine()
out.WriteString(")")
if s.orderBy != nil {
err := out.writeOrderBy(SetStatementType, s.orderBy)
if err != nil {
return err
}
}
if s.limit >= 0 {
out.NewLine()
out.WriteString("LIMIT")
out.insertParametrizedArgument(s.limit)
}
if s.offset >= 0 {
out.NewLine()
out.WriteString("OFFSET")
out.insertParametrizedArgument(s.offset)
}
return nil
}
func (s *setStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
queryData := &SqlBuilder{
Dialect: detectDialect(s, dialect...),
}
err = s.serializeImpl(queryData)
if err != nil {
return
}
query, args = queryData.finalize()
return
}

View file

@ -1,301 +0,0 @@
package jet
import (
"gotest.tools/assert"
"testing"
)
func TestUnionTwoSelect(t *testing.T) {
var expectedSQL = `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`
unionStmt1 := table1.
SELECT(table1Col1).
UNION(
table2.SELECT(table2Col3),
)
unionStmt2 := UNION(table1.SELECT(table1Col1), table2.SELECT(table2Col3))
assertStatement(t, unionStmt1, expectedSQL)
assertStatement(t, unionStmt2, expectedSQL)
}
func TestUnionNilSelect(t *testing.T) {
unionStmt := table1.
SELECT(table1Col1).
UNION(nil)
assertStatementErr(t, unionStmt, "jet: select statement is nil")
}
func TestUnionThreeSelect1(t *testing.T) {
unionStmt1 := table1.SELECT(table1Col1).
UNION(
table2.SELECT(table2Col3),
).
UNION(
table3.SELECT(table3Col1),
)
var expectedSQL = `
(
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
)
UNION
(
SELECT table3.col1 AS "table3.col1"
FROM db.table3
)
);
`
assertStatement(t, unionStmt1, expectedSQL)
}
func TestUnionThreeSelect2(t *testing.T) {
unionStmt2 := UNION(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
table3.SELECT(table3Col1),
)
var expectedSQL = `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
UNION
(
SELECT table3.col1 AS "table3.col1"
FROM db.table3
)
);
`
assertStatement(t, unionStmt2, expectedSQL)
}
func TestUnionWithOrderBy(t *testing.T) {
unionStmt := UNION(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).
ORDER_BY(table1Col1.ASC())
assertStatement(t, unionStmt, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
)
ORDER BY "table1.col1" ASC;
`)
}
func TestUnionWithLimitAndOffset(t *testing.T) {
query, args, err := UNION(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).
LIMIT(10).
OFFSET(11).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
)
LIMIT $1
OFFSET $2;
`)
assert.Equal(t, len(args), 2)
}
func TestUnionInUnion(t *testing.T) {
expectedSQL := `
(
(
SELECT table2.col3 AS "table2.col3",
table2.col3 AS "table2.col3"
FROM db.table2
)
UNION
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
)
);
`
query := UNION(
SELECT(table2Col3, table2Col3).FROM(table2),
UNION_ALL(table1.SELECT(table1Col1), table2.SELECT(table2Col3)),
)
assertStatement(t, query, expectedSQL)
}
func TestUnionALL(t *testing.T) {
query, args, err := UNION_ALL(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestINTERSECT(t *testing.T) {
query, args, err := INTERSECT(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
INTERSECT
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestINTERSECT_ALL(t *testing.T) {
query, args, err := INTERSECT_ALL(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
INTERSECT ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestEXCEPT(t *testing.T) {
query, args, err := EXCEPT(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
EXCEPT
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestEXCEPT_ALL(t *testing.T) {
query, args, err := EXCEPT_ALL(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
EXCEPT ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}

View file

@ -24,7 +24,7 @@ func (s *SqlBuilder) DebugSQL() string {
const defaultIdent = 5 const defaultIdent = 5
func (q *SqlBuilder) increaseIdent(ident ...int) { func (q *SqlBuilder) IncreaseIdent(ident ...int) {
if len(ident) > 0 { if len(ident) > 0 {
q.ident += ident[0] q.ident += ident[0]
} else { } else {
@ -32,7 +32,7 @@ func (q *SqlBuilder) increaseIdent(ident ...int) {
} }
} }
func (q *SqlBuilder) decreaseIdent(ident ...int) { func (q *SqlBuilder) DecreaseIdent(ident ...int) {
toDecrease := defaultIdent toDecrease := defaultIdent
if len(ident) > 0 { if len(ident) > 0 {
@ -46,10 +46,10 @@ func (q *SqlBuilder) decreaseIdent(ident ...int) {
q.ident -= toDecrease q.ident -= toDecrease
} }
func (q *SqlBuilder) writeProjections(statement StatementType, projections []Projection) error { func (q *SqlBuilder) WriteProjections(statement StatementType, projections []Projection) error {
q.increaseIdent() q.IncreaseIdent()
err := SerializeProjectionList(statement, projections, q) err := SerializeProjectionList(statement, projections, q)
q.decreaseIdent() q.DecreaseIdent()
return err return err
} }
@ -57,9 +57,9 @@ func (q *SqlBuilder) writeFrom(statement StatementType, table Serializer) error
q.NewLine() q.NewLine()
q.WriteString("FROM") q.WriteString("FROM")
q.increaseIdent() q.IncreaseIdent()
err := table.serialize(statement, q) err := table.serialize(statement, q)
q.decreaseIdent() q.DecreaseIdent()
return err return err
} }
@ -68,9 +68,9 @@ func (q *SqlBuilder) writeWhere(statement StatementType, where Expression) error
q.NewLine() q.NewLine()
q.WriteString("WHERE") q.WriteString("WHERE")
q.increaseIdent() q.IncreaseIdent()
err := where.serialize(statement, q, noWrap) err := where.serialize(statement, q, noWrap)
q.decreaseIdent() q.DecreaseIdent()
return err return err
} }
@ -79,9 +79,9 @@ func (q *SqlBuilder) writeGroupBy(statement StatementType, groupBy []GroupByClau
q.NewLine() q.NewLine()
q.WriteString("GROUP BY") q.WriteString("GROUP BY")
q.increaseIdent() q.IncreaseIdent()
err := serializeGroupByClauseList(statement, groupBy, q) err := serializeGroupByClauseList(statement, groupBy, q)
q.decreaseIdent() q.DecreaseIdent()
return err return err
} }
@ -90,9 +90,9 @@ func (q *SqlBuilder) writeOrderBy(statement StatementType, orderBy []OrderByClau
q.NewLine() q.NewLine()
q.WriteString("ORDER BY") q.WriteString("ORDER BY")
q.increaseIdent() q.IncreaseIdent()
err := serializeOrderByClauseList(statement, orderBy, q) err := serializeOrderByClauseList(statement, orderBy, q)
q.decreaseIdent() q.DecreaseIdent()
return err return err
} }
@ -101,9 +101,9 @@ func (q *SqlBuilder) writeHaving(statement StatementType, having Expression) err
q.NewLine() q.NewLine()
q.WriteString("HAVING") q.WriteString("HAVING")
q.increaseIdent() q.IncreaseIdent()
err := having.serialize(statement, q, noWrap) err := having.serialize(statement, q, noWrap)
q.decreaseIdent() q.DecreaseIdent()
return err return err
} }
@ -115,9 +115,9 @@ func (q *SqlBuilder) WriteReturning(statement StatementType, returning []Project
q.NewLine() q.NewLine()
q.WriteString("RETURNING") q.WriteString("RETURNING")
q.increaseIdent() q.IncreaseIdent()
return q.writeProjections(statement, returning) return q.WriteProjections(statement, returning)
} }
func (q *SqlBuilder) NewLine() { func (q *SqlBuilder) NewLine() {

View file

@ -204,7 +204,7 @@ func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, opti
if !contains(options, noWrap) { if !contains(options, noWrap) {
out.WriteString("(") out.WriteString("(")
out.increaseIdent() out.IncreaseIdent()
} }
for _, clause := range s.Clauses { for _, clause := range s.Clauses {
@ -216,7 +216,7 @@ func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, opti
} }
if !contains(options, noWrap) { if !contains(options, noWrap) {
out.decreaseIdent() out.DecreaseIdent()
out.NewLine() out.NewLine()
out.WriteString(")") out.WriteString(")")
} }

View file

@ -11,154 +11,36 @@ type SerializerTable interface {
} }
type TableInterface interface { type TableInterface interface {
Columns() []IColumn Columns() []Column
}
type TableBase interface {
dialect() Dialect
columns() []IColumn
}
type readableTable interface {
// Generates a select query on the current tableName.
SELECT(projection Projection, projections ...Projection) SelectStatement
// Creates a inner join tableName Expression using onCondition.
INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a left join tableName Expression using onCondition.
LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a right join tableName Expression using onCondition.
RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a full join tableName Expression using onCondition.
FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a cross join tableName Expression using onCondition.
CROSS_JOIN(table ReadableTable) ReadableTable
}
type writableTable interface {
INSERT(columns ...IColumn) InsertStatement
UPDATE(column IColumn, columns ...IColumn) UpdateStatement
DELETE() DeleteStatement
LOCK() LockStatement
}
// ReadableTable interface
type ReadableTable interface {
TableBase
readableTable
Serializer
acceptsVisitor
}
// WritableTable interface
type WritableTable interface {
TableBase
writableTable
Serializer
acceptsVisitor
}
// Table interface
type Table interface {
TableBase
readableTable
writableTable
Serializer
acceptsVisitor
SchemaName() string SchemaName() string
TableName() string TableName() string
AS(alias string) AS(alias string)
} }
type readableTableInterfaceImpl struct {
parent ReadableTable
}
// Generates a select query on the current tableName.
func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...))
}
// Creates a inner join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, InnerJoin, onCondition)
}
// Creates a left join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, LeftJoin, onCondition)
}
// Creates a right join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, RightJoin, onCondition)
}
func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, FullJoin, onCondition)
}
func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable {
return newJoinTable(r.parent, table, CrossJoin, nil)
}
type writableTableInterfaceImpl struct {
parent WritableTable
}
func (w *writableTableInterfaceImpl) INSERT(columns ...IColumn) InsertStatement {
return newInsertStatement(w.parent, UnwidColumnList(columns))
}
func (w *writableTableInterfaceImpl) UPDATE(column IColumn, columns ...IColumn) UpdateStatement {
return newUpdateStatement(w.parent, UnwindColumns(column, columns...))
}
func (w *writableTableInterfaceImpl) DELETE() DeleteStatement {
return newDeleteStatement(w.parent)
}
func (w *writableTableInterfaceImpl) LOCK() LockStatement {
return LOCK(w.parent)
}
// NewTable creates new table with schema Name, table Name and list of columns // NewTable creates new table with schema Name, table Name and list of columns
func NewTable(Dialect Dialect, schemaName, name string, columns ...Column) Table { func NewTable(schemaName, name string, columns ...ColumnExpression) TableImpl {
t := &tableImpl{ t := TableImpl{
Dialect: Dialect,
schemaName: schemaName, schemaName: schemaName,
name: name, name: name,
columnList: columns, columnList: columns,
} }
for _, c := range columns { for _, c := range columns {
c.SetTableName(name) c.SetTableName(name)
} }
t.readableTableInterfaceImpl.parent = t
t.writableTableInterfaceImpl.parent = t
return t return t
} }
type tableImpl struct { type TableImpl struct {
readableTableInterfaceImpl
writableTableInterfaceImpl
Dialect Dialect
schemaName string schemaName string
name string name string
alias string alias string
columnList []Column columnList []ColumnExpression
} }
func (t *tableImpl) AS(alias string) { func (t *TableImpl) AS(alias string) {
t.alias = alias t.alias = alias
for _, c := range t.columnList { for _, c := range t.columnList {
@ -166,16 +48,16 @@ func (t *tableImpl) AS(alias string) {
} }
} }
func (t *tableImpl) SchemaName() string { func (t *TableImpl) SchemaName() string {
return t.schemaName return t.schemaName
} }
func (t *tableImpl) TableName() string { func (t *TableImpl) TableName() string {
return t.name return t.name
} }
func (t *tableImpl) columns() []IColumn { func (t *TableImpl) Columns() []Column {
ret := []IColumn{} ret := []Column{}
for _, col := range t.columnList { for _, col := range t.columnList {
ret = append(ret, col) ret = append(ret, col)
@ -184,15 +66,11 @@ func (t *tableImpl) columns() []IColumn {
return ret return ret
} }
func (t *tableImpl) dialect() Dialect { func (t *TableImpl) accept(visitor visitor) {
return t.Dialect
}
func (t *tableImpl) accept(visitor visitor) {
visitor.visit(t) visitor.visit(t)
} }
func (t *tableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { func (t *TableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if t == nil { if t == nil {
return errors.New("jet: tableImpl is nil. ") return errors.New("jet: tableImpl is nil. ")
} }
@ -220,55 +98,45 @@ const (
) )
// Join expressions are pseudo readable tables. // Join expressions are pseudo readable tables.
type joinTable struct { type JoinTableImpl struct {
readableTableInterfaceImpl lhs Serializer
rhs Serializer
lhs ReadableTable
rhs ReadableTable
joinType JoinType joinType JoinType
onCondition BoolExpression onCondition BoolExpression
} }
func newJoinTable( func NewJoinTableImpl(lhs Serializer, rhs Serializer, joinType JoinType, onCondition BoolExpression) JoinTableImpl {
lhs ReadableTable,
rhs ReadableTable,
joinType JoinType,
onCondition BoolExpression) *joinTable {
joinTable := &joinTable{ joinTable := JoinTableImpl{
lhs: lhs, lhs: lhs,
rhs: rhs, rhs: rhs,
joinType: joinType, joinType: joinType,
onCondition: onCondition, onCondition: onCondition,
} }
joinTable.readableTableInterfaceImpl.parent = joinTable
return joinTable return joinTable
} }
func (t *joinTable) SchemaName() string { func (t *JoinTableImpl) SchemaName() string {
return "" return ""
} }
func (t *joinTable) TableName() string { func (t *JoinTableImpl) TableName() string {
return "" return ""
} }
func (t *joinTable) columns() []IColumn { func (t *JoinTableImpl) Columns() []Column {
return append(t.lhs.columns(), t.rhs.columns()...) //return append(t.lhs.columns(), t.rhs.columns()...)
panic("Unimplemented")
} }
func (t *joinTable) accept(visitor visitor) { func (t *JoinTableImpl) accept(visitor visitor) {
t.lhs.accept(visitor) //t.lhs.accept(visitor)
t.rhs.accept(visitor) //t.rhs.accept(visitor)
//TODO: remove
} }
func (t *joinTable) dialect() Dialect { func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) {
return detectDialect(t)
}
func (t *joinTable) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) {
if t == nil { if t == nil {
return errors.New("jet: Join table is nil. ") return errors.New("jet: Join table is nil. ")
} }
@ -318,8 +186,8 @@ func (t *joinTable) serialize(statement StatementType, out *SqlBuilder, options
return nil return nil
} }
func UnwindColumns(column1 IColumn, columns ...IColumn) []IColumn { func UnwindColumns(column1 Column, columns ...Column) []Column {
columnList := []IColumn{} columnList := []Column{}
if val, ok := column1.(IColumnList); ok { if val, ok := column1.(IColumnList); ok {
for _, col := range val.Columns() { for _, col := range val.Columns() {
@ -334,8 +202,8 @@ func UnwindColumns(column1 IColumn, columns ...IColumn) []IColumn {
return columnList return columnList
} }
func UnwidColumnList(columns []IColumn) []IColumn { func UnwidColumnList(columns []Column) []Column {
ret := []IColumn{} ret := []Column{}
for _, col := range columns { for _, col := range columns {
if columnList, ok := col.(IColumnList); ok { if columnList, ok := col.(IColumnList); ok {

View file

@ -2,9 +2,18 @@ package jet
import ( import (
"gotest.tools/assert" "gotest.tools/assert"
"strconv"
"testing" "testing"
) )
var DefaultDialect = NewDialect(DialectParams{ // just for tests
AliasQuoteChar: '"',
IdentifierQuoteChar: '"',
ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord)
},
})
var table1Col1 = IntegerColumn("col1") var table1Col1 = IntegerColumn("col1")
var table1ColInt = IntegerColumn("col_int") var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float") var table1ColFloat = FloatColumn("col_float")
@ -16,21 +25,7 @@ var table1ColTimestampz = TimestampzColumn("col_timestampz")
var table1ColBool = BoolColumn("col_bool") var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date") var table1ColDate = DateColumn("col_date")
var table1 = NewTable( var table1 = NewTable("db", "table1", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz)
ANSII,
"db",
"table1",
table1Col1,
table1ColInt,
table1ColFloat,
table1Col3,
table1ColTime,
table1ColTimez,
table1ColBool,
table1ColDate,
table1ColTimestamp,
table1ColTimestampz,
)
var table2Col3 = IntegerColumn("col3") var table2Col3 = IntegerColumn("col3")
var table2Col4 = IntegerColumn("col4") var table2Col4 = IntegerColumn("col4")
@ -44,46 +39,27 @@ var table2ColTimestamp = TimestampColumn("col_timestamp")
var table2ColTimestampz = TimestampzColumn("col_timestampz") var table2ColTimestampz = TimestampzColumn("col_timestampz")
var table2ColDate = DateColumn("col_date") var table2ColDate = DateColumn("col_date")
var table2 = NewTable( var table2 = NewTable("db", "table2", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz)
ANSII,
"db",
"table2",
table2Col3,
table2Col4,
table2ColInt,
table2ColFloat,
table2ColStr,
table2ColBool,
table2ColTime,
table2ColTimez,
table2ColDate,
table2ColTimestamp,
table2ColTimestampz,
)
var table3Col1 = IntegerColumn("col1") var table3Col1 = IntegerColumn("col1")
var table3ColInt = IntegerColumn("col_int") var table3ColInt = IntegerColumn("col_int")
var table3StrCol = StringColumn("col2") var table3StrCol = StringColumn("col2")
var table3 = NewTable( var table3 = NewTable("db", "table3", table3Col1, table3ColInt, table3StrCol)
ANSII,
"db",
"table3",
table3Col1,
table3ColInt,
table3StrCol)
func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) { func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) {
out := SqlBuilder{Dialect: ANSII} out := SqlBuilder{Dialect: DefaultDialect}
err := clause.serialize(SelectStatementType, &out) err := clause.serialize(SelectStatementType, &out)
assert.NilError(t, err) assert.NilError(t, err)
//fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query) assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args) assert.DeepEqual(t, out.Args, args)
} }
func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) { func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) {
out := SqlBuilder{Dialect: ANSII} out := SqlBuilder{Dialect: DefaultDialect}
err := clause.serialize(SelectStatementType, &out) err := clause.serialize(SelectStatementType, &out)
//fmt.Println(out.buff.String()) //fmt.Println(out.buff.String())
@ -92,7 +68,7 @@ func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string)
} }
func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) { func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) {
out := SqlBuilder{Dialect: ANSII} out := SqlBuilder{Dialect: DefaultDialect}
err := projection.serializeForProjection(SelectStatementType, &out) err := projection.serializeForProjection(SelectStatementType, &out)
assert.NilError(t, err) assert.NilError(t, err)
@ -110,7 +86,7 @@ func assertStatement(t *testing.T, query Statement, expectedQuery string, expect
} }
func assertStatementErr(t *testing.T, stmt Statement, errorStr string) { func assertStatementErr(t *testing.T, stmt Statement, errorStr string) {
_, _, err := stmt.Sql(ANSII) _, _, err := stmt.Sql(DefaultDialect)
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
assert.Error(t, err, errorStr) assert.Error(t, err, errorStr)

View file

@ -1,5 +1,3 @@
// +build todo
package jet package jet
import "testing" import "testing"
@ -9,46 +7,46 @@ var timestampz = Timestampz(2000, 1, 31, 10, 20, 0, 0, 2)
func TestTimestampzExpressionEQ(t *testing.T) { func TestTimestampzExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.EQ(table2ColTimestampz), "(table1.col_timestampz = table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.EQ(table2ColTimestampz), "(table1.col_timestampz = table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.EQ(timestampz), assertClauseSerialize(t, table1ColTimestampz.EQ(timestampz),
"(table1.col_timestampz = $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") "(table1.col_timestampz = $1)", "2000-01-31 10:20:00.000 +002")
} }
func TestTimestampzExpressionNOT_EQ(t *testing.T) { func TestTimestampzExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(table2ColTimestampz), "(table1.col_timestampz != table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(table2ColTimestampz), "(table1.col_timestampz != table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1)", "2000-01-31 10:20:00.000 +002")
} }
func TestTimestampzExpressionIS_DISTINCT_FROM(t *testing.T) { func TestTimestampzExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS DISTINCT FROM table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1)", "2000-01-31 10:20:00.000 +002")
} }
func TestTimestampzExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { func TestTimestampzExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS NOT DISTINCT FROM table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS NOT DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1)", "2000-01-31 10:20:00.000 +002")
} }
func TestTimestampzExpressionLT(t *testing.T) { func TestTimestampzExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT(table2ColTimestampz), "(table1.col_timestampz < table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.LT(table2ColTimestampz), "(table1.col_timestampz < table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1)", "2000-01-31 10:20:00.000 +002")
} }
func TestTimestampzExpressionLT_EQ(t *testing.T) { func TestTimestampzExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(table2ColTimestampz), "(table1.col_timestampz <= table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.LT_EQ(table2ColTimestampz), "(table1.col_timestampz <= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1)", "2000-01-31 10:20:00.000 +002")
} }
func TestTimestampzExpressionGT(t *testing.T) { func TestTimestampzExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT(table2ColTimestampz), "(table1.col_timestampz > table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.GT(table2ColTimestampz), "(table1.col_timestampz > table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1)", "2000-01-31 10:20:00.000 +002")
} }
func TestTimestampzExpressionGT_EQ(t *testing.T) { func TestTimestampzExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(table2ColTimestampz), "(table1.col_timestampz >= table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.GT_EQ(table2ColTimestampz), "(table1.col_timestampz >= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1)", "2000-01-31 10:20:00.000 +002")
} }
func TestTimestampzExp(t *testing.T) { func TestTimestampzExp(t *testing.T) {
assertClauseSerialize(t, TimestampzExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimestampzExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz), assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz),
"(table1.col_float < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") "(table1.col_float < $1)", "2000-01-31 10:20:00.000 +002")
} }

View file

@ -1,5 +1,3 @@
// +build TODO
package jet package jet
import "testing" import "testing"
@ -8,46 +6,46 @@ var timezVar = Timez(10, 20, 0, 0, 4)
func TestTimezExpressionEQ(t *testing.T) { func TestTimezExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.EQ(table2ColTimez), "(table1.col_timez = table2.col_timez)") assertClauseSerialize(t, table1ColTimez.EQ(table2ColTimez), "(table1.col_timez = table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1)", "10:20:00.000 +04")
} }
func TestTimezExpressionNOT_EQ(t *testing.T) { func TestTimezExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.NOT_EQ(table2ColTimez), "(table1.col_timez != table2.col_timez)") assertClauseSerialize(t, table1ColTimez.NOT_EQ(table2ColTimez), "(table1.col_timez != table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1)", "10:20:00.000 +04")
} }
func TestTimezExpressionIS_DISTINCT_FROM(t *testing.T) { func TestTimezExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS DISTINCT FROM table2.col_timez)") assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1)", "10:20:00.000 +04")
} }
func TestTimezExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { func TestTimezExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS NOT DISTINCT FROM table2.col_timez)") assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS NOT DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1)", "10:20:00.000 +04")
} }
func TestTimezExpressionLT(t *testing.T) { func TestTimezExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT(table2ColTimez), "(table1.col_timez < table2.col_timez)") assertClauseSerialize(t, table1ColTimez.LT(table2ColTimez), "(table1.col_timez < table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1)", "10:20:00.000 +04")
} }
func TestTimezExpressionLT_EQ(t *testing.T) { func TestTimezExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT_EQ(table2ColTimez), "(table1.col_timez <= table2.col_timez)") assertClauseSerialize(t, table1ColTimez.LT_EQ(table2ColTimez), "(table1.col_timez <= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1)", "10:20:00.000 +04")
} }
func TestTimezExpressionGT(t *testing.T) { func TestTimezExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT(table2ColTimez), "(table1.col_timez > table2.col_timez)") assertClauseSerialize(t, table1ColTimez.GT(table2ColTimez), "(table1.col_timez > table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1)", "10:20:00.000 +04")
} }
func TestTimezExpressionGT_EQ(t *testing.T) { func TestTimezExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT_EQ(table2ColTimez), "(table1.col_timez >= table2.col_timez)") assertClauseSerialize(t, table1ColTimez.GT_EQ(table2ColTimez), "(table1.col_timez >= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1)", "10:20:00.000 +04")
} }
func TestTimezExp(t *testing.T) { func TestTimezExp(t *testing.T) {
assertClauseSerialize(t, TimezExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimezExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, 4)), assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, 4)),
"(table1.col_float < $1::time with time zone)", string("01:01:01.001 +04")) "(table1.col_float < $1)", string("01:01:01.001 +04"))
} }

View file

@ -1,126 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/utils"
)
// UpdateStatement is interface of SQL UPDATE statement
type UpdateStatement interface {
Statement
SET(value interface{}, values ...interface{}) UpdateStatement
MODEL(data interface{}) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...Projection) UpdateStatement
}
func newUpdateStatement(table WritableTable, columns []IColumn) UpdateStatement {
return &updateStatementImpl{
table: table,
columns: columns,
values: make([]Serializer, 0, len(columns)),
}
}
type updateStatementImpl struct {
table WritableTable
columns []IColumn
values []Serializer
where BoolExpression
returning []Projection
}
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
u.values = UnwindRowFromValues(value, values)
return u
}
func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement {
u.values = UnwindRowFromModel(u.columns, data)
return u
}
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.where = expression
return u
}
func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement {
u.returning = projections
return u
}
func (u *updateStatementImpl) accept(visitor visitor) {
visitor.visit(u)
u.table.accept(visitor)
}
func (u *updateStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
out := &SqlBuilder{
Dialect: detectDialect(u, dialect...),
}
out.NewLine()
out.WriteString("UPDATE")
if utils.IsNil(u.table) {
return "", nil, errors.New("jet: table to update is nil")
}
if err = u.table.serialize(UpdateStatementType, out); err != nil {
return
}
if len(u.columns) == 0 {
return "", nil, errors.New("jet: no columns selected")
}
if len(u.values) == 0 {
return "", nil, errors.New("jet: no values to updated")
}
out.NewLine()
out.WriteString("SET")
if u.where == nil {
return "", nil, errors.New("jet: WHERE clause not set")
}
if err = out.writeWhere(UpdateStatementType, u.where); err != nil {
return
}
if err = out.WriteReturning(UpdateStatementType, u.returning); err != nil {
return
}
query, args = out.finalize()
return
}
func (u *updateStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(u, dialect...)
}
func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(u, db, destination)
}
func (u *updateStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, u, db, destination)
}
func (u *updateStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(u, db)
}
func (u *updateStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, u, db)
}

View file

@ -1,57 +0,0 @@
package jet
import (
"testing"
)
func TestUpdateWithOneValue(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = $1
WHERE table1.col_int >= $2;
`
stmt := table1.UPDATE(table1ColInt).
SET(1).
WHERE(table1ColInt.GT_EQ(Int(33)))
assertStatement(t, stmt, expectedSQL, 1, int64(33))
}
func TestUpdateWithValues(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = $1, col_float = $2
WHERE table1.col_int >= $3;
`
stmt := table1.UPDATE(table1ColInt, table1ColFloat).
SET(1, 22.2).
WHERE(table1ColInt.GT_EQ(Int(33)))
assertStatement(t, stmt, expectedSQL, 1, 22.2, int64(33))
}
func TestUpdateOneColumnWithSelect(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_float = (
SELECT table1.col_float AS "table1.col_float"
FROM db.table1
)
WHERE table1.col1 = $1
RETURNING table1.col1 AS "table1.col1";
`
stmt := table1.
UPDATE(table1ColFloat).
SET(
table1.SELECT(table1ColFloat),
).
WHERE(table1Col1.EQ(Int(2))).
RETURNING(table1Col1)
assertStatement(t, stmt, expectedSQL, int64(2))
}
func TestInvalidInputs(t *testing.T) {
assertStatementErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set")
assertStatementErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list")
}

View file

@ -98,7 +98,7 @@ func SerializeProjectionList(statement StatementType, projections []Projection,
return nil return nil
} }
func SerializeColumnNames(columns []IColumn, out *SqlBuilder) error { func SerializeColumnNames(columns []Column, out *SqlBuilder) error {
for i, col := range columns { for i, col := range columns {
if i > 0 { if i > 0 {
out.WriteString(", ") out.WriteString(", ")
@ -114,7 +114,7 @@ func SerializeColumnNames(columns []IColumn, out *SqlBuilder) error {
return nil return nil
} }
func ColumnListToProjectionList(columns []Column) []Projection { func ColumnListToProjectionList(columns []ColumnExpression) []Projection {
var ret []Projection var ret []Projection
for _, column := range columns { for _, column := range columns {
@ -132,7 +132,7 @@ func valueToClause(value interface{}) Serializer {
return literal(value) return literal(value)
} }
func UnwindRowFromModel(columns []IColumn, data interface{}) []Serializer { func UnwindRowFromModel(columns []Column, data interface{}) []Serializer {
structValue := reflect.Indirect(reflect.ValueOf(data)) structValue := reflect.Indirect(reflect.ValueOf(data))
row := []Serializer{} row := []Serializer{}
@ -163,7 +163,7 @@ func UnwindRowFromModel(columns []IColumn, data interface{}) []Serializer {
return row return row
} }
func UnwindRowsFromModels(columns []IColumn, data interface{}) [][]Serializer { func UnwindRowsFromModels(columns []Column, data interface{}) [][]Serializer {
sliceValue := reflect.Indirect(reflect.ValueOf(data)) sliceValue := reflect.Indirect(reflect.ValueOf(data))
mustBe(sliceValue, reflect.Slice) mustBe(sliceValue, reflect.Slice)

View file

@ -45,10 +45,10 @@ func (f *DialectFinder) mustGetDialect() Dialect {
func (f *DialectFinder) visit(element acceptsVisitor) { func (f *DialectFinder) visit(element acceptsVisitor) {
if table, ok := element.(TableBase); ok { //if table, ok := element.(TableBase); ok {
dialect := table.dialect() // dialect := table.dialect()
f.dialects[dialect.Name()] = dialect // f.dialects[dialect.Name()] = dialect
} //}
} }
func detectDialect(element acceptsVisitor, dialectOverride ...Dialect) Dialect { func detectDialect(element acceptsVisitor, dialectOverride ...Dialect) Dialect {

View file

@ -14,13 +14,13 @@ type cast interface {
} }
type castImpl struct { type castImpl struct {
jet.CastImpl jet.Cast
} }
func CAST(expr jet.Expression) cast { func CAST(expr jet.Expression) cast {
castImpl := &castImpl{} castImpl := &castImpl{}
castImpl.CastImpl = jet.NewCastImpl(expr) castImpl.Cast = jet.NewCastImpl(expr)
return castImpl return castImpl
} }

View file

@ -2,7 +2,7 @@ package mysql
import "github.com/go-jet/jet/internal/jet" import "github.com/go-jet/jet/internal/jet"
type Column jet.Column type Column jet.ColumnExpression
type IColumnList jet.IColumnList type IColumnList jet.IColumnList

View file

@ -33,3 +33,16 @@ func TestIntExpressionBIT_XOR(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int ^ table2.col_int)") assertClauseSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int ^ table2.col_int)")
assertClauseSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int ^ ?)", int64(11)) assertClauseSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int ^ ?)", int64(11))
} }
func TestExists(t *testing.T) {
assertClauseSerialize(t, EXISTS(
table2.
SELECT(Int(1)).
WHERE(table1Col1.EQ(table2Col3)),
),
`(EXISTS (
SELECT ?
FROM db.table2
WHERE table1.col1 = table2.col3
))`, int64(1))
}

View file

@ -16,9 +16,9 @@ type InsertStatement interface {
QUERY(selectStatement SelectStatement) InsertStatement QUERY(selectStatement SelectStatement) InsertStatement
} }
func newInsertStatement(table Table, columns []jet.IColumn) InsertStatement { func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{} newInsert := &insertStatementImpl{}
newInsert.StatementImpl = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newInsert, newInsert.StatementImpl = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert, &newInsert.Values, &newInsert.Select) &newInsert.Insert, &newInsert.Values, &newInsert.Select)
newInsert.Insert.Table = table newInsert.Insert.Table = table
@ -41,12 +41,12 @@ func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) I
} }
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromModel(i.getColumns(), data)) i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromModel(i.Insert.GetColumns(), data))
return i return i
} }
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
i.Values.Rows = append(i.Values.Rows, jet.UnwindRowsFromModels(i.getColumns(), data)...) i.Values.Rows = append(i.Values.Rows, jet.UnwindRowsFromModels(i.Insert.GetColumns(), data)...)
return i return i
} }
@ -54,11 +54,3 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState
i.Select.Query = selectStatement i.Select.Query = selectStatement
return i return i
} }
func (i *insertStatementImpl) getColumns() []jet.IColumn {
if len(i.Insert.Columns) > 0 {
return i.Insert.Columns
}
return i.Insert.Table.Columns()
}

View file

@ -12,12 +12,12 @@ type Table interface {
jet.SerializerTable jet.SerializerTable
readableTable readableTable
INSERT(columns ...jet.IColumn) InsertStatement INSERT(columns ...jet.Column) InsertStatement
UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement
DELETE() DeleteStatement DELETE() DeleteStatement
//LOCK() LockStatement //LOCK() LockStatement
AS(alias string) //As(alias string)
} }
type readableTable interface { type readableTable interface {
@ -41,8 +41,8 @@ type readableTable interface {
} }
type ReadableTable interface { type ReadableTable interface {
jet.SerializerTable
readableTable readableTable
jet.Serializer
} }
type readableTableInterfaceImpl struct { type readableTableInterfaceImpl struct {
@ -77,9 +77,9 @@ func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) Table {
return newJoinTable(r.parent, table, jet.CrossJoin, nil) return newJoinTable(r.parent, table, jet.CrossJoin, nil)
} }
func NewTable(schemaName, name string, columns ...jet.Column) Table { func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table {
t := &tableImpl{ t := &tableImpl{
TableImpl2: jet.NewTable2(Dialect, schemaName, name, columns...), TableImpl: jet.NewTable(schemaName, name, columns...),
} }
t.readableTableInterfaceImpl.parent = t t.readableTableInterfaceImpl.parent = t
@ -89,16 +89,16 @@ func NewTable(schemaName, name string, columns ...jet.Column) Table {
} }
type tableImpl struct { type tableImpl struct {
jet.TableImpl2 jet.TableImpl
readableTableInterfaceImpl readableTableInterfaceImpl
parent Table parent Table
} }
func (w *tableImpl) INSERT(columns ...jet.IColumn) InsertStatement { func (w *tableImpl) INSERT(columns ...jet.Column) InsertStatement {
return newInsertStatement(w.parent, jet.UnwidColumnList(columns)) return newInsertStatement(w.parent, jet.UnwidColumnList(columns))
} }
func (w *tableImpl) UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement { func (w *tableImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement {
return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...)) return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...))
} }

101
mysql/table_test.go Normal file
View file

@ -0,0 +1,101 @@
package mysql
import (
"testing"
)
func TestJoinNilInputs(t *testing.T) {
assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
"jet: right hand side of join operation is nil table")
assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil),
"jet: join condition is nil")
}
func TestINNER_JOIN(t *testing.T) {
assertClauseSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).
INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int)
INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(Int(1))).
INNER_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = ?)
INNER JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestLEFT_JOIN(t *testing.T) {
assertClauseSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)
LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(Int(1))).
LEFT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = ?)
LEFT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestRIGHT_JOIN(t *testing.T) {
assertClauseSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)
RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))).
RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = ?)
RIGHT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestFULL_JOIN(t *testing.T) {
assertClauseSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)).
FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int)
FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(Int(1))).
FULL_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = ?)
FULL JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestCROSS_JOIN(t *testing.T) {
assertClauseSerialize(t, table1.
CROSS_JOIN(table2),
`db.table1
CROSS JOIN db.table2`)
assertClauseSerialize(t, table1.
CROSS_JOIN(table2).
CROSS_JOIN(table3),
`db.table1
CROSS JOIN db.table2
CROSS JOIN db.table3`)
}

View file

@ -20,7 +20,7 @@ type updateStatementImpl struct {
Where jet.ClauseWhere Where jet.ClauseWhere
} }
func newUpdateStatement(table Table, columns []jet.IColumn) UpdateStatement { func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
update := &updateStatementImpl{} update := &updateStatementImpl{}
update.StatementImpl = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, update.StatementImpl = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update,
&update.Set, &update.Where) &update.Set, &update.Where)

View file

@ -64,6 +64,8 @@ func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, ar
assert.NilError(t, err) assert.NilError(t, err)
//fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query) assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args) assert.DeepEqual(t, out.Args, args)
} }

View file

@ -37,13 +37,13 @@ type cast interface {
} }
type castImpl struct { type castImpl struct {
jet.CastImpl jet.Cast
} }
func CAST(expr Expression) cast { func CAST(expr Expression) cast {
castImpl := &castImpl{} castImpl := &castImpl{}
castImpl.CastImpl = jet.NewCastImpl(expr) castImpl.Cast = jet.NewCastImpl(expr)
return castImpl return castImpl
} }

21
postgres/clauses.go Normal file
View file

@ -0,0 +1,21 @@
package postgres
import (
"github.com/go-jet/jet/internal/jet"
)
type ClauseReturning struct {
Projections []jet.Projection
}
func (r *ClauseReturning) Serialize(statementType jet.StatementType, out *jet.SqlBuilder) error {
if len(r.Projections) == 0 {
return nil
}
out.NewLine()
out.WriteString("RETURNING")
out.IncreaseIdent()
return out.WriteProjections(statementType, r.Projections)
}

View file

@ -2,7 +2,7 @@ package postgres
import "github.com/go-jet/jet/internal/jet" import "github.com/go-jet/jet/internal/jet"
type Column jet.Column type Column jet.ColumnExpression
type IColumnList jet.IColumnList type IColumnList jet.IColumnList

View file

@ -15,7 +15,7 @@ type deleteStatementImpl struct {
Delete jet.ClauseStatementBegin Delete jet.ClauseStatementBegin
Where jet.ClauseWhere Where jet.ClauseWhere
Returning jet.ClauseReturning Returning ClauseReturning
} }
func newDeleteStatement(table WritableTable) DeleteStatement { func newDeleteStatement(table WritableTable) DeleteStatement {

View file

@ -13,3 +13,48 @@ func TestString_REGEXP_LIKE_function(t *testing.T) {
assertClauseSerialize(t, REGEXP_LIKE(table3StrCol, String("JOHN"), "c"), "table3.col2 ~ $1", "JOHN") assertClauseSerialize(t, REGEXP_LIKE(table3StrCol, String("JOHN"), "c"), "table3.col2 ~ $1", "JOHN")
assertClauseSerialize(t, REGEXP_LIKE(table3StrCol, String("JOHN"), "i"), "table3.col2 ~* $1", "JOHN") assertClauseSerialize(t, REGEXP_LIKE(table3StrCol, String("JOHN"), "i"), "table3.col2 ~* $1", "JOHN")
} }
func TestExists(t *testing.T) {
assertClauseSerialize(t, EXISTS(
table2.
SELECT(Int(1)).
WHERE(table1Col1.EQ(table2Col3)),
),
`(EXISTS (
SELECT $1
FROM db.table2
WHERE table1.col1 = table2.col3
))`, int64(1))
}
func TestIN(t *testing.T) {
assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)),
`($1 IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
}
func TestNOT_IN(t *testing.T) {
assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)),
`($1 NOT IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) NOT IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
}

View file

@ -19,9 +19,9 @@ type InsertStatement interface {
RETURNING(projections ...jet.Projection) InsertStatement RETURNING(projections ...jet.Projection) InsertStatement
} }
func newInsertStatement(table WritableTable, columns []jet.IColumn) InsertStatement { func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{} newInsert := &insertStatementImpl{}
newInsert.StatementImpl = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newInsert, newInsert.StatementImpl = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert, &newInsert.Values, &newInsert.Select, &newInsert.Returning) &newInsert.Insert, &newInsert.Values, &newInsert.Select, &newInsert.Returning)
newInsert.Insert.Table = table newInsert.Insert.Table = table
@ -36,7 +36,7 @@ type insertStatementImpl struct {
Insert jet.ClauseInsert Insert jet.ClauseInsert
Values jet.ClauseValues Values jet.ClauseValues
Select jet.ClauseQuery Select jet.ClauseQuery
Returning jet.ClauseReturning Returning ClauseReturning
} }
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
@ -45,12 +45,12 @@ func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) I
} }
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromModel(i.getColumns(), data)) i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromModel(i.Insert.GetColumns(), data))
return i return i
} }
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
i.Values.Rows = append(i.Values.Rows, jet.UnwindRowsFromModels(i.getColumns(), data)...) i.Values.Rows = append(i.Values.Rows, jet.UnwindRowsFromModels(i.Insert.GetColumns(), data)...)
return i return i
} }
@ -63,11 +63,3 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState
i.Select.Query = selectStatement i.Select.Query = selectStatement
return i return i
} }
func (i *insertStatementImpl) getColumns() []jet.IColumn {
if len(i.Insert.Columns) > 0 {
return i.Insert.Columns
}
return i.Insert.Table.Columns()
}

View file

@ -25,7 +25,7 @@ type LockStatement interface {
func LOCK(tables ...jet.SerializerTable) LockStatement { func LOCK(tables ...jet.SerializerTable) LockStatement {
newLock := &lockStatementImpl{} newLock := &lockStatementImpl{}
newLock.StatementImpl = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newLock, newLock.StatementImpl = jet.NewStatementImpl(Dialect, jet.LockStatementType, newLock,
&newLock.StatementBegin, &newLock.In, &newLock.NoWait) &newLock.StatementBegin, &newLock.In, &newLock.NoWait)
newLock.StatementBegin.Name = "LOCK TABLE" newLock.StatementBegin.Name = "LOCK TABLE"

View file

@ -23,8 +23,8 @@ type readableTable interface {
} }
type writableTable interface { type writableTable interface {
INSERT(columns ...jet.IColumn) InsertStatement INSERT(columns ...jet.Column) InsertStatement
UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement
DELETE() DeleteStatement DELETE() DeleteStatement
LOCK() LockStatement LOCK() LockStatement
} }
@ -47,12 +47,12 @@ type Table interface {
//table //table
readableTable readableTable
writableTable writableTable
jet.Serializer jet.SerializerTable
//acceptsVisitor //acceptsVisitor
SchemaName() string //SchemaName() string
TableName() string //TableName() string
AS(alias string) //As(alias string)
} }
type readableTableInterfaceImpl struct { type readableTableInterfaceImpl struct {
@ -91,11 +91,11 @@ type writableTableInterfaceImpl struct {
parent WritableTable parent WritableTable
} }
func (w *writableTableInterfaceImpl) INSERT(columns ...jet.IColumn) InsertStatement { func (w *writableTableInterfaceImpl) INSERT(columns ...jet.Column) InsertStatement {
return newInsertStatement(w.parent, jet.UnwidColumnList(columns)) return newInsertStatement(w.parent, jet.UnwidColumnList(columns))
} }
func (w *writableTableInterfaceImpl) UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement { func (w *writableTableInterfaceImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement {
return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...)) return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...))
} }
@ -111,13 +111,13 @@ type table2Impl struct {
readableTableInterfaceImpl readableTableInterfaceImpl
writableTableInterfaceImpl writableTableInterfaceImpl
jet.TableImpl2 jet.TableImpl
} }
func NewTable(schemaName, name string, columns ...jet.Column) Table { func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table {
t := &table2Impl{ t := &table2Impl{
TableImpl2: jet.NewTable2(Dialect, schemaName, name, columns...), TableImpl: jet.NewTable(schemaName, name, columns...),
} }
t.readableTableInterfaceImpl.parent = t t.readableTableInterfaceImpl.parent = t

View file

@ -1,4 +1,4 @@
package jet package postgres
import ( import (
"testing" "testing"

View file

@ -22,10 +22,10 @@ type updateStatementImpl struct {
Update jet.ClauseUpdate Update jet.ClauseUpdate
Set ClauseSet Set ClauseSet
Where jet.ClauseWhere Where jet.ClauseWhere
Returning jet.ClauseReturning Returning ClauseReturning
} }
func newUpdateStatement(table WritableTable, columns []jet.IColumn) UpdateStatement { func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement {
update := &updateStatementImpl{} update := &updateStatementImpl{}
update.StatementImpl = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, update.StatementImpl = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update,
&update.Set, &update.Where, &update.Returning) &update.Set, &update.Where, &update.Returning)
@ -58,7 +58,7 @@ func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateSta
} }
type ClauseSet struct { type ClauseSet struct {
Columns []jet.IColumn Columns []jet.Column
Values []jet.Serializer Values []jet.Serializer
} }