diff --git a/internal/jet/bool_expression_test.go b/internal/jet/bool_expression_test.go index 995d302..5295374 100644 --- a/internal/jet/bool_expression_test.go +++ b/internal/jet/bool_expression_test.go @@ -71,19 +71,6 @@ func TestBoolLiteral(t *testing.T) { assertClauseSerialize(t, Bool(false), "$1", false) } -func TestExists(t *testing.T) { - assertClauseSerialize(t, EXISTS( - table2. - SELECT(Int(1)). - WHERE(table1Col1.EQ(table2Col3)), - ), - `(EXISTS ( - SELECT $1 - FROM db.table2 - WHERE table1.col1 = table2.col3 -))`, int64(1)) -} - func TestBoolExp(t *testing.T) { assertClauseSerialize(t, BoolExp(String("true")), "$1", "true") assertClauseSerialize(t, BoolExp(String("true")).IS_TRUE(), "$1 IS TRUE", "true") diff --git a/internal/jet/cast.go b/internal/jet/cast.go index 886257e..36b3ec8 100644 --- a/internal/jet/cast.go +++ b/internal/jet/cast.go @@ -18,12 +18,12 @@ type CastImpl struct { expression Expression } -func NewCastImpl(expression Expression) CastImpl { +func NewCastImpl(expression Expression) Cast { castImpl := CastImpl{ expression: expression, } - return castImpl + return &castImpl } func (b *CastImpl) AS(castType string) Expression { diff --git a/internal/jet/cast_test.go b/internal/jet/cast_test.go index 996a5f8..b72cede 100644 --- a/internal/jet/cast_test.go +++ b/internal/jet/cast_test.go @@ -1,7 +1,11 @@ package jet -//func TestCastAS(t *testing.T) { -// AssertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST(? AS boolean)", int64(1)) -// AssertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)") -// AssertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)") -//} +import ( + "testing" +) + +func TestCastAS(t *testing.T) { + assertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST($1 AS boolean)", int64(1)) + assertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)") + assertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)") +} diff --git a/internal/jet/clause.go b/internal/jet/clause.go index b92861d..23ba600 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -36,7 +36,7 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SqlBuilder) e return errors.New("jet: no column selected for Projection") } - return out.writeProjections(statementType, s.Projections) + return out.WriteProjections(statementType, s.Projections) } type ClauseFrom struct { @@ -77,9 +77,9 @@ func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SqlBuilder) out.NewLine() out.WriteString("GROUP BY") - out.increaseIdent() + out.IncreaseIdent() err := serializeGroupByClauseList(statementType, c.List, out) - out.decreaseIdent() + out.DecreaseIdent() return err } @@ -173,15 +173,10 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB wrap := s.OrderBy.List != nil || s.Limit.Count >= 0 || s.Offset.Count >= 0 - //if wrap { - // out.WriteString("(") - // out.increaseIdent() - //} - if wrap { out.NewLine() out.WriteString("(") - out.increaseIdent() + out.IncreaseIdent() } for i, selectStmt := range s.Selects { @@ -207,7 +202,7 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB } if wrap { - out.decreaseIdent() + out.DecreaseIdent() out.NewLine() out.WriteString(")") } @@ -224,12 +219,6 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB return err } - //if wrap { - // out.decreaseIdent() - // out.newLine() - // out.WriteString(")") - //} - return nil } @@ -253,7 +242,7 @@ func (u *ClauseUpdate) Serialize(statementType StatementType, out *SqlBuilder) e } type ClauseSet struct { - Columns []IColumn + Columns []Column Values []Serializer } @@ -265,7 +254,7 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) erro return errors.New("jet: mismatch in numers of columns and values") } - out.increaseIdent(4) + out.IncreaseIdent(4) for i, column := range s.Columns { if i > 0 { out.WriteString(", ") @@ -280,26 +269,26 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) erro out.WriteString(" = ") - if err := Serialize(s.Values[i], UpdateStatementType, out); err != nil { + if err := s.Values[i].serialize(UpdateStatementType, out); err != nil { return err } } - out.decreaseIdent(4) + out.DecreaseIdent(4) return nil } -type ClauseReturning struct { - Projections []Projection -} - -func (r *ClauseReturning) Serialize(statementType StatementType, out *SqlBuilder) error { - return out.WriteReturning(statementType, r.Projections) -} - type ClauseInsert struct { Table SerializerTable - Columns []IColumn + Columns []Column +} + +func (i *ClauseInsert) GetColumns() []Column { + if len(i.Columns) > 0 { + return i.Columns + } + + return i.Table.Columns() } func (i *ClauseInsert) Serialize(statementType StatementType, out *SqlBuilder) error { @@ -347,7 +336,7 @@ func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) e out.WriteString(",") } - out.increaseIdent() + out.IncreaseIdent() out.NewLine() out.WriteString("(") @@ -358,7 +347,7 @@ func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) e } out.writeByte(')') - out.decreaseIdent() + out.DecreaseIdent() } return nil } @@ -459,235 +448,3 @@ func (i *ClauseIn) Serialize(statementType StatementType, out *SqlBuilder) error return nil } - -// NewTable creates new table with schema Name, table Name and list of columns -func NewTable2(Dialect Dialect, schemaName, name string, columns ...Column) TableImpl2 { - - t := TableImpl2{ - Dialect: Dialect, - schemaName: schemaName, - name: name, - columnList: columns, - } - - for _, c := range columns { - c.SetTableName(name) - } - - return t -} - -type TableImpl2 struct { - Dialect Dialect - schemaName string - name string - alias string - columnList []Column -} - -func (t *TableImpl2) AS(alias string) { - t.alias = alias - - for _, c := range t.columnList { - c.SetTableName(alias) - } -} - -func (t *TableImpl2) SchemaName() string { - return t.schemaName -} - -func (t *TableImpl2) TableName() string { - return t.name -} - -func (t *TableImpl2) Columns() []IColumn { - ret := []IColumn{} - - for _, col := range t.columnList { - ret = append(ret, col) - } - - return ret -} - -func (t *TableImpl2) dialect() Dialect { - return t.Dialect -} - -func (t *TableImpl2) accept(visitor visitor) { - visitor.visit(t) -} - -func (t *TableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { - if t == nil { - return errors.New("jet: tableImpl is nil. ") - } - - out.writeIdentifier(t.schemaName) - out.WriteString(".") - out.writeIdentifier(t.name) - - if len(t.alias) > 0 { - out.WriteString("AS") - out.writeIdentifier(t.alias) - } - - return nil -} - -// Join expressions are pseudo readable tables. -type JoinTableImpl struct { - lhs Serializer - rhs Serializer - joinType JoinType - onCondition BoolExpression -} - -func NewJoinTableImpl(lhs Serializer, rhs Serializer, joinType JoinType, onCondition BoolExpression) JoinTableImpl { - - joinTable := JoinTableImpl{ - lhs: lhs, - rhs: rhs, - joinType: joinType, - onCondition: onCondition, - } - - return joinTable -} - -func (t *JoinTableImpl) SchemaName() string { - return "" -} - -func (t *JoinTableImpl) TableName() string { - return "" -} - -func (t *JoinTableImpl) Columns() []IColumn { - //return append(t.lhs.columns(), t.rhs.columns()...) - panic("Unimplemented") -} - -func (t *JoinTableImpl) accept(visitor visitor) { - //t.lhs.accept(visitor) - //t.rhs.accept(visitor) - //TODO: uncoment -} - -func (t *JoinTableImpl) dialect() Dialect { - return detectDialect(t) -} - -func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) { - if t == nil { - return errors.New("jet: Join table is nil. ") - } - - if utils.IsNil(t.lhs) { - return errors.New("jet: left hand side of join operation is nil table") - } - - if err = t.lhs.serialize(statement, out); err != nil { - return - } - - out.NewLine() - - switch t.joinType { - case InnerJoin: - out.WriteString("INNER JOIN") - case LeftJoin: - out.WriteString("LEFT JOIN") - case RightJoin: - out.WriteString("RIGHT JOIN") - case FullJoin: - out.WriteString("FULL JOIN") - case CrossJoin: - out.WriteString("CROSS JOIN") - } - - if utils.IsNil(t.rhs) { - return errors.New("jet: right hand side of join operation is nil table") - } - - if err = t.rhs.serialize(statement, out); err != nil { - return - } - - if t.onCondition == nil && t.joinType != CrossJoin { - return errors.New("jet: join condition is nil") - } - - if t.onCondition != nil { - out.WriteString("ON") - if err = t.onCondition.serialize(statement, out); err != nil { - return - } - } - - return nil -} - -// SelectTable is interface for SELECT sub-queries -type SelectTable interface { - Alias() string - AllColumns() ProjectionList -} - -type SelectTableImpl2 struct { - selectStmt StatementWithProjections - alias string - - projections []Projection -} - -func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl2 { - selectTable := SelectTableImpl2{selectStmt: selectStmt, alias: alias} - - for _, projection := range selectStmt.projections() { - newProjection := projection.fromImpl(&selectTable) - - selectTable.projections = append(selectTable.projections, newProjection) - } - - return selectTable -} - -func (s *SelectTableImpl2) Alias() string { - return s.alias -} - -func (s *SelectTableImpl2) Columns() []IColumn { - return nil -} - -func (s *SelectTableImpl2) accept(visitor visitor) { - visitor.visit(s) - s.selectStmt.accept(visitor) -} - -func (s *SelectTableImpl2) dialect() Dialect { - return detectDialect(s.selectStmt) -} - -func (s *SelectTableImpl2) AllColumns() ProjectionList { - return s.projections -} - -func (s *SelectTableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { - if s == nil { - return errors.New("jet: Expression table is nil. ") - } - - err := s.selectStmt.serialize(statement, out) - - if err != nil { - return err - } - - out.WriteString("AS") - out.writeIdentifier(s.alias) - - return nil -} diff --git a/internal/jet/column.go b/internal/jet/column.go index 0783d5a..22cf4f3 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -2,7 +2,7 @@ package jet -type IColumn interface { +type Column interface { Name() string TableName() string @@ -12,9 +12,9 @@ type IColumn interface { } // Column is common column interface for all types of columns. -type Column interface { +type ColumnExpression interface { + Column Expression - IColumn } // The base type for real materialized columns. @@ -28,7 +28,7 @@ type columnImpl struct { subQuery SelectTable } -func newColumn(name string, tableName string, parent Column) columnImpl { +func newColumn(name string, tableName string, parent ColumnExpression) columnImpl { bc := columnImpl{ name: name, tableName: tableName, @@ -109,19 +109,19 @@ func (c columnImpl) serialize(statement StatementType, out *SqlBuilder, options type IColumnList interface { Projection - IColumn + Column - Columns() []Column + Columns() []ColumnExpression } -func ColumnList(columns ...Column) IColumnList { +func ColumnList(columns ...ColumnExpression) IColumnList { return columnListImpl(columns) } // ColumnList is redefined type to support list of columns as single Projection -type columnListImpl []Column +type columnListImpl []ColumnExpression -func (cl columnListImpl) Columns() []Column { +func (cl columnListImpl) Columns() []ColumnExpression { return cl } diff --git a/internal/jet/column_types.go b/internal/jet/column_types.go index baff46c..c82a5a1 100644 --- a/internal/jet/column_types.go +++ b/internal/jet/column_types.go @@ -3,7 +3,7 @@ package jet // ColumnBool is interface for SQL boolean columns. type ColumnBool interface { BoolExpression - IColumn + Column From(subQuery SelectTable) ColumnBool } @@ -42,7 +42,7 @@ func BoolColumn(name string) ColumnBool { // ColumnFloat is interface for SQL real, numeric, decimal or double precision column. type ColumnFloat interface { FloatExpression - IColumn + Column From(subQuery SelectTable) ColumnFloat } @@ -80,7 +80,7 @@ func FloatColumn(name string) ColumnFloat { // ColumnInteger is interface for SQL smallint, integer, bigint columns. type ColumnInteger interface { IntegerExpression - IColumn + Column From(subQuery SelectTable) ColumnInteger } @@ -118,7 +118,7 @@ func IntegerColumn(name string) ColumnInteger { // bytea, uuid columns and enums types. type ColumnString interface { StringExpression - IColumn + Column From(subQuery SelectTable) ColumnString } @@ -155,7 +155,7 @@ func StringColumn(name string) ColumnString { // ColumnTime is interface for SQL time column. type ColumnTime interface { TimeExpression - IColumn + Column From(subQuery SelectTable) ColumnTime } @@ -190,7 +190,7 @@ func TimeColumn(name string) ColumnTime { // ColumnTimez is interface of SQL time with time zone columns. type ColumnTimez interface { TimezExpression - IColumn + Column From(subQuery SelectTable) ColumnTimez } @@ -227,7 +227,7 @@ func TimezColumn(name string) ColumnTimez { // ColumnTimestamp is interface of SQL timestamp columns. type ColumnTimestamp interface { TimestampExpression - IColumn + Column From(subQuery SelectTable) ColumnTimestamp } @@ -264,7 +264,7 @@ func TimestampColumn(name string) ColumnTimestamp { // ColumnTimestampz is interface of SQL timestamp with timezone columns. type ColumnTimestampz interface { TimestampzExpression - IColumn + Column From(subQuery SelectTable) ColumnTimestampz } @@ -301,7 +301,7 @@ func TimestampzColumn(name string) ColumnTimestampz { // ColumnDate is interface of SQL date columns. type ColumnDate interface { DateExpression - IColumn + Column From(subQuery SelectTable) ColumnDate } diff --git a/internal/jet/column_types_test.go b/internal/jet/column_types_test.go index df8c241..2e12ef1 100644 --- a/internal/jet/column_types_test.go +++ b/internal/jet/column_types_test.go @@ -4,7 +4,9 @@ import ( "testing" ) -var subQuery = table1.SELECT(table1ColFloat, table1ColInt).AsTable("sub_query") +var subQuery = &SelectTableImpl2{ + alias: "sub_query", +} func TestNewBoolColumn(t *testing.T) { boolColumn := BoolColumn("colBool").From(subQuery) diff --git a/internal/jet/delete_statement.go b/internal/jet/delete_statement.go deleted file mode 100644 index fe6f448..0000000 --- a/internal/jet/delete_statement.go +++ /dev/null @@ -1,110 +0,0 @@ -package jet - -import ( - "context" - "database/sql" - "errors" - "github.com/go-jet/jet/execution" -) - -// DeleteStatement is interface for SQL DELETE statement -type DeleteStatement interface { - Statement - - WHERE(expression BoolExpression) DeleteStatement - - RETURNING(projections ...Projection) DeleteStatement -} - -func newDeleteStatement(table WritableTable) DeleteStatement { - return &deleteStatementImpl{ - table: table, - } -} - -type deleteStatementImpl struct { - table WritableTable - where BoolExpression - returning []Projection -} - -func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement { - d.where = expression - return d -} - -func (d *deleteStatementImpl) RETURNING(projections ...Projection) DeleteStatement { - d.returning = projections - return d -} - -func (d *deleteStatementImpl) accept(visitor visitor) { - visitor.visit(d) - - d.table.accept(visitor) -} - -func (d *deleteStatementImpl) serializeImpl(out *SqlBuilder) error { - if d == nil { - return errors.New("jet: delete statement is nil") - } - out.NewLine() - out.WriteString("DELETE FROM") - - if d.table == nil { - return errors.New("jet: nil tableName") - } - - if err := d.table.serialize(DeleteStatementType, out); err != nil { - return err - } - - if d.where == nil { - return errors.New("jet: deleting without a WHERE clause") - } - - if err := out.writeWhere(DeleteStatementType, d.where); err != nil { - return err - } - - if err := out.WriteReturning(DeleteStatementType, d.returning); err != nil { - return err - } - - return nil -} - -func (d *deleteStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { - queryData := &SqlBuilder{ - Dialect: detectDialect(d, dialect...), - } - - err = d.serializeImpl(queryData) - - if err != nil { - return - } - - query, args = queryData.finalize() - return -} - -func (d *deleteStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) { - return debugSql(d, dialect...) -} - -func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error { - return query(d, db, destination) -} - -func (d *deleteStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error { - return queryContext(context, d, db, destination) -} - -func (d *deleteStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { - return exec(d, db) -} - -func (d *deleteStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) { - return execContext(context, d, db) -} diff --git a/internal/jet/delete_statement_test.go b/internal/jet/delete_statement_test.go deleted file mode 100644 index 6de033a..0000000 --- a/internal/jet/delete_statement_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package jet - -import ( - "testing" -) - -func TestDeleteUnconditionally(t *testing.T) { - assertStatementErr(t, table1.DELETE(), `jet: deleting without a WHERE clause`) - assertStatementErr(t, table1.DELETE().WHERE(nil), `jet: deleting without a WHERE clause`) -} - -func TestDeleteWithWhere(t *testing.T) { - assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), ` -DELETE FROM db.table1 -WHERE table1.col1 = $1; -`, int64(1)) -} - -func TestDeleteWithWhereAndReturning(t *testing.T) { - assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).RETURNING(table1Col1), ` -DELETE FROM db.table1 -WHERE table1.col1 = $1 -RETURNING table1.col1 AS "table1.col1"; -`, int64(1)) -} diff --git a/internal/jet/dialects.go b/internal/jet/dialect.go similarity index 85% rename from internal/jet/dialects.go rename to internal/jet/dialect.go index 912ced1..7fc1111 100644 --- a/internal/jet/dialects.go +++ b/internal/jet/dialect.go @@ -1,17 +1,5 @@ package jet -import ( - "strconv" -) - -var ANSII = NewDialect(DialectParams{ // just for tests - AliasQuoteChar: '"', - IdentifierQuoteChar: '"', - ArgumentPlaceholder: func(ord int) string { - return "$" + strconv.Itoa(ord) - }, -}) - type Dialect interface { Name() string PackageName() string @@ -25,7 +13,7 @@ type SerializeFunc func(statement StatementType, out *SqlBuilder, options ...Ser type SerializeOverride func(expressions ...Expression) SerializeFunc type QueryPlaceholderFunc func(ord int) string -type UpdateAssigmentFunc func(columns []IColumn, values []Serializer, out *SqlBuilder) (err error) +type UpdateAssigmentFunc func(columns []Column, values []Serializer, out *SqlBuilder) (err error) type DialectParams struct { Name string diff --git a/internal/jet/expression_test.go b/internal/jet/expression_test.go index da0e033..8db4cf9 100644 --- a/internal/jet/expression_test.go +++ b/internal/jet/expression_test.go @@ -26,33 +26,14 @@ func TestExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { } func TestIN(t *testing.T) { + assertClauseSerialize(t, table2ColInt.IN(Int(1), Int(2), Int(3)), + `(table2.col_int IN ($1, $2, $3))`, int64(1), int64(2), int64(3)) - assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), - `($1 IN (( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 -)))`, float64(1.11)) - - assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) IN (( - SELECT table2.col3 AS "table2.col3", - table3.col1 AS "table3.col1" - FROM db.table2 -)))`, int64(12)) } func TestNOT_IN(t *testing.T) { - assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), - `($1 NOT IN (( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 -)))`, float64(1.11)) + assertClauseSerialize(t, table2ColInt.NOT_IN(Int(1), Int(2), Int(3)), + `(table2.col_int NOT IN ($1, $2, $3))`, int64(1), int64(2), int64(3)) - assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) NOT IN (( - SELECT table2.col3 AS "table2.col3", - table3.col1 AS "table3.col1" - FROM db.table2 -)))`, int64(12)) } diff --git a/internal/jet/insert_statement.go b/internal/jet/insert_statement.go deleted file mode 100644 index 3c299be..0000000 --- a/internal/jet/insert_statement.go +++ /dev/null @@ -1,180 +0,0 @@ -package jet - -import ( - "context" - "database/sql" - "errors" - "github.com/go-jet/jet/execution" - "github.com/go-jet/jet/internal/utils" -) - -// InsertStatement is interface for SQL INSERT statements -type InsertStatement interface { - Statement - - // Insert row of values - VALUES(value interface{}, values ...interface{}) InsertStatement - // Insert row of values, where value for each column is extracted from filed of structure data. - // If data is not struct or there is no field for every column selected, this method will panic. - MODEL(data interface{}) InsertStatement - - MODELS(data interface{}) InsertStatement - - QUERY(selectStatement SelectStatement) InsertStatement - - RETURNING(projections ...Projection) InsertStatement -} - -func newInsertStatement(t WritableTable, columns []IColumn) InsertStatement { - return &insertStatementImpl{ - table: t, - columns: columns, - } -} - -type insertStatementImpl struct { - table WritableTable - columns []IColumn - rows [][]Serializer - query SelectStatement - returning []Projection -} - -func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { - i.rows = append(i.rows, UnwindRowFromValues(value, values)) - return i -} - -func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { - i.rows = append(i.rows, UnwindRowFromModel(i.getColumns(), data)) - return i -} - -func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { - i.rows = append(i.rows, UnwindRowsFromModels(i.getColumns(), data)...) - return i -} - -func (i *insertStatementImpl) RETURNING(projections ...Projection) InsertStatement { - i.returning = projections - return i -} - -func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { - i.query = selectStatement - return i -} - -func (i *insertStatementImpl) getColumns() []IColumn { - if len(i.columns) > 0 { - return i.columns - } - - return i.table.columns() -} - -func (i *insertStatementImpl) accept(visitor visitor) { - visitor.visit(i) - - i.table.accept(visitor) -} - -func (i *insertStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) { - return debugSql(i, dialect...) -} - -func (i *insertStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { - out := &SqlBuilder{ - Dialect: detectDialect(i, dialect...), - } - - out.NewLine() - out.WriteString("INSERT INTO") - - if utils.IsNil(i.table) { - return "", nil, errors.New("jet: table is nil") - } - - err = i.table.serialize(InsertStatementType, out) - - if err != nil { - return - } - - if len(i.columns) > 0 { - out.WriteString("(") - - err = SerializeColumnNames(i.columns, out) - - if err != nil { - return - } - - out.WriteString(")") - } - - //TODO: - - if len(i.rows) == 0 && i.query == nil { - return "", nil, errors.New("jet: no row values or query specified") - } - - if len(i.rows) > 0 && i.query != nil { - return "", nil, errors.New("jet: only row values or query has to be specified") - } - - if len(i.rows) > 0 { - out.WriteString("VALUES") - - for rowIndex, row := range i.rows { - if rowIndex > 0 { - out.WriteString(",") - } - - out.increaseIdent() - out.NewLine() - out.WriteString("(") - - err = SerializeClauseList(InsertStatementType, row, out) - - if err != nil { - return "", nil, err - } - - out.writeByte(')') - out.decreaseIdent() - } - } - - if i.query != nil { - err = i.query.serialize(InsertStatementType, out) - - if err != nil { - return - } - } - - if err = out.WriteReturning(InsertStatementType, i.returning); err != nil { - return - } - - query, args = out.finalize() - - return -} - -func (i *insertStatementImpl) Query(db execution.DB, destination interface{}) error { - return query(i, db, destination) -} - -func (i *insertStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error { - return queryContext(context, i, db, destination) -} - -func (i *insertStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { - return exec(i, db) -} - -func (i *insertStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) { - return execContext(context, i, db) -} diff --git a/internal/jet/insert_statement_test.go b/internal/jet/insert_statement_test.go deleted file mode 100644 index 95679f5..0000000 --- a/internal/jet/insert_statement_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package jet - -import ( - "gotest.tools/assert" - "testing" - "time" -) - -func TestInvalidInsert(t *testing.T) { - assertStatementErr(t, table1.INSERT(table1Col1), "jet: no row values or query specified") - assertStatementErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") -} - -func TestInsertNilValue(t *testing.T) { - assertStatement(t, table1.INSERT(table1Col1).VALUES(nil), ` -INSERT INTO db.table1 (col1) VALUES - ($1); -`, nil) -} - -func TestInsertSingleValue(t *testing.T) { - assertStatement(t, table1.INSERT(table1Col1).VALUES(1), ` -INSERT INTO db.table1 (col1) VALUES - ($1); -`, int(1)) -} - -func TestInsertWithColumnList(t *testing.T) { - columnList := ColumnList(table3ColInt, table3StrCol) - - assertStatement(t, table3.INSERT(columnList).VALUES(1, 3), ` -INSERT INTO db.table3 (col_int, col2) VALUES - ($1, $2); -`, 1, 3) -} - -func TestInsertDate(t *testing.T) { - date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) - - assertStatement(t, table1.INSERT(table1ColTime).VALUES(date), ` -INSERT INTO db.table1 (col_time) VALUES - ($1); -`, date) -} - -func TestInsertMultipleValues(t *testing.T) { - assertStatement(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` -INSERT INTO db.table1 (col1, col_float, col3) VALUES - ($1, $2, $3); -`, 1, 2, 3) -} - -func TestInsertMultipleRows(t *testing.T) { - stmt := table1.INSERT(table1Col1, table1ColFloat). - VALUES(1, 2). - VALUES(11, 22). - VALUES(111, 222) - - assertStatement(t, stmt, ` -INSERT INTO db.table1 (col1, col_float) VALUES - ($1, $2), - ($3, $4), - ($5, $6); -`, 1, 2, 11, 22, 111, 222) -} - -func TestInsertValuesFromModel(t *testing.T) { - type Table1Model struct { - Col1 *int - ColFloat float64 - } - - one := 1 - - toInsert := Table1Model{ - Col1: &one, - ColFloat: 1.11, - } - - stmt := table1.INSERT(table1Col1, table1ColFloat). - MODEL(toInsert). - MODEL(&toInsert) - - expectedSQL := ` -INSERT INTO db.table1 (col1, col_float) -VALUES ($1, $2), - ($3, $4); -` - - assertStatement(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) -} - -func TestInsertValuesFromModelColumnMismatch(t *testing.T) { - defer func() { - r := recover() - assert.Equal(t, r, "missing struct field for column : col1") - }() - type Table1Model struct { - Col1Prim int - Col2 string - } - - newData := Table1Model{ - Col1Prim: 1, - Col2: "one", - } - - table1. - INSERT(table1Col1, table1ColFloat). - MODEL(newData) -} - -func TestInsertFromNonStructModel(t *testing.T) { - - defer func() { - r := recover() - assert.Equal(t, r, "argument mismatch: expected struct, got []int") - }() - - table2.INSERT(table2ColInt).MODEL([]int{}) -} - -func TestInsertQuery(t *testing.T) { - - stmt := table1.INSERT(table1Col1). - QUERY(table1.SELECT(table1Col1)) - - var expectedSQL = ` -INSERT INTO db.table1 (col1) ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 -); -` - assertStatement(t, stmt, expectedSQL) -} - -func TestInsertDefaultValue(t *testing.T) { - stmt := table1.INSERT(table1Col1, table1ColFloat). - VALUES(DEFAULT, "two") - - var expectedSQL = ` -INSERT INTO db.table1 (col1, col_float) VALUES - (DEFAULT, $1); -` - - assertStatement(t, stmt, expectedSQL, "two") -} diff --git a/internal/jet/lock_statement.go b/internal/jet/lock_statement.go deleted file mode 100644 index f49f2bc..0000000 --- a/internal/jet/lock_statement.go +++ /dev/null @@ -1,112 +0,0 @@ -package jet - -import ( - "context" - "database/sql" - "errors" - "github.com/go-jet/jet/execution" -) - -// TableLockMode is a type of possible SQL table lock -type TableLockMode string - -// LockStatement interface for SQL LOCK statement -type LockStatement interface { - Statement - - IN(lockMode string) LockStatement - NOWAIT() LockStatement -} - -type lockStatementImpl struct { - tables []WritableTable - lockMode string - nowait bool -} - -// LOCK creates lock statement for list of tables. -func LOCK(tables ...WritableTable) LockStatement { - return &lockStatementImpl{ - tables: tables, - } -} - -func (l *lockStatementImpl) IN(lockMode string) LockStatement { - l.lockMode = lockMode - return l -} - -func (l *lockStatementImpl) NOWAIT() LockStatement { - l.nowait = true - return l -} - -func (l *lockStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) { - return debugSql(l, dialect...) -} - -func (l *lockStatementImpl) accept(visitor visitor) { - visitor.visit(l) - - for _, table := range l.tables { - table.accept(visitor) - } -} - -func (l *lockStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { - if l == nil { - return "", nil, errors.New("jet: nil Statement") - } - - if len(l.tables) == 0 { - return "", nil, errors.New("jet: There is no table selected to be locked") - } - - out := &SqlBuilder{ - Dialect: detectDialect(l, dialect...), - } - - out.NewLine() - out.WriteString("LOCK TABLE") - - for i, table := range l.tables { - if i > 0 { - out.WriteString(", ") - } - - err := table.serialize(LockStatementType, out) - - if err != nil { - return "", nil, err - } - } - - if l.lockMode != "" { - out.WriteString("IN") - out.WriteString(string(l.lockMode)) - out.WriteString("MODE") - } - - if l.nowait { - out.WriteString("NOWAIT") - } - - query, args = out.finalize() - return -} - -func (l *lockStatementImpl) Query(db execution.DB, destination interface{}) error { - return query(l, db, destination) -} - -func (l *lockStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error { - return queryContext(context, l, db, destination) -} - -func (l *lockStatementImpl) Exec(db execution.DB) (sql.Result, error) { - return exec(l, db) -} - -func (l *lockStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) { - return execContext(context, l, db) -} diff --git a/internal/jet/operators.go b/internal/jet/operators.go index ff0b301..ab837f8 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -17,7 +17,7 @@ func BIT_NOT(expr IntegerExpression) IntegerExpression { //----------- Comparison operators ---------------// // EXISTS checks for existence of the rows in subQuery -func EXISTS(subQuery SelectStatement) BoolExpression { +func EXISTS(subQuery Expression) BoolExpression { return newPrefixBoolOperator(subQuery, "EXISTS") } diff --git a/internal/jet/select_lock.go b/internal/jet/select_lock.go new file mode 100644 index 0000000..8e79b9c --- /dev/null +++ b/internal/jet/select_lock.go @@ -0,0 +1,48 @@ +package jet + +// SelectLock is interface for SELECT statement locks +type SelectLock interface { + Serializer + + NOWAIT() SelectLock + SKIP_LOCKED() SelectLock +} + +type selectLockImpl struct { + lockStrength string + noWait, skipLocked bool +} + +func NewSelectLock(name string) func() SelectLock { + return func() SelectLock { + return newSelectLock(name) + } +} + +func newSelectLock(lockStrength string) SelectLock { + return &selectLockImpl{lockStrength: lockStrength} +} + +func (s *selectLockImpl) NOWAIT() SelectLock { + s.noWait = true + return s +} + +func (s *selectLockImpl) SKIP_LOCKED() SelectLock { + s.skipLocked = true + return s +} + +func (s *selectLockImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { + out.WriteString(s.lockStrength) + + if s.noWait { + out.WriteString("NOWAIT") + } + + if s.skipLocked { + out.WriteString("SKIP LOCKED") + } + + return nil +} diff --git a/internal/jet/select_statement.go b/internal/jet/select_statement.go deleted file mode 100644 index 77e2d9a..0000000 --- a/internal/jet/select_statement.go +++ /dev/null @@ -1,355 +0,0 @@ -package jet - -import ( - "context" - "database/sql" - "errors" - "github.com/go-jet/jet/execution" -) - -// SelectStatement is interface for SQL SELECT statements -type SelectStatement interface { - Statement - IExpression - - DISTINCT() SelectStatement - FROM(table ReadableTable) SelectStatement - WHERE(expression BoolExpression) SelectStatement - GROUP_BY(groupByClauses ...GroupByClause) SelectStatement - HAVING(boolExpression BoolExpression) SelectStatement - ORDER_BY(orderByClauses ...OrderByClause) SelectStatement - LIMIT(limit int64) SelectStatement - OFFSET(offset int64) SelectStatement - FOR(lock SelectLock) SelectStatement - - UNION(rhs SelectStatement) SelectStatement - UNION_ALL(rhs SelectStatement) SelectStatement - INTERSECT(rhs SelectStatement) SelectStatement - INTERSECT_ALL(rhs SelectStatement) SelectStatement - EXCEPT(rhs SelectStatement) SelectStatement - EXCEPT_ALL(rhs SelectStatement) SelectStatement - - AsTable(alias string) SelectTable - - projections() []Projection -} - -//SELECT creates new SelectStatement with list of projections -func SELECT(projection1 Projection, projections ...Projection) SelectStatement { - return newSelectStatement(nil, append([]Projection{projection1}, projections...)) -} - -type selectStatementImpl struct { - ExpressionInterfaceImpl - parent SelectStatement - - table ReadableTable - distinct bool - projectionList []Projection - where BoolExpression - groupBy []GroupByClause - having BoolExpression - - orderBy []OrderByClause - limit, offset int64 - - lockFor SelectLock -} - -func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { - newSelect := &selectStatementImpl{ - table: table, - projectionList: projections, - limit: -1, - offset: -1, - distinct: false, - } - - newSelect.ExpressionInterfaceImpl.Parent = newSelect - newSelect.parent = newSelect - - return newSelect -} - -func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement { - s.table = table - return s.parent -} - -func (s *selectStatementImpl) AsTable(alias string) SelectTable { - //return newSelectTable(s.parent, alias) - panic("UNimplemented.") -} - -func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { - s.where = expression - return s.parent -} - -func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement { - s.groupBy = groupByClauses - return s.parent -} - -func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement { - s.having = expression - return s.parent -} - -func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement { - s.orderBy = clauses - return s.parent -} - -func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { - s.offset = offset - return s.parent -} - -func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { - s.limit = limit - return s.parent -} - -func (s *selectStatementImpl) DISTINCT() SelectStatement { - s.distinct = true - return s.parent -} - -func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement { - s.lockFor = lock - return s.parent -} - -func (s *selectStatementImpl) UNION(rhs SelectStatement) SelectStatement { - return UNION(s.parent, rhs) -} - -func (s *selectStatementImpl) UNION_ALL(rhs SelectStatement) SelectStatement { - return UNION_ALL(s.parent, rhs) -} - -func (s *selectStatementImpl) INTERSECT(rhs SelectStatement) SelectStatement { - return INTERSECT(s.parent, rhs) -} - -func (s *selectStatementImpl) INTERSECT_ALL(rhs SelectStatement) SelectStatement { - return INTERSECT_ALL(s.parent, rhs) -} - -func (s *selectStatementImpl) EXCEPT(rhs SelectStatement) SelectStatement { - return EXCEPT(s.parent, rhs) -} - -func (s *selectStatementImpl) EXCEPT_ALL(rhs SelectStatement) SelectStatement { - return EXCEPT_ALL(s.parent, rhs) -} - -func (s *selectStatementImpl) projections() []Projection { - return s.projectionList -} - -func (s *selectStatementImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { - if s == nil { - return errors.New("jet: Select expression is nil. ") - } - out.WriteString("(") - - out.increaseIdent() - err := s.serializeImpl(out) - out.decreaseIdent() - - if err != nil { - return err - } - - out.NewLine() - out.WriteString(")") - - return nil -} - -func (s *selectStatementImpl) serializeImpl(out *SqlBuilder) error { - if s == nil { - return errors.New("jet: Select expression is nil. ") - } - - out.NewLine() - out.WriteString("SELECT") - - if s.distinct { - out.WriteString("DISTINCT") - } - - if len(s.projectionList) == 0 { - return errors.New("jet: no column selected for Projection") - } - - err := out.writeProjections(SelectStatementType, s.projectionList) - - if err != nil { - return err - } - - if s.table != nil { - if err := out.writeFrom(SelectStatementType, s.table); err != nil { - return err - } - } - - if s.where != nil { - err := out.writeWhere(SelectStatementType, s.where) - - if err != nil { - return nil - } - } - - if s.groupBy != nil && len(s.groupBy) > 0 { - err := out.writeGroupBy(SelectStatementType, s.groupBy) - - if err != nil { - return err - } - } - - if s.having != nil { - err := out.writeHaving(SelectStatementType, s.having) - - if err != nil { - return err - } - } - - if s.orderBy != nil { - err := out.writeOrderBy(SelectStatementType, s.orderBy) - - if err != nil { - return err - } - } - - if s.limit >= 0 { - out.NewLine() - out.WriteString("LIMIT") - out.insertParametrizedArgument(s.limit) - } - - if s.offset >= 0 { - out.NewLine() - out.WriteString("OFFSET") - out.insertParametrizedArgument(s.offset) - } - - if s.lockFor != nil { - out.NewLine() - out.WriteString("FOR") - err := s.lockFor.serialize(SelectStatementType, out) - - if err != nil { - return err - } - } - - return nil -} - -func (s *selectStatementImpl) accept(visitor visitor) { - visitor.visit(s) - - if s.table != nil { - s.table.accept(visitor) - } - - if s.where != nil { - s.where.accept(visitor) - } - - if s.having != nil { - s.having.accept(visitor) - } -} - -func (s *selectStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { - - queryData := &SqlBuilder{ - Dialect: detectDialect(s, dialect...), - } - - err = s.serializeImpl(queryData) - - if err != nil { - return "", nil, err - } - - query, args = queryData.finalize() - - return -} - -func (s *selectStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) { - return debugSql(s.parent, dialect...) -} - -func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error { - return query(s.parent, db, destination) -} - -func (s *selectStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error { - return queryContext(context, s.parent, db, destination) -} - -func (s *selectStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { - return exec(s.parent, db) -} - -func (s *selectStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) { - return execContext(context, s.parent, db) -} - -// SelectLock is interface for SELECT statement locks -type SelectLock interface { - Serializer - - NOWAIT() SelectLock - SKIP_LOCKED() SelectLock -} - -type selectLockImpl struct { - lockStrength string - noWait, skipLocked bool -} - -func NewSelectLock(name string) func() SelectLock { - return func() SelectLock { - return newSelectLock(name) - } -} - -func newSelectLock(lockStrength string) SelectLock { - return &selectLockImpl{lockStrength: lockStrength} -} - -func (s *selectLockImpl) NOWAIT() SelectLock { - s.noWait = true - return s -} - -func (s *selectLockImpl) SKIP_LOCKED() SelectLock { - s.skipLocked = true - return s -} - -func (s *selectLockImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { - out.WriteString(s.lockStrength) - - if s.noWait { - out.WriteString("NOWAIT") - } - - if s.skipLocked { - out.WriteString("SKIP LOCKED") - } - - return nil -} diff --git a/internal/jet/select_statement_test.go b/internal/jet/select_statement_test.go deleted file mode 100644 index 4c79c08..0000000 --- a/internal/jet/select_statement_test.go +++ /dev/null @@ -1,198 +0,0 @@ -package jet - -import "testing" - -func TestInvalidSelect(t *testing.T) { - assertStatementErr(t, SELECT(nil), "jet: Projection is nil") -} - -func TestSelectColumnList(t *testing.T) { - columnList := ColumnList(table2ColInt, table2ColFloat, table3ColInt) - - assertStatement(t, SELECT(columnList).FROM(table2), ` -SELECT table2.col_int AS "table2.col_int", - table2.col_float AS "table2.col_float", - table3.col_int AS "table3.col_int" -FROM db.table2; -`) -} - -func TestSelectLiterals(t *testing.T) { - assertStatement(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` -SELECT $1, - $2, - $3 -FROM db.table1; -`, int64(1), 2.2, false) -} - -func TestSelectDistinct(t *testing.T) { - assertStatement(t, SELECT(table1ColBool).DISTINCT().FROM(table1), ` -SELECT DISTINCT table1.col_bool AS "table1.col_bool" -FROM db.table1; -`) -} - -func TestSelectFrom(t *testing.T) { - assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), ` -SELECT table1.col_int AS "table1.col_int", - table2.col_float AS "table2.col_float" -FROM db.table1; -`) - assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), ` -SELECT table1.col_int AS "table1.col_int", - table2.col_float AS "table2.col_float" -FROM db.table1 - INNER JOIN db.table2 ON (table1.col_int = table2.col_int); -`) - assertStatement(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), ` -SELECT table1.col_int AS "table1.col_int", - table2.col_float AS "table2.col_float" -FROM db.table1 - INNER JOIN db.table2 ON (table1.col_int = table2.col_int); -`) -} - -func TestSelectWhere(t *testing.T) { - assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` -SELECT table1.col_int AS "table1.col_int" -FROM db.table1 -WHERE $1; -`, true) - assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` -SELECT table1.col_int AS "table1.col_int" -FROM db.table1 -WHERE table1.col_int >= $1; -`, int64(10)) -} - -func TestSelectGroupBy(t *testing.T) { - assertStatement(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), ` -SELECT table2.col_int AS "table2.col_int" -FROM db.table2 -GROUP BY table2.col_float; -`) -} - -func TestSelectHaving(t *testing.T) { - assertStatement(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` -SELECT table3.col_int AS "table3.col_int" -FROM db.table3 -HAVING table1.col_bool = $1; -`, true) -} - -func TestSelectOrderBy(t *testing.T) { - assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), ` -SELECT table2.col_float AS "table2.col_float" -FROM db.table2 -ORDER BY table2.col_int DESC; -`) - assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), ` -SELECT table2.col_float AS "table2.col_float" -FROM db.table2 -ORDER BY table2.col_int DESC, table2.col_int ASC; -`) -} - -func TestSelectLimitOffset(t *testing.T) { - assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), ` -SELECT table2.col_int AS "table2.col_int" -FROM db.table2 -LIMIT $1; -`, int64(10)) - assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), ` -SELECT table2.col_int AS "table2.col_int" -FROM db.table2 -LIMIT $1 -OFFSET $2; -`, int64(10), int64(2)) -} - -func TestSelectSets(t *testing.T) { - select1 := SELECT(table1ColBool).FROM(table1) - select2 := SELECT(table2ColBool).FROM(table2) - - assertStatement(t, select1.UNION(select2), ` -( - ( - SELECT table1.col_bool AS "table1.col_bool" - FROM db.table1 - ) - UNION - ( - SELECT table2.col_bool AS "table2.col_bool" - FROM db.table2 - ) -); -`) - assertStatement(t, select1.UNION_ALL(select2), ` -( - ( - SELECT table1.col_bool AS "table1.col_bool" - FROM db.table1 - ) - UNION ALL - ( - SELECT table2.col_bool AS "table2.col_bool" - FROM db.table2 - ) -); -`) - - assertStatement(t, select1.INTERSECT(select2), ` -( - ( - SELECT table1.col_bool AS "table1.col_bool" - FROM db.table1 - ) - INTERSECT - ( - SELECT table2.col_bool AS "table2.col_bool" - FROM db.table2 - ) -); -`) - - assertStatement(t, select1.INTERSECT_ALL(select2), ` -( - ( - SELECT table1.col_bool AS "table1.col_bool" - FROM db.table1 - ) - INTERSECT ALL - ( - SELECT table2.col_bool AS "table2.col_bool" - FROM db.table2 - ) -); -`) - assertStatement(t, select1.EXCEPT(select2), ` -( - ( - SELECT table1.col_bool AS "table1.col_bool" - FROM db.table1 - ) - EXCEPT - ( - SELECT table2.col_bool AS "table2.col_bool" - FROM db.table2 - ) -); -`) - - assertStatement(t, select1.EXCEPT_ALL(select2), ` -( - ( - SELECT table1.col_bool AS "table1.col_bool" - FROM db.table1 - ) - EXCEPT ALL - ( - SELECT table2.col_bool AS "table2.col_bool" - FROM db.table2 - ) -); -`) - -} diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go index f4db54e..aa6509e 100644 --- a/internal/jet/select_table.go +++ b/internal/jet/select_table.go @@ -1,70 +1,58 @@ package jet -//// SelectTable is interface for SELECT sub-queries -//type SelectTable interface { -// ReadableTable -// -// Alias() string -// -// AllColumns() ProjectionList -//} -// -//type selectTableImpl struct { -// readableTableInterfaceImpl -// selectStmt SelectStatement -// alias string -// -// projections []Projection -//} -// -//func newSelectTable(selectStmt SelectStatement, alias string) SelectTable { -// expTable := &selectTableImpl{selectStmt: selectStmt, alias: alias} -// -// expTable.readableTableInterfaceImpl.parent = expTable -// -// for _, projection := range selectStmt.projections() { -// newProjection := projection.fromImpl(expTable) -// -// expTable.projections = append(expTable.projections, newProjection) -// } -// -// return expTable -//} -// -//func (s *selectTableImpl) Alias() string { -// return s.alias -//} -// -//func (s *selectTableImpl) columns() []IColumn { -// return nil -//} -// -//func (s *selectTableImpl) accept(visitor visitor) { -// visitor.visit(s) -// s.selectStmt.accept(visitor) -//} -// -//func (s *selectTableImpl) dialect() Dialect { -// return detectDialect(s.selectStmt) -//} -// -//func (s *selectTableImpl) AllColumns() ProjectionList { -// return s.projections -//} -// -//func (s *selectTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { -// if s == nil { -// return errors.New("jet: Expression table is nil. ") -// } -// -// err := s.selectStmt.serialize(statement, out) -// -// if err != nil { -// return err -// } -// -// out.WriteString("AS") -// out.writeIdentifier(s.alias) -// -// return nil -//} +import "errors" + +// SelectTable is interface for SELECT sub-queries +type SelectTable interface { + Alias() string + AllColumns() ProjectionList +} + +type SelectTableImpl2 struct { + selectStmt StatementWithProjections + alias string + + projections []Projection +} + +func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl2 { + selectTable := SelectTableImpl2{selectStmt: selectStmt, alias: alias} + + for _, projection := range selectStmt.projections() { + newProjection := projection.fromImpl(&selectTable) + + selectTable.projections = append(selectTable.projections, newProjection) + } + + return selectTable +} + +func (s *SelectTableImpl2) Alias() string { + return s.alias +} + +func (s *SelectTableImpl2) accept(visitor visitor) { + visitor.visit(s) + s.selectStmt.accept(visitor) +} + +func (s *SelectTableImpl2) AllColumns() ProjectionList { + return s.projections +} + +func (s *SelectTableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { + if s == nil { + return errors.New("jet: Expression table is nil. ") + } + + err := s.selectStmt.serialize(statement, out) + + if err != nil { + return err + } + + out.WriteString("AS") + out.writeIdentifier(s.alias) + + return nil +} diff --git a/internal/jet/set_statement.go b/internal/jet/set_statement.go deleted file mode 100644 index 77dd783..0000000 --- a/internal/jet/set_statement.go +++ /dev/null @@ -1,197 +0,0 @@ -package jet - -import ( - "errors" -) - -// UNION effectively appends the result of sub-queries(select statements) into single query. -// It eliminates duplicate rows from its result. -func UNION(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return newSetStatementImpl(Union, false, toSelectList(lhs, rhs, selects...)) -} - -// UNION_ALL effectively appends the result of sub-queries(select statements) into single query. -// It does not eliminates duplicate rows from its result. -func UNION_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return newSetStatementImpl(Union, true, toSelectList(lhs, rhs, selects...)) -} - -// INTERSECT returns all rows that are in query results. -// It eliminates duplicate rows from its result. -func INTERSECT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return newSetStatementImpl(Intersect, false, toSelectList(lhs, rhs, selects...)) -} - -// INTERSECT_ALL returns all rows that are in query results. -// It does not eliminates duplicate rows from its result. -func INTERSECT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return newSetStatementImpl(Intersect, true, toSelectList(lhs, rhs, selects...)) -} - -// EXCEPT returns all rows that are in the result of query lhs but not in the result of query rhs. -// It eliminates duplicate rows from its result. -func EXCEPT(lhs, rhs SelectStatement) SelectStatement { - return newSetStatementImpl(Except, false, toSelectList(lhs, rhs)) -} - -// EXCEPT_ALL returns all rows that are in the result of query lhs but not in the result of query rhs. -// It does not eliminates duplicate rows from its result. -func EXCEPT_ALL(lhs, rhs SelectStatement) SelectStatement { - return newSetStatementImpl(Except, true, toSelectList(lhs, rhs)) -} - -func toSelectList(lhs, rhs SelectStatement, selects ...SelectStatement) []SelectStatement { - return append([]SelectStatement{lhs, rhs}, selects...) -} - -const ( - Union = "UNION" - Intersect = "INTERSECT" - Except = "EXCEPT" -) - -// Similar to selectStatementImpl, but less complete -type setStatementImpl struct { - selectStatementImpl - - operator string - all bool - selects []SelectStatement -} - -func newSetStatementImpl(operator string, all bool, selects []SelectStatement) SelectStatement { - setStatement := &setStatementImpl{ - operator: operator, - all: all, - selects: selects, - } - - setStatement.selectStatementImpl.ExpressionInterfaceImpl.Parent = setStatement - setStatement.selectStatementImpl.parent = setStatement - setStatement.limit = -1 - setStatement.offset = -1 - - return setStatement -} - -func (s *setStatementImpl) accept(visitor visitor) { - visitor.visit(s) - - for _, selects := range s.selects { - selects.accept(visitor) - } -} - -func (s *setStatementImpl) projections() []Projection { - if len(s.selects) > 0 { - return s.selects[0].projections() - } - return []Projection{} -} - -func (s *setStatementImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { - if s == nil { - return errors.New("jet: Set expression is nil. ") - } - - wrap := s.orderBy != nil || s.limit >= 0 || s.offset >= 0 - - if wrap { - out.WriteString("(") - out.increaseIdent() - } - - err := s.serializeImpl(out) - - if err != nil { - return err - } - - if wrap { - out.decreaseIdent() - out.NewLine() - out.WriteString(")") - } - - return nil -} - -func (s *setStatementImpl) serializeImpl(out *SqlBuilder) error { - if s == nil { - return errors.New("jet: Set expression is nil. ") - } - - if len(s.selects) < 2 { - return errors.New("jet: UNION Statement must have at least two SELECT statements") - } - - if setOverride := out.Dialect.SerializeOverride(s.operator); setOverride != nil { - return setOverride()(SelectStatementType, out) - } - - out.NewLine() - out.WriteString("(") - out.increaseIdent() - - for i, selectStmt := range s.selects { - out.NewLine() - if i > 0 { - out.WriteString(s.operator) - - if s.all { - out.WriteString("ALL") - } - out.NewLine() - } - - if selectStmt == nil { - return errors.New("jet: select statement is nil") - } - - err := selectStmt.serialize(SetStatementType, out) - - if err != nil { - return err - } - } - - out.decreaseIdent() - out.NewLine() - out.WriteString(")") - - if s.orderBy != nil { - err := out.writeOrderBy(SetStatementType, s.orderBy) - if err != nil { - return err - } - } - - if s.limit >= 0 { - out.NewLine() - out.WriteString("LIMIT") - out.insertParametrizedArgument(s.limit) - } - - if s.offset >= 0 { - out.NewLine() - out.WriteString("OFFSET") - out.insertParametrizedArgument(s.offset) - } - - return nil -} - -func (s *setStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { - queryData := &SqlBuilder{ - Dialect: detectDialect(s, dialect...), - } - - err = s.serializeImpl(queryData) - - if err != nil { - return - } - - query, args = queryData.finalize() - return -} diff --git a/internal/jet/set_statement_test.go b/internal/jet/set_statement_test.go deleted file mode 100644 index 159fb0b..0000000 --- a/internal/jet/set_statement_test.go +++ /dev/null @@ -1,301 +0,0 @@ -package jet - -import ( - "gotest.tools/assert" - "testing" -) - -func TestUnionTwoSelect(t *testing.T) { - var expectedSQL = ` -( - ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 - ) - UNION - ( - SELECT table2.col3 AS "table2.col3" - FROM db.table2 - ) -); -` - unionStmt1 := table1. - SELECT(table1Col1). - UNION( - table2.SELECT(table2Col3), - ) - - unionStmt2 := UNION(table1.SELECT(table1Col1), table2.SELECT(table2Col3)) - - assertStatement(t, unionStmt1, expectedSQL) - assertStatement(t, unionStmt2, expectedSQL) -} - -func TestUnionNilSelect(t *testing.T) { - unionStmt := table1. - SELECT(table1Col1). - UNION(nil) - - assertStatementErr(t, unionStmt, "jet: select statement is nil") -} - -func TestUnionThreeSelect1(t *testing.T) { - - unionStmt1 := table1.SELECT(table1Col1). - UNION( - table2.SELECT(table2Col3), - ). - UNION( - table3.SELECT(table3Col1), - ) - - var expectedSQL = ` -( - - ( - ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 - ) - UNION - ( - SELECT table2.col3 AS "table2.col3" - FROM db.table2 - ) - ) - UNION - ( - SELECT table3.col1 AS "table3.col1" - FROM db.table3 - ) -); -` - - assertStatement(t, unionStmt1, expectedSQL) -} - -func TestUnionThreeSelect2(t *testing.T) { - - unionStmt2 := UNION( - table1.SELECT(table1Col1), - table2.SELECT(table2Col3), - table3.SELECT(table3Col1), - ) - - var expectedSQL = ` -( - ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 - ) - UNION - ( - SELECT table2.col3 AS "table2.col3" - FROM db.table2 - ) - UNION - ( - SELECT table3.col1 AS "table3.col1" - FROM db.table3 - ) -); -` - - assertStatement(t, unionStmt2, expectedSQL) -} - -func TestUnionWithOrderBy(t *testing.T) { - unionStmt := UNION( - table1.SELECT(table1Col1), - table2.SELECT(table2Col3), - ). - ORDER_BY(table1Col1.ASC()) - - assertStatement(t, unionStmt, ` -( - ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 - ) - UNION - ( - SELECT table2.col3 AS "table2.col3" - FROM db.table2 - ) -) -ORDER BY "table1.col1" ASC; -`) -} - -func TestUnionWithLimitAndOffset(t *testing.T) { - query, args, err := UNION( - table1.SELECT(table1Col1), - table2.SELECT(table2Col3), - ). - LIMIT(10). - OFFSET(11).Sql() - - assert.NilError(t, err) - assert.Equal(t, query, ` -( - ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 - ) - UNION - ( - SELECT table2.col3 AS "table2.col3" - FROM db.table2 - ) -) -LIMIT $1 -OFFSET $2; -`) - assert.Equal(t, len(args), 2) -} - -func TestUnionInUnion(t *testing.T) { - expectedSQL := ` -( - ( - SELECT table2.col3 AS "table2.col3", - table2.col3 AS "table2.col3" - FROM db.table2 - ) - UNION - - ( - ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 - ) - UNION ALL - ( - SELECT table2.col3 AS "table2.col3" - FROM db.table2 - ) - ) -); -` - query := UNION( - SELECT(table2Col3, table2Col3).FROM(table2), - UNION_ALL(table1.SELECT(table1Col1), table2.SELECT(table2Col3)), - ) - - assertStatement(t, query, expectedSQL) -} - -func TestUnionALL(t *testing.T) { - query, args, err := UNION_ALL( - table1.SELECT(table1Col1), - table2.SELECT(table2Col3), - ).Sql() - - assert.NilError(t, err) - assert.Equal(t, query, ` -( - ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 - ) - UNION ALL - ( - SELECT table2.col3 AS "table2.col3" - FROM db.table2 - ) -); -`) - assert.Equal(t, len(args), 0) -} - -func TestINTERSECT(t *testing.T) { - query, args, err := INTERSECT( - table1.SELECT(table1Col1), - table2.SELECT(table2Col3), - ).Sql() - - assert.NilError(t, err) - assert.Equal(t, query, ` -( - ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 - ) - INTERSECT - ( - SELECT table2.col3 AS "table2.col3" - FROM db.table2 - ) -); -`) - assert.Equal(t, len(args), 0) -} - -func TestINTERSECT_ALL(t *testing.T) { - query, args, err := INTERSECT_ALL( - table1.SELECT(table1Col1), - table2.SELECT(table2Col3), - ).Sql() - - assert.NilError(t, err) - assert.Equal(t, query, ` -( - ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 - ) - INTERSECT ALL - ( - SELECT table2.col3 AS "table2.col3" - FROM db.table2 - ) -); -`) - assert.Equal(t, len(args), 0) -} - -func TestEXCEPT(t *testing.T) { - query, args, err := EXCEPT( - table1.SELECT(table1Col1), - table2.SELECT(table2Col3), - ).Sql() - - assert.NilError(t, err) - assert.Equal(t, query, ` -( - ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 - ) - EXCEPT - ( - SELECT table2.col3 AS "table2.col3" - FROM db.table2 - ) -); -`) - assert.Equal(t, len(args), 0) -} - -func TestEXCEPT_ALL(t *testing.T) { - query, args, err := EXCEPT_ALL( - table1.SELECT(table1Col1), - table2.SELECT(table2Col3), - ).Sql() - - assert.NilError(t, err) - assert.Equal(t, query, ` -( - ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 - ) - EXCEPT ALL - ( - SELECT table2.col3 AS "table2.col3" - FROM db.table2 - ) -); -`) - assert.Equal(t, len(args), 0) -} diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index dff7cea..ec0f449 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -24,7 +24,7 @@ func (s *SqlBuilder) DebugSQL() string { const defaultIdent = 5 -func (q *SqlBuilder) increaseIdent(ident ...int) { +func (q *SqlBuilder) IncreaseIdent(ident ...int) { if len(ident) > 0 { q.ident += ident[0] } else { @@ -32,7 +32,7 @@ func (q *SqlBuilder) increaseIdent(ident ...int) { } } -func (q *SqlBuilder) decreaseIdent(ident ...int) { +func (q *SqlBuilder) DecreaseIdent(ident ...int) { toDecrease := defaultIdent if len(ident) > 0 { @@ -46,10 +46,10 @@ func (q *SqlBuilder) decreaseIdent(ident ...int) { q.ident -= toDecrease } -func (q *SqlBuilder) writeProjections(statement StatementType, projections []Projection) error { - q.increaseIdent() +func (q *SqlBuilder) WriteProjections(statement StatementType, projections []Projection) error { + q.IncreaseIdent() err := SerializeProjectionList(statement, projections, q) - q.decreaseIdent() + q.DecreaseIdent() return err } @@ -57,9 +57,9 @@ func (q *SqlBuilder) writeFrom(statement StatementType, table Serializer) error q.NewLine() q.WriteString("FROM") - q.increaseIdent() + q.IncreaseIdent() err := table.serialize(statement, q) - q.decreaseIdent() + q.DecreaseIdent() return err } @@ -68,9 +68,9 @@ func (q *SqlBuilder) writeWhere(statement StatementType, where Expression) error q.NewLine() q.WriteString("WHERE") - q.increaseIdent() + q.IncreaseIdent() err := where.serialize(statement, q, noWrap) - q.decreaseIdent() + q.DecreaseIdent() return err } @@ -79,9 +79,9 @@ func (q *SqlBuilder) writeGroupBy(statement StatementType, groupBy []GroupByClau q.NewLine() q.WriteString("GROUP BY") - q.increaseIdent() + q.IncreaseIdent() err := serializeGroupByClauseList(statement, groupBy, q) - q.decreaseIdent() + q.DecreaseIdent() return err } @@ -90,9 +90,9 @@ func (q *SqlBuilder) writeOrderBy(statement StatementType, orderBy []OrderByClau q.NewLine() q.WriteString("ORDER BY") - q.increaseIdent() + q.IncreaseIdent() err := serializeOrderByClauseList(statement, orderBy, q) - q.decreaseIdent() + q.DecreaseIdent() return err } @@ -101,9 +101,9 @@ func (q *SqlBuilder) writeHaving(statement StatementType, having Expression) err q.NewLine() q.WriteString("HAVING") - q.increaseIdent() + q.IncreaseIdent() err := having.serialize(statement, q, noWrap) - q.decreaseIdent() + q.DecreaseIdent() return err } @@ -115,9 +115,9 @@ func (q *SqlBuilder) WriteReturning(statement StatementType, returning []Project q.NewLine() q.WriteString("RETURNING") - q.increaseIdent() + q.IncreaseIdent() - return q.writeProjections(statement, returning) + return q.WriteProjections(statement, returning) } func (q *SqlBuilder) NewLine() { diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 4057a7d..4a27b44 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -204,7 +204,7 @@ func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, opti if !contains(options, noWrap) { out.WriteString("(") - out.increaseIdent() + out.IncreaseIdent() } for _, clause := range s.Clauses { @@ -216,7 +216,7 @@ func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, opti } if !contains(options, noWrap) { - out.decreaseIdent() + out.DecreaseIdent() out.NewLine() out.WriteString(")") } diff --git a/internal/jet/table.go b/internal/jet/table.go index c0b69da..59174d2 100644 --- a/internal/jet/table.go +++ b/internal/jet/table.go @@ -11,154 +11,36 @@ type SerializerTable interface { } type TableInterface interface { - Columns() []IColumn -} - -type TableBase interface { - dialect() Dialect - columns() []IColumn -} - -type readableTable interface { - // Generates a select query on the current tableName. - SELECT(projection Projection, projections ...Projection) SelectStatement - - // Creates a inner join tableName Expression using onCondition. - INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable - - // Creates a left join tableName Expression using onCondition. - LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable - - // Creates a right join tableName Expression using onCondition. - RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable - - // Creates a full join tableName Expression using onCondition. - FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable - - // Creates a cross join tableName Expression using onCondition. - CROSS_JOIN(table ReadableTable) ReadableTable -} - -type writableTable interface { - INSERT(columns ...IColumn) InsertStatement - UPDATE(column IColumn, columns ...IColumn) UpdateStatement - DELETE() DeleteStatement - - LOCK() LockStatement -} - -// ReadableTable interface -type ReadableTable interface { - TableBase - readableTable - Serializer - acceptsVisitor -} - -// WritableTable interface -type WritableTable interface { - TableBase - writableTable - Serializer - acceptsVisitor -} - -// Table interface -type Table interface { - TableBase - readableTable - writableTable - Serializer - acceptsVisitor - + Columns() []Column SchemaName() string TableName() string AS(alias string) } -type readableTableInterfaceImpl struct { - parent ReadableTable -} - -// Generates a select query on the current tableName. -func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { - return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) -} - -// Creates a inner join tableName Expression using onCondition. -func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return newJoinTable(r.parent, table, InnerJoin, onCondition) -} - -// Creates a left join tableName Expression using onCondition. -func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return newJoinTable(r.parent, table, LeftJoin, onCondition) -} - -// Creates a right join tableName Expression using onCondition. -func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return newJoinTable(r.parent, table, RightJoin, onCondition) -} - -func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return newJoinTable(r.parent, table, FullJoin, onCondition) -} - -func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable { - return newJoinTable(r.parent, table, CrossJoin, nil) -} - -type writableTableInterfaceImpl struct { - parent WritableTable -} - -func (w *writableTableInterfaceImpl) INSERT(columns ...IColumn) InsertStatement { - return newInsertStatement(w.parent, UnwidColumnList(columns)) -} - -func (w *writableTableInterfaceImpl) UPDATE(column IColumn, columns ...IColumn) UpdateStatement { - return newUpdateStatement(w.parent, UnwindColumns(column, columns...)) -} - -func (w *writableTableInterfaceImpl) DELETE() DeleteStatement { - return newDeleteStatement(w.parent) -} - -func (w *writableTableInterfaceImpl) LOCK() LockStatement { - return LOCK(w.parent) -} - // NewTable creates new table with schema Name, table Name and list of columns -func NewTable(Dialect Dialect, schemaName, name string, columns ...Column) Table { +func NewTable(schemaName, name string, columns ...ColumnExpression) TableImpl { - t := &tableImpl{ - Dialect: Dialect, + t := TableImpl{ schemaName: schemaName, name: name, columnList: columns, } + for _, c := range columns { c.SetTableName(name) } - t.readableTableInterfaceImpl.parent = t - t.writableTableInterfaceImpl.parent = t - return t } -type tableImpl struct { - readableTableInterfaceImpl - writableTableInterfaceImpl - - Dialect Dialect +type TableImpl struct { schemaName string name string alias string - columnList []Column + columnList []ColumnExpression } -func (t *tableImpl) AS(alias string) { +func (t *TableImpl) AS(alias string) { t.alias = alias for _, c := range t.columnList { @@ -166,16 +48,16 @@ func (t *tableImpl) AS(alias string) { } } -func (t *tableImpl) SchemaName() string { +func (t *TableImpl) SchemaName() string { return t.schemaName } -func (t *tableImpl) TableName() string { +func (t *TableImpl) TableName() string { return t.name } -func (t *tableImpl) columns() []IColumn { - ret := []IColumn{} +func (t *TableImpl) Columns() []Column { + ret := []Column{} for _, col := range t.columnList { ret = append(ret, col) @@ -184,15 +66,11 @@ func (t *tableImpl) columns() []IColumn { return ret } -func (t *tableImpl) dialect() Dialect { - return t.Dialect -} - -func (t *tableImpl) accept(visitor visitor) { +func (t *TableImpl) accept(visitor visitor) { visitor.visit(t) } -func (t *tableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (t *TableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { if t == nil { return errors.New("jet: tableImpl is nil. ") } @@ -220,55 +98,45 @@ const ( ) // Join expressions are pseudo readable tables. -type joinTable struct { - readableTableInterfaceImpl - - lhs ReadableTable - rhs ReadableTable +type JoinTableImpl struct { + lhs Serializer + rhs Serializer joinType JoinType onCondition BoolExpression } -func newJoinTable( - lhs ReadableTable, - rhs ReadableTable, - joinType JoinType, - onCondition BoolExpression) *joinTable { +func NewJoinTableImpl(lhs Serializer, rhs Serializer, joinType JoinType, onCondition BoolExpression) JoinTableImpl { - joinTable := &joinTable{ + joinTable := JoinTableImpl{ lhs: lhs, rhs: rhs, joinType: joinType, onCondition: onCondition, } - joinTable.readableTableInterfaceImpl.parent = joinTable - return joinTable } -func (t *joinTable) SchemaName() string { +func (t *JoinTableImpl) SchemaName() string { return "" } -func (t *joinTable) TableName() string { +func (t *JoinTableImpl) TableName() string { return "" } -func (t *joinTable) columns() []IColumn { - return append(t.lhs.columns(), t.rhs.columns()...) +func (t *JoinTableImpl) Columns() []Column { + //return append(t.lhs.columns(), t.rhs.columns()...) + panic("Unimplemented") } -func (t *joinTable) accept(visitor visitor) { - t.lhs.accept(visitor) - t.rhs.accept(visitor) +func (t *JoinTableImpl) accept(visitor visitor) { + //t.lhs.accept(visitor) + //t.rhs.accept(visitor) + //TODO: remove } -func (t *joinTable) dialect() Dialect { - return detectDialect(t) -} - -func (t *joinTable) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) { +func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) { if t == nil { return errors.New("jet: Join table is nil. ") } @@ -318,8 +186,8 @@ func (t *joinTable) serialize(statement StatementType, out *SqlBuilder, options return nil } -func UnwindColumns(column1 IColumn, columns ...IColumn) []IColumn { - columnList := []IColumn{} +func UnwindColumns(column1 Column, columns ...Column) []Column { + columnList := []Column{} if val, ok := column1.(IColumnList); ok { for _, col := range val.Columns() { @@ -334,8 +202,8 @@ func UnwindColumns(column1 IColumn, columns ...IColumn) []IColumn { return columnList } -func UnwidColumnList(columns []IColumn) []IColumn { - ret := []IColumn{} +func UnwidColumnList(columns []Column) []Column { + ret := []Column{} for _, col := range columns { if columnList, ok := col.(IColumnList); ok { diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index 97f4a80..8d71556 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -2,9 +2,18 @@ package jet import ( "gotest.tools/assert" + "strconv" "testing" ) +var DefaultDialect = NewDialect(DialectParams{ // just for tests + AliasQuoteChar: '"', + IdentifierQuoteChar: '"', + ArgumentPlaceholder: func(ord int) string { + return "$" + strconv.Itoa(ord) + }, +}) + var table1Col1 = IntegerColumn("col1") var table1ColInt = IntegerColumn("col_int") var table1ColFloat = FloatColumn("col_float") @@ -16,21 +25,7 @@ var table1ColTimestampz = TimestampzColumn("col_timestampz") var table1ColBool = BoolColumn("col_bool") var table1ColDate = DateColumn("col_date") -var table1 = NewTable( - ANSII, - "db", - "table1", - table1Col1, - table1ColInt, - table1ColFloat, - table1Col3, - table1ColTime, - table1ColTimez, - table1ColBool, - table1ColDate, - table1ColTimestamp, - table1ColTimestampz, -) +var table1 = NewTable("db", "table1", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz) var table2Col3 = IntegerColumn("col3") var table2Col4 = IntegerColumn("col4") @@ -44,46 +39,27 @@ var table2ColTimestamp = TimestampColumn("col_timestamp") var table2ColTimestampz = TimestampzColumn("col_timestampz") var table2ColDate = DateColumn("col_date") -var table2 = NewTable( - ANSII, - "db", - "table2", - table2Col3, - table2Col4, - table2ColInt, - table2ColFloat, - table2ColStr, - table2ColBool, - table2ColTime, - table2ColTimez, - table2ColDate, - table2ColTimestamp, - table2ColTimestampz, -) +var table2 = NewTable("db", "table2", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz) var table3Col1 = IntegerColumn("col1") var table3ColInt = IntegerColumn("col_int") var table3StrCol = StringColumn("col2") -var table3 = NewTable( - ANSII, - "db", - "table3", - table3Col1, - table3ColInt, - table3StrCol) +var table3 = NewTable("db", "table3", table3Col1, table3ColInt, table3StrCol) func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) { - out := SqlBuilder{Dialect: ANSII} + out := SqlBuilder{Dialect: DefaultDialect} err := clause.serialize(SelectStatementType, &out) assert.NilError(t, err) + //fmt.Println(out.Buff.String()) + assert.DeepEqual(t, out.Buff.String(), query) assert.DeepEqual(t, out.Args, args) } func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) { - out := SqlBuilder{Dialect: ANSII} + out := SqlBuilder{Dialect: DefaultDialect} err := clause.serialize(SelectStatementType, &out) //fmt.Println(out.buff.String()) @@ -92,7 +68,7 @@ func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) } func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) { - out := SqlBuilder{Dialect: ANSII} + out := SqlBuilder{Dialect: DefaultDialect} err := projection.serializeForProjection(SelectStatementType, &out) assert.NilError(t, err) @@ -110,7 +86,7 @@ func assertStatement(t *testing.T, query Statement, expectedQuery string, expect } func assertStatementErr(t *testing.T, stmt Statement, errorStr string) { - _, _, err := stmt.Sql(ANSII) + _, _, err := stmt.Sql(DefaultDialect) assert.Assert(t, err != nil) assert.Error(t, err, errorStr) diff --git a/internal/jet/timestampz_expression_test.go b/internal/jet/timestampz_expression_test.go index d67b2b6..8ec30d3 100644 --- a/internal/jet/timestampz_expression_test.go +++ b/internal/jet/timestampz_expression_test.go @@ -1,5 +1,3 @@ -// +build todo - package jet import "testing" @@ -9,46 +7,46 @@ var timestampz = Timestampz(2000, 1, 31, 10, 20, 0, 0, 2) func TestTimestampzExpressionEQ(t *testing.T) { assertClauseSerialize(t, table1ColTimestampz.EQ(table2ColTimestampz), "(table1.col_timestampz = table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.EQ(timestampz), - "(table1.col_timestampz = $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") + "(table1.col_timestampz = $1)", "2000-01-31 10:20:00.000 +002") } func TestTimestampzExpressionNOT_EQ(t *testing.T) { assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(table2ColTimestampz), "(table1.col_timestampz != table2.col_timestampz)") - assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") + assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1)", "2000-01-31 10:20:00.000 +002") } func TestTimestampzExpressionIS_DISTINCT_FROM(t *testing.T) { assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS DISTINCT FROM table2.col_timestampz)") - assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") + assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1)", "2000-01-31 10:20:00.000 +002") } func TestTimestampzExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS NOT DISTINCT FROM table2.col_timestampz)") - assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") + assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1)", "2000-01-31 10:20:00.000 +002") } func TestTimestampzExpressionLT(t *testing.T) { assertClauseSerialize(t, table1ColTimestampz.LT(table2ColTimestampz), "(table1.col_timestampz < table2.col_timestampz)") - assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") + assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1)", "2000-01-31 10:20:00.000 +002") } func TestTimestampzExpressionLT_EQ(t *testing.T) { assertClauseSerialize(t, table1ColTimestampz.LT_EQ(table2ColTimestampz), "(table1.col_timestampz <= table2.col_timestampz)") - assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") + assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1)", "2000-01-31 10:20:00.000 +002") } func TestTimestampzExpressionGT(t *testing.T) { assertClauseSerialize(t, table1ColTimestampz.GT(table2ColTimestampz), "(table1.col_timestampz > table2.col_timestampz)") - assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") + assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1)", "2000-01-31 10:20:00.000 +002") } func TestTimestampzExpressionGT_EQ(t *testing.T) { assertClauseSerialize(t, table1ColTimestampz.GT_EQ(table2ColTimestampz), "(table1.col_timestampz >= table2.col_timestampz)") - assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") + assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1)", "2000-01-31 10:20:00.000 +002") } func TestTimestampzExp(t *testing.T) { assertClauseSerialize(t, TimestampzExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz), - "(table1.col_float < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") + "(table1.col_float < $1)", "2000-01-31 10:20:00.000 +002") } diff --git a/internal/jet/timez_expression_test.go b/internal/jet/timez_expression_test.go index e60fc16..23aef52 100644 --- a/internal/jet/timez_expression_test.go +++ b/internal/jet/timez_expression_test.go @@ -1,5 +1,3 @@ -// +build TODO - package jet import "testing" @@ -8,46 +6,46 @@ var timezVar = Timez(10, 20, 0, 0, 4) func TestTimezExpressionEQ(t *testing.T) { assertClauseSerialize(t, table1ColTimez.EQ(table2ColTimez), "(table1.col_timez = table2.col_timez)") - assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1::time with time zone)", "10:20:00.000 +04") + assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1)", "10:20:00.000 +04") } func TestTimezExpressionNOT_EQ(t *testing.T) { assertClauseSerialize(t, table1ColTimez.NOT_EQ(table2ColTimez), "(table1.col_timez != table2.col_timez)") - assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1::time with time zone)", "10:20:00.000 +04") + assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1)", "10:20:00.000 +04") } func TestTimezExpressionIS_DISTINCT_FROM(t *testing.T) { assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS DISTINCT FROM table2.col_timez)") - assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04") + assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1)", "10:20:00.000 +04") } func TestTimezExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS NOT DISTINCT FROM table2.col_timez)") - assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04") + assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1)", "10:20:00.000 +04") } func TestTimezExpressionLT(t *testing.T) { assertClauseSerialize(t, table1ColTimez.LT(table2ColTimez), "(table1.col_timez < table2.col_timez)") - assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1::time with time zone)", "10:20:00.000 +04") + assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1)", "10:20:00.000 +04") } func TestTimezExpressionLT_EQ(t *testing.T) { assertClauseSerialize(t, table1ColTimez.LT_EQ(table2ColTimez), "(table1.col_timez <= table2.col_timez)") - assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1::time with time zone)", "10:20:00.000 +04") + assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1)", "10:20:00.000 +04") } func TestTimezExpressionGT(t *testing.T) { assertClauseSerialize(t, table1ColTimez.GT(table2ColTimez), "(table1.col_timez > table2.col_timez)") - assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1::time with time zone)", "10:20:00.000 +04") + assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1)", "10:20:00.000 +04") } func TestTimezExpressionGT_EQ(t *testing.T) { assertClauseSerialize(t, table1ColTimez.GT_EQ(table2ColTimez), "(table1.col_timez >= table2.col_timez)") - assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1::time with time zone)", "10:20:00.000 +04") + assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1)", "10:20:00.000 +04") } func TestTimezExp(t *testing.T) { assertClauseSerialize(t, TimezExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, 4)), - "(table1.col_float < $1::time with time zone)", string("01:01:01.001 +04")) + "(table1.col_float < $1)", string("01:01:01.001 +04")) } diff --git a/internal/jet/update_statement.go b/internal/jet/update_statement.go deleted file mode 100644 index 59dc9ed..0000000 --- a/internal/jet/update_statement.go +++ /dev/null @@ -1,126 +0,0 @@ -package jet - -import ( - "context" - "database/sql" - "errors" - "github.com/go-jet/jet/execution" - "github.com/go-jet/jet/internal/utils" -) - -// UpdateStatement is interface of SQL UPDATE statement -type UpdateStatement interface { - Statement - - SET(value interface{}, values ...interface{}) UpdateStatement - MODEL(data interface{}) UpdateStatement - - WHERE(expression BoolExpression) UpdateStatement - RETURNING(projections ...Projection) UpdateStatement -} - -func newUpdateStatement(table WritableTable, columns []IColumn) UpdateStatement { - return &updateStatementImpl{ - table: table, - columns: columns, - values: make([]Serializer, 0, len(columns)), - } -} - -type updateStatementImpl struct { - table WritableTable - columns []IColumn - values []Serializer - where BoolExpression - returning []Projection -} - -func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { - u.values = UnwindRowFromValues(value, values) - - return u -} - -func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement { - u.values = UnwindRowFromModel(u.columns, data) - - return u -} - -func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { - u.where = expression - return u -} - -func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement { - u.returning = projections - return u -} - -func (u *updateStatementImpl) accept(visitor visitor) { - visitor.visit(u) - u.table.accept(visitor) -} - -func (u *updateStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { - out := &SqlBuilder{ - Dialect: detectDialect(u, dialect...), - } - - out.NewLine() - out.WriteString("UPDATE") - - if utils.IsNil(u.table) { - return "", nil, errors.New("jet: table to update is nil") - } - - if err = u.table.serialize(UpdateStatementType, out); err != nil { - return - } - - if len(u.columns) == 0 { - return "", nil, errors.New("jet: no columns selected") - } - - if len(u.values) == 0 { - return "", nil, errors.New("jet: no values to updated") - } - - out.NewLine() - out.WriteString("SET") - - if u.where == nil { - return "", nil, errors.New("jet: WHERE clause not set") - } - - if err = out.writeWhere(UpdateStatementType, u.where); err != nil { - return - } - - if err = out.WriteReturning(UpdateStatementType, u.returning); err != nil { - return - } - - query, args = out.finalize() - return -} - -func (u *updateStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) { - return debugSql(u, dialect...) -} - -func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) error { - return query(u, db, destination) -} - -func (u *updateStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error { - return queryContext(context, u, db, destination) -} - -func (u *updateStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { - return exec(u, db) -} - -func (u *updateStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) { - return execContext(context, u, db) -} diff --git a/internal/jet/update_statement_test.go b/internal/jet/update_statement_test.go deleted file mode 100644 index 37ddde7..0000000 --- a/internal/jet/update_statement_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package jet - -import ( - "testing" -) - -func TestUpdateWithOneValue(t *testing.T) { - expectedSQL := ` -UPDATE db.table1 -SET col_int = $1 -WHERE table1.col_int >= $2; -` - stmt := table1.UPDATE(table1ColInt). - SET(1). - WHERE(table1ColInt.GT_EQ(Int(33))) - - assertStatement(t, stmt, expectedSQL, 1, int64(33)) -} - -func TestUpdateWithValues(t *testing.T) { - expectedSQL := ` -UPDATE db.table1 -SET col_int = $1, col_float = $2 -WHERE table1.col_int >= $3; -` - stmt := table1.UPDATE(table1ColInt, table1ColFloat). - SET(1, 22.2). - WHERE(table1ColInt.GT_EQ(Int(33))) - - assertStatement(t, stmt, expectedSQL, 1, 22.2, int64(33)) -} - -func TestUpdateOneColumnWithSelect(t *testing.T) { - expectedSQL := ` -UPDATE db.table1 -SET col_float = ( - SELECT table1.col_float AS "table1.col_float" - FROM db.table1 -) -WHERE table1.col1 = $1 -RETURNING table1.col1 AS "table1.col1"; -` - stmt := table1. - UPDATE(table1ColFloat). - SET( - table1.SELECT(table1ColFloat), - ). - WHERE(table1Col1.EQ(Int(2))). - RETURNING(table1Col1) - - assertStatement(t, stmt, expectedSQL, int64(2)) -} - -func TestInvalidInputs(t *testing.T) { - assertStatementErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") - assertStatementErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list") -} diff --git a/internal/jet/utils.go b/internal/jet/utils.go index e1a7924..9e8a69f 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -98,7 +98,7 @@ func SerializeProjectionList(statement StatementType, projections []Projection, return nil } -func SerializeColumnNames(columns []IColumn, out *SqlBuilder) error { +func SerializeColumnNames(columns []Column, out *SqlBuilder) error { for i, col := range columns { if i > 0 { out.WriteString(", ") @@ -114,7 +114,7 @@ func SerializeColumnNames(columns []IColumn, out *SqlBuilder) error { return nil } -func ColumnListToProjectionList(columns []Column) []Projection { +func ColumnListToProjectionList(columns []ColumnExpression) []Projection { var ret []Projection for _, column := range columns { @@ -132,7 +132,7 @@ func valueToClause(value interface{}) Serializer { return literal(value) } -func UnwindRowFromModel(columns []IColumn, data interface{}) []Serializer { +func UnwindRowFromModel(columns []Column, data interface{}) []Serializer { structValue := reflect.Indirect(reflect.ValueOf(data)) row := []Serializer{} @@ -163,7 +163,7 @@ func UnwindRowFromModel(columns []IColumn, data interface{}) []Serializer { return row } -func UnwindRowsFromModels(columns []IColumn, data interface{}) [][]Serializer { +func UnwindRowsFromModels(columns []Column, data interface{}) [][]Serializer { sliceValue := reflect.Indirect(reflect.ValueOf(data)) mustBe(sliceValue, reflect.Slice) diff --git a/internal/jet/visitor.go b/internal/jet/visitor.go index ff6172b..4586244 100644 --- a/internal/jet/visitor.go +++ b/internal/jet/visitor.go @@ -45,10 +45,10 @@ func (f *DialectFinder) mustGetDialect() Dialect { func (f *DialectFinder) visit(element acceptsVisitor) { - if table, ok := element.(TableBase); ok { - dialect := table.dialect() - f.dialects[dialect.Name()] = dialect - } + //if table, ok := element.(TableBase); ok { + // dialect := table.dialect() + // f.dialects[dialect.Name()] = dialect + //} } func detectDialect(element acceptsVisitor, dialectOverride ...Dialect) Dialect { diff --git a/mysql/cast.go b/mysql/cast.go index 0707efa..c53005e 100644 --- a/mysql/cast.go +++ b/mysql/cast.go @@ -14,13 +14,13 @@ type cast interface { } type castImpl struct { - jet.CastImpl + jet.Cast } func CAST(expr jet.Expression) cast { castImpl := &castImpl{} - castImpl.CastImpl = jet.NewCastImpl(expr) + castImpl.Cast = jet.NewCastImpl(expr) return castImpl } diff --git a/mysql/columns.go b/mysql/columns.go index 819cee4..7dd6c0e 100644 --- a/mysql/columns.go +++ b/mysql/columns.go @@ -2,7 +2,7 @@ package mysql import "github.com/go-jet/jet/internal/jet" -type Column jet.Column +type Column jet.ColumnExpression type IColumnList jet.IColumnList diff --git a/mysql/dialect_test.go b/mysql/dialect_test.go index e334219..ed202d8 100644 --- a/mysql/dialect_test.go +++ b/mysql/dialect_test.go @@ -33,3 +33,16 @@ func TestIntExpressionBIT_XOR(t *testing.T) { assertClauseSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int ^ table2.col_int)") assertClauseSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int ^ ?)", int64(11)) } + +func TestExists(t *testing.T) { + assertClauseSerialize(t, EXISTS( + table2. + SELECT(Int(1)). + WHERE(table1Col1.EQ(table2Col3)), + ), + `(EXISTS ( + SELECT ? + FROM db.table2 + WHERE table1.col1 = table2.col3 +))`, int64(1)) +} diff --git a/mysql/insert_statement.go b/mysql/insert_statement.go index 21c239f..3232b52 100644 --- a/mysql/insert_statement.go +++ b/mysql/insert_statement.go @@ -16,9 +16,9 @@ type InsertStatement interface { QUERY(selectStatement SelectStatement) InsertStatement } -func newInsertStatement(table Table, columns []jet.IColumn) InsertStatement { +func newInsertStatement(table Table, columns []jet.Column) InsertStatement { newInsert := &insertStatementImpl{} - newInsert.StatementImpl = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newInsert, + newInsert.StatementImpl = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, &newInsert.Insert, &newInsert.Values, &newInsert.Select) newInsert.Insert.Table = table @@ -41,12 +41,12 @@ func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) I } func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { - i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromModel(i.getColumns(), data)) + i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromModel(i.Insert.GetColumns(), data)) return i } func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { - i.Values.Rows = append(i.Values.Rows, jet.UnwindRowsFromModels(i.getColumns(), data)...) + i.Values.Rows = append(i.Values.Rows, jet.UnwindRowsFromModels(i.Insert.GetColumns(), data)...) return i } @@ -54,11 +54,3 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState i.Select.Query = selectStatement return i } - -func (i *insertStatementImpl) getColumns() []jet.IColumn { - if len(i.Insert.Columns) > 0 { - return i.Insert.Columns - } - - return i.Insert.Table.Columns() -} diff --git a/mysql/table.go b/mysql/table.go index c463882..cac5bd7 100644 --- a/mysql/table.go +++ b/mysql/table.go @@ -12,12 +12,12 @@ type Table interface { jet.SerializerTable readableTable - INSERT(columns ...jet.IColumn) InsertStatement - UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement + INSERT(columns ...jet.Column) InsertStatement + UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement DELETE() DeleteStatement //LOCK() LockStatement - AS(alias string) + //As(alias string) } type readableTable interface { @@ -41,8 +41,8 @@ type readableTable interface { } type ReadableTable interface { - jet.SerializerTable readableTable + jet.Serializer } type readableTableInterfaceImpl struct { @@ -77,9 +77,9 @@ func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) Table { return newJoinTable(r.parent, table, jet.CrossJoin, nil) } -func NewTable(schemaName, name string, columns ...jet.Column) Table { +func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table { t := &tableImpl{ - TableImpl2: jet.NewTable2(Dialect, schemaName, name, columns...), + TableImpl: jet.NewTable(schemaName, name, columns...), } t.readableTableInterfaceImpl.parent = t @@ -89,16 +89,16 @@ func NewTable(schemaName, name string, columns ...jet.Column) Table { } type tableImpl struct { - jet.TableImpl2 + jet.TableImpl readableTableInterfaceImpl parent Table } -func (w *tableImpl) INSERT(columns ...jet.IColumn) InsertStatement { +func (w *tableImpl) INSERT(columns ...jet.Column) InsertStatement { return newInsertStatement(w.parent, jet.UnwidColumnList(columns)) } -func (w *tableImpl) UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement { +func (w *tableImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement { return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...)) } diff --git a/mysql/table_test.go b/mysql/table_test.go new file mode 100644 index 0000000..da45f36 --- /dev/null +++ b/mysql/table_test.go @@ -0,0 +1,101 @@ +package mysql + +import ( + "testing" +) + +func TestJoinNilInputs(t *testing.T) { + assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), + "jet: right hand side of join operation is nil table") + assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil), + "jet: join condition is nil") +} + +func TestINNER_JOIN(t *testing.T) { + assertClauseSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertClauseSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)). + INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +INNER JOIN db.table2 ON (table1.col_int = table2.col_int) +INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertClauseSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(Int(1))). + INNER_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +INNER JOIN db.table2 ON (table1.col_int = ?) +INNER JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestLEFT_JOIN(t *testing.T) { + assertClauseSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertClauseSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)). + LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +LEFT JOIN db.table2 ON (table1.col_int = table2.col_int) +LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertClauseSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(Int(1))). + LEFT_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +LEFT JOIN db.table2 ON (table1.col_int = ?) +LEFT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestRIGHT_JOIN(t *testing.T) { + assertClauseSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertClauseSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)). + RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int) +RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertClauseSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))). + RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +RIGHT JOIN db.table2 ON (table1.col_int = ?) +RIGHT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestFULL_JOIN(t *testing.T) { + assertClauseSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertClauseSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)). + FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +FULL JOIN db.table2 ON (table1.col_int = table2.col_int) +FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertClauseSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(Int(1))). + FULL_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +FULL JOIN db.table2 ON (table1.col_int = ?) +FULL JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestCROSS_JOIN(t *testing.T) { + assertClauseSerialize(t, table1. + CROSS_JOIN(table2), + `db.table1 +CROSS JOIN db.table2`) + assertClauseSerialize(t, table1. + CROSS_JOIN(table2). + CROSS_JOIN(table3), + `db.table1 +CROSS JOIN db.table2 +CROSS JOIN db.table3`) +} diff --git a/mysql/update_statement.go b/mysql/update_statement.go index cf82972..35db0e8 100644 --- a/mysql/update_statement.go +++ b/mysql/update_statement.go @@ -20,7 +20,7 @@ type updateStatementImpl struct { Where jet.ClauseWhere } -func newUpdateStatement(table Table, columns []jet.IColumn) UpdateStatement { +func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { update := &updateStatementImpl{} update.StatementImpl = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, &update.Set, &update.Where) diff --git a/mysql/utils_test.go b/mysql/utils_test.go index 0a91e62..6af9545 100644 --- a/mysql/utils_test.go +++ b/mysql/utils_test.go @@ -64,6 +64,8 @@ func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, ar assert.NilError(t, err) + //fmt.Println(out.Buff.String()) + assert.DeepEqual(t, out.Buff.String(), query) assert.DeepEqual(t, out.Args, args) } diff --git a/postgres/cast.go b/postgres/cast.go index 1469267..747dd97 100644 --- a/postgres/cast.go +++ b/postgres/cast.go @@ -37,13 +37,13 @@ type cast interface { } type castImpl struct { - jet.CastImpl + jet.Cast } func CAST(expr Expression) cast { castImpl := &castImpl{} - castImpl.CastImpl = jet.NewCastImpl(expr) + castImpl.Cast = jet.NewCastImpl(expr) return castImpl } diff --git a/postgres/clauses.go b/postgres/clauses.go new file mode 100644 index 0000000..fa3d01b --- /dev/null +++ b/postgres/clauses.go @@ -0,0 +1,21 @@ +package postgres + +import ( + "github.com/go-jet/jet/internal/jet" +) + +type ClauseReturning struct { + Projections []jet.Projection +} + +func (r *ClauseReturning) Serialize(statementType jet.StatementType, out *jet.SqlBuilder) error { + if len(r.Projections) == 0 { + return nil + } + + out.NewLine() + out.WriteString("RETURNING") + out.IncreaseIdent() + + return out.WriteProjections(statementType, r.Projections) +} diff --git a/postgres/columns.go b/postgres/columns.go index 2ccc252..fb6722a 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -2,7 +2,7 @@ package postgres import "github.com/go-jet/jet/internal/jet" -type Column jet.Column +type Column jet.ColumnExpression type IColumnList jet.IColumnList diff --git a/postgres/delete_statement.go b/postgres/delete_statement.go index 0fcfcfd..cf4caad 100644 --- a/postgres/delete_statement.go +++ b/postgres/delete_statement.go @@ -15,7 +15,7 @@ type deleteStatementImpl struct { Delete jet.ClauseStatementBegin Where jet.ClauseWhere - Returning jet.ClauseReturning + Returning ClauseReturning } func newDeleteStatement(table WritableTable) DeleteStatement { diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index 16c694f..aecb97f 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -13,3 +13,48 @@ func TestString_REGEXP_LIKE_function(t *testing.T) { assertClauseSerialize(t, REGEXP_LIKE(table3StrCol, String("JOHN"), "c"), "table3.col2 ~ $1", "JOHN") assertClauseSerialize(t, REGEXP_LIKE(table3StrCol, String("JOHN"), "i"), "table3.col2 ~* $1", "JOHN") } + +func TestExists(t *testing.T) { + assertClauseSerialize(t, EXISTS( + table2. + SELECT(Int(1)). + WHERE(table1Col1.EQ(table2Col3)), + ), + `(EXISTS ( + SELECT $1 + FROM db.table2 + WHERE table1.col1 = table2.col3 +))`, int64(1)) +} + +func TestIN(t *testing.T) { + + assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), + `($1 IN (( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 +)))`, float64(1.11)) + + assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), + `(ROW($1, table1.col1) IN (( + SELECT table2.col3 AS "table2.col3", + table3.col1 AS "table3.col1" + FROM db.table2 +)))`, int64(12)) +} + +func TestNOT_IN(t *testing.T) { + + assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), + `($1 NOT IN (( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 +)))`, float64(1.11)) + + assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), + `(ROW($1, table1.col1) NOT IN (( + SELECT table2.col3 AS "table2.col3", + table3.col1 AS "table3.col1" + FROM db.table2 +)))`, int64(12)) +} diff --git a/postgres/insert_statement.go b/postgres/insert_statement.go index c93bc9a..b570998 100644 --- a/postgres/insert_statement.go +++ b/postgres/insert_statement.go @@ -19,9 +19,9 @@ type InsertStatement interface { RETURNING(projections ...jet.Projection) InsertStatement } -func newInsertStatement(table WritableTable, columns []jet.IColumn) InsertStatement { +func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement { newInsert := &insertStatementImpl{} - newInsert.StatementImpl = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newInsert, + newInsert.StatementImpl = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, &newInsert.Insert, &newInsert.Values, &newInsert.Select, &newInsert.Returning) newInsert.Insert.Table = table @@ -36,7 +36,7 @@ type insertStatementImpl struct { Insert jet.ClauseInsert Values jet.ClauseValues Select jet.ClauseQuery - Returning jet.ClauseReturning + Returning ClauseReturning } func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { @@ -45,12 +45,12 @@ func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) I } func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { - i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromModel(i.getColumns(), data)) + i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromModel(i.Insert.GetColumns(), data)) return i } func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { - i.Values.Rows = append(i.Values.Rows, jet.UnwindRowsFromModels(i.getColumns(), data)...) + i.Values.Rows = append(i.Values.Rows, jet.UnwindRowsFromModels(i.Insert.GetColumns(), data)...) return i } @@ -63,11 +63,3 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState i.Select.Query = selectStatement return i } - -func (i *insertStatementImpl) getColumns() []jet.IColumn { - if len(i.Insert.Columns) > 0 { - return i.Insert.Columns - } - - return i.Insert.Table.Columns() -} diff --git a/postgres/lock_statement.go b/postgres/lock_statement.go index 0e31f14..cf13cd8 100644 --- a/postgres/lock_statement.go +++ b/postgres/lock_statement.go @@ -25,7 +25,7 @@ type LockStatement interface { func LOCK(tables ...jet.SerializerTable) LockStatement { newLock := &lockStatementImpl{} - newLock.StatementImpl = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newLock, + newLock.StatementImpl = jet.NewStatementImpl(Dialect, jet.LockStatementType, newLock, &newLock.StatementBegin, &newLock.In, &newLock.NoWait) newLock.StatementBegin.Name = "LOCK TABLE" diff --git a/postgres/table.go b/postgres/table.go index bce6fe3..2ceb556 100644 --- a/postgres/table.go +++ b/postgres/table.go @@ -23,8 +23,8 @@ type readableTable interface { } type writableTable interface { - INSERT(columns ...jet.IColumn) InsertStatement - UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement + INSERT(columns ...jet.Column) InsertStatement + UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement DELETE() DeleteStatement LOCK() LockStatement } @@ -47,12 +47,12 @@ type Table interface { //table readableTable writableTable - jet.Serializer + jet.SerializerTable //acceptsVisitor - SchemaName() string - TableName() string - AS(alias string) + //SchemaName() string + //TableName() string + //As(alias string) } type readableTableInterfaceImpl struct { @@ -91,11 +91,11 @@ type writableTableInterfaceImpl struct { parent WritableTable } -func (w *writableTableInterfaceImpl) INSERT(columns ...jet.IColumn) InsertStatement { +func (w *writableTableInterfaceImpl) INSERT(columns ...jet.Column) InsertStatement { return newInsertStatement(w.parent, jet.UnwidColumnList(columns)) } -func (w *writableTableInterfaceImpl) UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement { +func (w *writableTableInterfaceImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement { return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...)) } @@ -111,13 +111,13 @@ type table2Impl struct { readableTableInterfaceImpl writableTableInterfaceImpl - jet.TableImpl2 + jet.TableImpl } -func NewTable(schemaName, name string, columns ...jet.Column) Table { +func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table { t := &table2Impl{ - TableImpl2: jet.NewTable2(Dialect, schemaName, name, columns...), + TableImpl: jet.NewTable(schemaName, name, columns...), } t.readableTableInterfaceImpl.parent = t diff --git a/internal/jet/table_test.go b/postgres/table_test.go similarity index 99% rename from internal/jet/table_test.go rename to postgres/table_test.go index 9a90da6..6573b02 100644 --- a/internal/jet/table_test.go +++ b/postgres/table_test.go @@ -1,4 +1,4 @@ -package jet +package postgres import ( "testing" diff --git a/postgres/update_statement.go b/postgres/update_statement.go index b06fc98..6f20326 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -22,10 +22,10 @@ type updateStatementImpl struct { Update jet.ClauseUpdate Set ClauseSet Where jet.ClauseWhere - Returning jet.ClauseReturning + Returning ClauseReturning } -func newUpdateStatement(table WritableTable, columns []jet.IColumn) UpdateStatement { +func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement { update := &updateStatementImpl{} update.StatementImpl = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, &update.Set, &update.Where, &update.Returning) @@ -58,7 +58,7 @@ func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateSta } type ClauseSet struct { - Columns []jet.IColumn + Columns []jet.Column Values []jet.Serializer }