diff --git a/internal/jet/clause.go b/internal/jet/clause.go index b6c4e37..b92861d 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -25,7 +25,7 @@ func (s *ClauseSelect) projections() []Projection { } func (s *ClauseSelect) Serialize(statementType StatementType, out *SqlBuilder) error { - out.newLine() + out.NewLine() out.WriteString("SELECT") if s.Distinct { @@ -74,7 +74,7 @@ func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SqlBuilder) return nil } - out.newLine() + out.NewLine() out.WriteString("GROUP BY") out.increaseIdent() @@ -114,7 +114,7 @@ type ClauseLimit struct { func (l *ClauseLimit) Serialize(statementType StatementType, out *SqlBuilder) error { if l.Count >= 0 { - out.newLine() + out.NewLine() out.WriteString("LIMIT") out.insertParametrizedArgument(l.Count) } @@ -128,7 +128,7 @@ type ClauseOffset struct { func (o *ClauseOffset) Serialize(statementType StatementType, out *SqlBuilder) error { if o.Count >= 0 { - out.newLine() + out.NewLine() out.WriteString("OFFSET") out.insertParametrizedArgument(o.Count) } @@ -145,7 +145,7 @@ func (f *ClauseFor) Serialize(statementType StatementType, out *SqlBuilder) erro return nil } - out.newLine() + out.NewLine() out.WriteString("FOR") return f.Lock.serialize(statementType, out) } @@ -179,20 +179,20 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB //} if wrap { - out.newLine() + out.NewLine() out.WriteString("(") out.increaseIdent() } for i, selectStmt := range s.Selects { - out.newLine() + out.NewLine() if i > 0 { out.WriteString(s.Operator) if s.All { out.WriteString("ALL") } - out.newLine() + out.NewLine() } if selectStmt == nil { @@ -208,7 +208,7 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB if wrap { out.decreaseIdent() - out.newLine() + out.NewLine() out.WriteString(")") } @@ -238,7 +238,7 @@ type ClauseUpdate struct { } func (u *ClauseUpdate) Serialize(statementType StatementType, out *SqlBuilder) error { - out.newLine() + out.NewLine() out.WriteString("UPDATE") if utils.IsNil(u.Table) { @@ -258,42 +258,33 @@ type ClauseSet struct { } func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) error { - out.newLine() + out.NewLine() out.WriteString("SET") - if len(s.Columns) == 0 { - return errors.New("jet: no columns selected") + if len(s.Columns) != len(s.Values) { + return errors.New("jet: mismatch in numers of columns and values") } - if len(s.Columns) > 1 { - out.WriteString("(") - } - - err := SerializeColumnNames(s.Columns, out) - - if err != nil { - return err - } - - if len(s.Columns) > 1 { - out.WriteString(")") - } - - out.WriteString("=") - - if len(s.Values) > 1 { - out.WriteString("(") - } - - err = SerializeClauseList(statementType, s.Values, out) - - if err != nil { - return err - } - - if len(s.Values) > 1 { - out.WriteString(")") + out.increaseIdent(4) + for i, column := range s.Columns { + if i > 0 { + out.WriteString(", ") + out.NewLine() + } + + if column == nil { + return errors.New("jet: nil column in columns list") + } + + out.WriteString(column.Name()) + + out.WriteString(" = ") + + if err := Serialize(s.Values[i], UpdateStatementType, out); err != nil { + return err + } } + out.decreaseIdent(4) return nil } @@ -312,7 +303,7 @@ type ClauseInsert struct { } func (i *ClauseInsert) Serialize(statementType StatementType, out *SqlBuilder) error { - out.newLine() + out.NewLine() out.WriteString("INSERT INTO") if utils.IsNil(i.Table) { @@ -357,7 +348,7 @@ func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) e } out.increaseIdent() - out.newLine() + out.NewLine() out.WriteString("(") err := SerializeClauseList(statementType, row, out) @@ -389,7 +380,7 @@ type ClauseDelete struct { } func (d *ClauseDelete) Serialize(statementType StatementType, out *SqlBuilder) error { - out.newLine() + out.NewLine() out.WriteString("DELETE FROM") if d.Table == nil { @@ -409,7 +400,7 @@ type ClauseStatementBegin struct { } func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SqlBuilder) error { - out.newLine() + out.NewLine() out.WriteString(d.Name) for i, table := range d.Tables { @@ -433,7 +424,7 @@ type ClauseString struct { } func (d *ClauseString) Serialize(statementType StatementType, out *SqlBuilder) error { - out.newLine() + out.NewLine() out.WriteString(d.Name) out.WriteString(d.Data) return nil @@ -573,7 +564,7 @@ func (t *JoinTableImpl) TableName() string { return "" } -func (t *JoinTableImpl) columns() []IColumn { +func (t *JoinTableImpl) Columns() []IColumn { //return append(t.lhs.columns(), t.rhs.columns()...) panic("Unimplemented") } @@ -601,7 +592,7 @@ func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, opti return } - out.newLine() + out.NewLine() switch t.joinType { case InnerJoin: @@ -667,7 +658,7 @@ func (s *SelectTableImpl2) Alias() string { return s.alias } -func (s *SelectTableImpl2) columns() []IColumn { +func (s *SelectTableImpl2) Columns() []IColumn { return nil } diff --git a/internal/jet/delete_statement.go b/internal/jet/delete_statement.go index c2108a9..fe6f448 100644 --- a/internal/jet/delete_statement.go +++ b/internal/jet/delete_statement.go @@ -48,7 +48,7 @@ func (d *deleteStatementImpl) serializeImpl(out *SqlBuilder) error { if d == nil { return errors.New("jet: delete statement is nil") } - out.newLine() + out.NewLine() out.WriteString("DELETE FROM") if d.table == nil { diff --git a/internal/jet/dialects.go b/internal/jet/dialects.go index 56a2c4b..912ced1 100644 --- a/internal/jet/dialects.go +++ b/internal/jet/dialects.go @@ -1,7 +1,6 @@ package jet import ( - "errors" "strconv" ) @@ -11,7 +10,6 @@ var ANSII = NewDialect(DialectParams{ // just for tests ArgumentPlaceholder: func(ord int) string { return "$" + strconv.Itoa(ord) }, - SupportsReturning: true, }) type Dialect interface { @@ -21,8 +19,6 @@ type Dialect interface { AliasQuoteChar() byte IdentifierQuoteChar() byte ArgumentPlaceholder() QueryPlaceholderFunc - SetClause() func(columns []IColumn, values []Serializer, out *SqlBuilder) (err error) - SupportsReturning() bool } type SerializeFunc func(statement StatementType, out *SqlBuilder, options ...SerializeOption) error @@ -38,9 +34,6 @@ type DialectParams struct { AliasQuoteChar byte IdentifierQuoteChar byte ArgumentPlaceholder QueryPlaceholderFunc - SetClause func(columns []IColumn, values []Serializer, out *SqlBuilder) (err error) - - SupportsReturning bool } func NewDialect(params DialectParams) Dialect { @@ -51,8 +44,6 @@ func NewDialect(params DialectParams) Dialect { aliasQuoteChar: params.AliasQuoteChar, identifierQuoteChar: params.IdentifierQuoteChar, argumentPlaceholder: params.ArgumentPlaceholder, - setClause: params.SetClause, - supportsReturning: params.SupportsReturning, } } @@ -91,41 +82,3 @@ func (d *dialectImpl) IdentifierQuoteChar() byte { func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc { return d.argumentPlaceholder } - -func (d *dialectImpl) SetClause() func(columns []IColumn, values []Serializer, out *SqlBuilder) (err error) { - if d.setClause != nil { - return d.setClause - } - return setClause -} - -func (d *dialectImpl) SupportsReturning() bool { - return d.supportsReturning -} - -func setClause(columns []IColumn, values []Serializer, out *SqlBuilder) (err error) { - - if len(columns) != len(values) { - return errors.New("jet: mismatch in numers of columns and values") - } - - for i, column := range columns { - if i > 0 { - out.WriteString(", ") - } - - if column == nil { - return errors.New("jet: nil column in columns list") - } - - out.WriteString(column.Name()) - - out.WriteString(" = ") - - if err = Serialize(values[i], UpdateStatementType, out); err != nil { - return err - } - } - - return nil -} diff --git a/internal/jet/insert_statement.go b/internal/jet/insert_statement.go index 0ff6db8..3c299be 100644 --- a/internal/jet/insert_statement.go +++ b/internal/jet/insert_statement.go @@ -88,7 +88,7 @@ func (i *insertStatementImpl) Sql(dialect ...Dialect) (query string, args []inte Dialect: detectDialect(i, dialect...), } - out.newLine() + out.NewLine() out.WriteString("INSERT INTO") if utils.IsNil(i.table) { @@ -132,7 +132,7 @@ func (i *insertStatementImpl) Sql(dialect ...Dialect) (query string, args []inte } out.increaseIdent() - out.newLine() + out.NewLine() out.WriteString("(") err = SerializeClauseList(InsertStatementType, row, out) diff --git a/internal/jet/lock_statement.go b/internal/jet/lock_statement.go index 397d5f5..f49f2bc 100644 --- a/internal/jet/lock_statement.go +++ b/internal/jet/lock_statement.go @@ -66,7 +66,7 @@ func (l *lockStatementImpl) Sql(dialect ...Dialect) (query string, args []interf Dialect: detectDialect(l, dialect...), } - out.newLine() + out.NewLine() out.WriteString("LOCK TABLE") for i, table := range l.tables { diff --git a/internal/jet/select_statement.go b/internal/jet/select_statement.go index 6162ebc..77e2d9a 100644 --- a/internal/jet/select_statement.go +++ b/internal/jet/select_statement.go @@ -163,7 +163,7 @@ func (s *selectStatementImpl) serialize(statement StatementType, out *SqlBuilder return err } - out.newLine() + out.NewLine() out.WriteString(")") return nil @@ -174,7 +174,7 @@ func (s *selectStatementImpl) serializeImpl(out *SqlBuilder) error { return errors.New("jet: Select expression is nil. ") } - out.newLine() + out.NewLine() out.WriteString("SELECT") if s.distinct { @@ -230,19 +230,19 @@ func (s *selectStatementImpl) serializeImpl(out *SqlBuilder) error { } if s.limit >= 0 { - out.newLine() + out.NewLine() out.WriteString("LIMIT") out.insertParametrizedArgument(s.limit) } if s.offset >= 0 { - out.newLine() + out.NewLine() out.WriteString("OFFSET") out.insertParametrizedArgument(s.offset) } if s.lockFor != nil { - out.newLine() + out.NewLine() out.WriteString("FOR") err := s.lockFor.serialize(SelectStatementType, out) diff --git a/internal/jet/set_statement.go b/internal/jet/set_statement.go index 9d2b83a..77dd783 100644 --- a/internal/jet/set_statement.go +++ b/internal/jet/set_statement.go @@ -109,7 +109,7 @@ func (s *setStatementImpl) serialize(statement StatementType, out *SqlBuilder, o if wrap { out.decreaseIdent() - out.newLine() + out.NewLine() out.WriteString(")") } @@ -129,19 +129,19 @@ func (s *setStatementImpl) serializeImpl(out *SqlBuilder) error { return setOverride()(SelectStatementType, out) } - out.newLine() + out.NewLine() out.WriteString("(") out.increaseIdent() for i, selectStmt := range s.selects { - out.newLine() + out.NewLine() if i > 0 { out.WriteString(s.operator) if s.all { out.WriteString("ALL") } - out.newLine() + out.NewLine() } if selectStmt == nil { @@ -156,7 +156,7 @@ func (s *setStatementImpl) serializeImpl(out *SqlBuilder) error { } out.decreaseIdent() - out.newLine() + out.NewLine() out.WriteString(")") if s.orderBy != nil { @@ -167,13 +167,13 @@ func (s *setStatementImpl) serializeImpl(out *SqlBuilder) error { } if s.limit >= 0 { - out.newLine() + out.NewLine() out.WriteString("LIMIT") out.insertParametrizedArgument(s.limit) } if s.offset >= 0 { - out.newLine() + out.NewLine() out.WriteString("OFFSET") out.insertParametrizedArgument(s.offset) } diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 4bb4b5c..dff7cea 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -24,16 +24,26 @@ func (s *SqlBuilder) DebugSQL() string { const defaultIdent = 5 -func (q *SqlBuilder) increaseIdent() { - q.ident += defaultIdent +func (q *SqlBuilder) increaseIdent(ident ...int) { + if len(ident) > 0 { + q.ident += ident[0] + } else { + q.ident += defaultIdent + } } -func (q *SqlBuilder) decreaseIdent() { - if q.ident < defaultIdent { +func (q *SqlBuilder) decreaseIdent(ident ...int) { + toDecrease := defaultIdent + + if len(ident) > 0 { + toDecrease = ident[0] + } + + if q.ident < toDecrease { q.ident = 0 } - q.ident -= defaultIdent + q.ident -= toDecrease } func (q *SqlBuilder) writeProjections(statement StatementType, projections []Projection) error { @@ -44,7 +54,7 @@ func (q *SqlBuilder) writeProjections(statement StatementType, projections []Pro } func (q *SqlBuilder) writeFrom(statement StatementType, table Serializer) error { - q.newLine() + q.NewLine() q.WriteString("FROM") q.increaseIdent() @@ -55,7 +65,7 @@ func (q *SqlBuilder) writeFrom(statement StatementType, table Serializer) error } func (q *SqlBuilder) writeWhere(statement StatementType, where Expression) error { - q.newLine() + q.NewLine() q.WriteString("WHERE") q.increaseIdent() @@ -66,7 +76,7 @@ func (q *SqlBuilder) writeWhere(statement StatementType, where Expression) error } func (q *SqlBuilder) writeGroupBy(statement StatementType, groupBy []GroupByClause) error { - q.newLine() + q.NewLine() q.WriteString("GROUP BY") q.increaseIdent() @@ -77,7 +87,7 @@ func (q *SqlBuilder) writeGroupBy(statement StatementType, groupBy []GroupByClau } func (q *SqlBuilder) writeOrderBy(statement StatementType, orderBy []OrderByClause) error { - q.newLine() + q.NewLine() q.WriteString("ORDER BY") q.increaseIdent() @@ -88,7 +98,7 @@ func (q *SqlBuilder) writeOrderBy(statement StatementType, orderBy []OrderByClau } func (q *SqlBuilder) writeHaving(statement StatementType, having Expression) error { - q.newLine() + q.NewLine() q.WriteString("HAVING") q.increaseIdent() @@ -103,18 +113,14 @@ func (q *SqlBuilder) WriteReturning(statement StatementType, returning []Project return nil } - if !q.Dialect.SupportsReturning() { - panic("jet: " + q.Dialect.Name() + " dialect does not support RETURNING.") - } - - q.newLine() + q.NewLine() q.WriteString("RETURNING") q.increaseIdent() return q.writeProjections(statement, returning) } -func (q *SqlBuilder) newLine() { +func (q *SqlBuilder) NewLine() { q.write([]byte{'\n'}) q.write(bytes.Repeat([]byte{' '}, q.ident)) } diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 3e77932..4057a7d 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -217,7 +217,7 @@ func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, opti if !contains(options, noWrap) { out.decreaseIdent() - out.newLine() + out.NewLine() out.WriteString(")") } diff --git a/internal/jet/table.go b/internal/jet/table.go index 923a3ad..c0b69da 100644 --- a/internal/jet/table.go +++ b/internal/jet/table.go @@ -7,10 +7,7 @@ import ( type SerializerTable interface { Serializer - Columns() []IColumn - //SchemaName() string - //TableName() string - //AS(alias string) + TableInterface } type TableInterface interface { @@ -284,7 +281,7 @@ func (t *joinTable) serialize(statement StatementType, out *SqlBuilder, options return } - out.newLine() + out.NewLine() switch t.joinType { case InnerJoin: diff --git a/internal/jet/update_statement.go b/internal/jet/update_statement.go index 5a7b235..59dc9ed 100644 --- a/internal/jet/update_statement.go +++ b/internal/jet/update_statement.go @@ -67,7 +67,7 @@ func (u *updateStatementImpl) Sql(dialect ...Dialect) (query string, args []inte Dialect: detectDialect(u, dialect...), } - out.newLine() + out.NewLine() out.WriteString("UPDATE") if utils.IsNil(u.table) { @@ -86,13 +86,9 @@ func (u *updateStatementImpl) Sql(dialect ...Dialect) (query string, args []inte return "", nil, errors.New("jet: no values to updated") } - out.newLine() + out.NewLine() out.WriteString("SET") - if err = out.Dialect.SetClause()(u.columns, u.values, out); err != nil { - return - } - if u.where == nil { return "", nil, errors.New("jet: WHERE clause not set") } diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 2945377..e1a7924 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -83,7 +83,7 @@ func SerializeProjectionList(statement StatementType, projections []Projection, for i, col := range projections { if i > 0 { out.WriteString(",") - out.newLine() + out.NewLine() } if col == nil { diff --git a/mysql/delete_statement.go b/mysql/delete_statement.go new file mode 100644 index 0000000..8aa3977 --- /dev/null +++ b/mysql/delete_statement.go @@ -0,0 +1,33 @@ +package mysql + +import "github.com/go-jet/jet/internal/jet" + +type DeleteStatement interface { + jet.Statement + + WHERE(expression BoolExpression) Statement +} + +type deleteStatementImpl struct { + jet.StatementImpl + + Delete jet.ClauseStatementBegin + Where jet.ClauseWhere +} + +func newDeleteStatement(table Table) DeleteStatement { + newDelete := &deleteStatementImpl{} + newDelete.StatementImpl = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete, &newDelete.Delete, + &newDelete.Where) + + newDelete.Delete.Name = "DELETE FROM" + newDelete.Delete.Tables = append(newDelete.Delete.Tables, table) + newDelete.Where.Mandatory = true + + return newDelete +} + +func (d *deleteStatementImpl) WHERE(expression BoolExpression) Statement { + d.Where.Condition = expression + return d +} diff --git a/mysql/delete_statement_test.go b/mysql/delete_statement_test.go new file mode 100644 index 0000000..5d84802 --- /dev/null +++ b/mysql/delete_statement_test.go @@ -0,0 +1,17 @@ +package mysql + +import ( + "testing" +) + +func TestDeleteUnconditionally(t *testing.T) { + assertStatementErr(t, table1.DELETE(), `jet: WHERE clause not set`) + assertStatementErr(t, table1.DELETE().WHERE(nil), `jet: WHERE clause not set`) +} + +func TestDeleteWithWhere(t *testing.T) { + assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), ` +DELETE FROM db.table1 +WHERE table1.col1 = ?; +`, int64(1)) +} diff --git a/mysql/dialect.go b/mysql/dialect.go index 6b05ead..bb61adf 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -15,8 +15,6 @@ func NewDialect() jet.Dialect { serializeOverrides["/"] = mysql_DIVISION serializeOverrides["#"] = mysql_BIT_XOR serializeOverrides[jet.StringConcatOperator] = mysql_CONCAT_operator - serializeOverrides[jet.Except] = mysql_EXCEPT - serializeOverrides[jet.Intersect] = mysql_INTERSECT mySQLDialectParams := jet.DialectParams{ Name: "MySQL", @@ -27,24 +25,11 @@ func NewDialect() jet.Dialect { ArgumentPlaceholder: func(int) string { return "?" }, - SupportsReturning: false, } return jet.NewDialect(mySQLDialectParams) } -func mysql_EXCEPT(expressions ...jet.Expression) jet.SerializeFunc { - return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { - panic("jet: MySQL does not support EXCEPT operator.") - } -} - -func mysql_INTERSECT(expressions ...jet.Expression) jet.SerializeFunc { - return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { - panic("jet: MySQL does not support INTERSECT operator.") - } -} - func mysql_BIT_XOR(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { if len(expressions) != 2 { diff --git a/mysql/expressions.go b/mysql/expressions.go index 2ff56b8..297507d 100644 --- a/mysql/expressions.go +++ b/mysql/expressions.go @@ -32,3 +32,5 @@ var TimestampExp = jet.TimestampExp var Raw = jet.Raw var NewEnumValue = jet.NewEnumValue + +type Statement jet.Statement diff --git a/mysql/functions.go b/mysql/functions.go index 981580a..073d615 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -27,8 +27,6 @@ var TRUNCATE = func(floatExpression jet.FloatExpression, precision jet.IntegerEx return jet.NewFloatFunc("TRUNCATE", floatExpression, precision) } -//var MINUSi = jet.MINUSi -//var MINUSf = jet.MINUSf var BIT_NOT = jet.BIT_NOT // ----------------- Aggregate functions -------------------// diff --git a/mysql/insert_statement.go b/mysql/insert_statement.go new file mode 100644 index 0000000..21c239f --- /dev/null +++ b/mysql/insert_statement.go @@ -0,0 +1,64 @@ +package mysql + +import "github.com/go-jet/jet/internal/jet" + +// InsertStatement is interface for SQL INSERT statements +type InsertStatement interface { + jet.Statement + + // Insert row of values + VALUES(value interface{}, values ...interface{}) InsertStatement + // Insert row of values, where value for each column is extracted from filed of structure data. + // If data is not struct or there is no field for every column selected, this method will panic. + MODEL(data interface{}) InsertStatement + MODELS(data interface{}) InsertStatement + + QUERY(selectStatement SelectStatement) InsertStatement +} + +func newInsertStatement(table Table, columns []jet.IColumn) InsertStatement { + newInsert := &insertStatementImpl{} + newInsert.StatementImpl = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newInsert, + &newInsert.Insert, &newInsert.Values, &newInsert.Select) + + newInsert.Insert.Table = table + newInsert.Insert.Columns = columns + + return newInsert +} + +type insertStatementImpl struct { + jet.StatementImpl + + Insert jet.ClauseInsert + Values jet.ClauseValues + Select jet.ClauseQuery +} + +func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { + i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromValues(value, values)) + return i +} + +func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { + i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromModel(i.getColumns(), data)) + return i +} + +func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { + i.Values.Rows = append(i.Values.Rows, jet.UnwindRowsFromModels(i.getColumns(), data)...) + return i +} + +func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { + i.Select.Query = selectStatement + return i +} + +func (i *insertStatementImpl) getColumns() []jet.IColumn { + if len(i.Insert.Columns) > 0 { + return i.Insert.Columns + } + + return i.Insert.Table.Columns() +} diff --git a/mysql/insert_statement_test.go b/mysql/insert_statement_test.go new file mode 100644 index 0000000..cb1513b --- /dev/null +++ b/mysql/insert_statement_test.go @@ -0,0 +1,134 @@ +package mysql + +import ( + "gotest.tools/assert" + "testing" + "time" +) + +//TODO: +//func TestInvalidInsert(t *testing.T) { +// assertStatementErr(t, table1.INSERT(table1Col1), "jet: no row values or query specified") +// assertStatementErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") +//} + +func TestInsertNilValue(t *testing.T) { + assertStatement(t, table1.INSERT(table1Col1).VALUES(nil), ` +INSERT INTO db.table1 (col1) VALUES + (?); +`, nil) +} + +func TestInsertSingleValue(t *testing.T) { + assertStatement(t, table1.INSERT(table1Col1).VALUES(1), ` +INSERT INTO db.table1 (col1) VALUES + (?); +`, int(1)) +} + +func TestInsertWithColumnList(t *testing.T) { + columnList := ColumnList(table3ColInt, table3StrCol) + + assertStatement(t, table3.INSERT(columnList).VALUES(1, 3), ` +INSERT INTO db.table3 (col_int, col2) VALUES + (?, ?); +`, 1, 3) +} + +func TestInsertDate(t *testing.T) { + date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) + + assertStatement(t, table1.INSERT(table1ColTimestamp).VALUES(date), ` +INSERT INTO db.table1 (col_timestamp) VALUES + (?); +`, date) +} + +func TestInsertMultipleValues(t *testing.T) { + assertStatement(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` +INSERT INTO db.table1 (col1, col_float, col3) VALUES + (?, ?, ?); +`, 1, 2, 3) +} + +func TestInsertMultipleRows(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColFloat). + VALUES(1, 2). + VALUES(11, 22). + VALUES(111, 222) + + assertStatement(t, stmt, ` +INSERT INTO db.table1 (col1, col_float) VALUES + (?, ?), + (?, ?), + (?, ?); +`, 1, 2, 11, 22, 111, 222) +} + +func TestInsertValuesFromModel(t *testing.T) { + type Table1Model struct { + Col1 *int + ColFloat float64 + } + + one := 1 + + toInsert := Table1Model{ + Col1: &one, + ColFloat: 1.11, + } + + stmt := table1.INSERT(table1Col1, table1ColFloat). + MODEL(toInsert). + MODEL(&toInsert) + + expectedSQL := ` +INSERT INTO db.table1 (col1, col_float) VALUES + (?, ?), + (?, ?); +` + + assertStatement(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) +} + +func TestInsertValuesFromModelColumnMismatch(t *testing.T) { + defer func() { + r := recover() + assert.Equal(t, r, "missing struct field for column : col1") + }() + type Table1Model struct { + Col1Prim int + Col2 string + } + + newData := Table1Model{ + Col1Prim: 1, + Col2: "one", + } + + table1. + INSERT(table1Col1, table1ColFloat). + MODEL(newData) +} + +func TestInsertFromNonStructModel(t *testing.T) { + + defer func() { + r := recover() + assert.Equal(t, r, "argument mismatch: expected struct, got []int") + }() + + table2.INSERT(table2ColInt).MODEL([]int{}) +} + +func TestInsertDefaultValue(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColFloat). + VALUES(DEFAULT, "two") + + var expectedSQL = ` +INSERT INTO db.table1 (col1, col_float) VALUES + (DEFAULT, ?); +` + + assertStatement(t, stmt, expectedSQL, "two") +} diff --git a/mysql/literal.go b/mysql/literal.go index fd18c4f..a2b98db 100644 --- a/mysql/literal.go +++ b/mysql/literal.go @@ -5,6 +5,10 @@ import ( "time" ) +var STAR = jet.STAR +var NULL = jet.NULL +var DEFAULT = jet.DEFAULT + var Bool = jet.Bool var Int = jet.Int var Float = jet.Float diff --git a/mysql/select_statement.go b/mysql/select_statement.go new file mode 100644 index 0000000..7abb9f3 --- /dev/null +++ b/mysql/select_statement.go @@ -0,0 +1,118 @@ +package mysql + +import "github.com/go-jet/jet/internal/jet" + +type SelectLock = jet.SelectLock + +var ( + UPDATE = jet.NewSelectLock("UPDATE") + SHARE = jet.NewSelectLock("SHARE") +) + +type SelectStatement interface { + jet.Statement + jet.HasProjections + jet.IExpression + + DISTINCT() SelectStatement + FROM(table ReadableTable) SelectStatement + WHERE(expression BoolExpression) SelectStatement + GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement + HAVING(boolExpression BoolExpression) SelectStatement + ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement + LIMIT(limit int64) SelectStatement + OFFSET(offset int64) SelectStatement + FOR(lock SelectLock) SelectStatement + + UNION(rhs SelectStatement) SetStatement + UNION_ALL(rhs SelectStatement) SetStatement + + AsTable(alias string) SelectTable +} + +//SELECT creates new SelectStatement with list of projections +func SELECT(projection jet.Projection, projections ...jet.Projection) SelectStatement { + return newSelectStatement(nil, append([]jet.Projection{projection}, projections...)) +} + +func newSelectStatement(table ReadableTable, projections []jet.Projection) SelectStatement { + newSelect := &selectStatementImpl{} + newSelect.ExpressionStatementImpl.StatementImpl = jet.NewStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, + &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, + &newSelect.Limit, &newSelect.Offset, &newSelect.For) + + newSelect.ExpressionStatementImpl.ExpressionInterfaceImpl.Parent = newSelect + + newSelect.Select.Projections = projections + newSelect.From.Table = table + newSelect.Limit.Count = -1 + newSelect.Offset.Count = -1 + + newSelect.setOperatorsImpl.parent = newSelect + + return newSelect +} + +type selectStatementImpl struct { + jet.ExpressionStatementImpl + setOperatorsImpl + + Select jet.ClauseSelect + From jet.ClauseFrom + Where jet.ClauseWhere + GroupBy jet.ClauseGroupBy + Having jet.ClauseHaving + OrderBy jet.ClauseOrderBy + Limit jet.ClauseLimit + Offset jet.ClauseOffset + For jet.ClauseFor +} + +func (s *selectStatementImpl) DISTINCT() SelectStatement { + s.Select.Distinct = true + return s +} + +func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement { + s.From.Table = table + return s +} + +func (s *selectStatementImpl) WHERE(condition BoolExpression) SelectStatement { + s.Where.Condition = condition + return s +} + +func (s *selectStatementImpl) GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement { + s.GroupBy.List = groupByClauses + return s +} + +func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatement { + s.Having.Condition = boolExpression + return s +} + +func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement { + s.OrderBy.List = orderByClauses + return s +} + +func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { + s.Limit.Count = limit + return s +} + +func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { + s.Offset.Count = offset + return s +} + +func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement { + s.For.Lock = lock + return s +} + +func (s *selectStatementImpl) AsTable(alias string) SelectTable { + return newSelectTable(s, alias) +} diff --git a/mysql/select_statement_test.go b/mysql/select_statement_test.go new file mode 100644 index 0000000..96f0c54 --- /dev/null +++ b/mysql/select_statement_test.go @@ -0,0 +1,126 @@ +package mysql + +import ( + "github.com/go-jet/jet/internal/testutils" + "testing" +) + +func TestInvalidSelect(t *testing.T) { + assertStatementErr(t, SELECT(nil), "jet: Projection is nil") +} + +func TestSelectColumnList(t *testing.T) { + columnList := ColumnList(table2ColInt, table2ColFloat, table3ColInt) + + assertStatement(t, SELECT(columnList).FROM(table2), ` +SELECT table2.col_int AS "table2.col_int", + table2.col_float AS "table2.col_float", + table3.col_int AS "table3.col_int" +FROM db.table2; +`) +} + +func TestSelectLiterals(t *testing.T) { + assertStatement(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` +SELECT ?, + ?, + ? +FROM db.table1; +`, int64(1), 2.2, false) +} + +func TestSelectDistinct(t *testing.T) { + assertStatement(t, SELECT(table1ColBool).DISTINCT().FROM(table1), ` +SELECT DISTINCT table1.col_bool AS "table1.col_bool" +FROM db.table1; +`) +} + +func TestSelectFrom(t *testing.T) { + assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1; +`) + assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1 + INNER JOIN db.table2 ON (table1.col_int = table2.col_int); +`) + assertStatement(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1 + INNER JOIN db.table2 ON (table1.col_int = table2.col_int); +`) +} + +func TestSelectWhere(t *testing.T) { + assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` +SELECT table1.col_int AS "table1.col_int" +FROM db.table1 +WHERE ?; +`, true) + assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` +SELECT table1.col_int AS "table1.col_int" +FROM db.table1 +WHERE table1.col_int >= ?; +`, int64(10)) +} + +func TestSelectGroupBy(t *testing.T) { + assertStatement(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +GROUP BY table2.col_float; +`) +} + +func TestSelectHaving(t *testing.T) { + assertStatement(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` +SELECT table3.col_int AS "table3.col_int" +FROM db.table3 +HAVING table1.col_bool = ?; +`, true) +} + +func TestSelectOrderBy(t *testing.T) { + assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), ` +SELECT table2.col_float AS "table2.col_float" +FROM db.table2 +ORDER BY table2.col_int DESC; +`) + assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), ` +SELECT table2.col_float AS "table2.col_float" +FROM db.table2 +ORDER BY table2.col_int DESC, table2.col_int ASC; +`) +} + +func TestSelectLimitOffset(t *testing.T) { + assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +LIMIT ?; +`, int64(10)) + assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +LIMIT ? +OFFSET ?; +`, int64(10), int64(2)) +} + +func TestSelectLock(t *testing.T) { + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR UPDATE; +`) + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(SHARE().NOWAIT()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR SHARE NOWAIT; +`) +} diff --git a/mysql/select_table.go b/mysql/select_table.go new file mode 100644 index 0000000..7a314d8 --- /dev/null +++ b/mysql/select_table.go @@ -0,0 +1,23 @@ +package mysql + +import "github.com/go-jet/jet/internal/jet" + +type SelectTable interface { + ReadableTable + jet.SelectTable +} + +type selectTableImpl struct { + jet.SelectTableImpl2 + readableTableInterfaceImpl +} + +func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable { + subQuery := &selectTableImpl{ + SelectTableImpl2: jet.NewSelectTable(selectStmt, alias), + } + + subQuery.readableTableInterfaceImpl.parent = subQuery + + return subQuery +} diff --git a/mysql/set_statement.go b/mysql/set_statement.go new file mode 100644 index 0000000..86ce186 --- /dev/null +++ b/mysql/set_statement.go @@ -0,0 +1,104 @@ +package mysql + +import "github.com/go-jet/jet/internal/jet" + +// UNION effectively appends the result of sub-queries(select statements) into single query. +// It eliminates duplicate rows from its result. +func UNION(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) SetStatement { + return newSetStatementImpl(Union, false, toSelectList(lhs, rhs, selects...)) +} + +// UNION_ALL effectively appends the result of sub-queries(select statements) into single query. +// It does not eliminates duplicate rows from its result. +func UNION_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) SetStatement { + return newSetStatementImpl(Union, true, toSelectList(lhs, rhs, selects...)) +} + +type SetStatement interface { + SetOperators + + ORDER_BY(orderByClauses ...jet.OrderByClause) SetStatement + + LIMIT(limit int64) SetStatement + OFFSET(offset int64) SetStatement + + AsTable(alias string) SelectTable +} + +type SetStatementFinal interface { +} + +type SetOperators interface { + jet.Statement + jet.HasProjections + jet.IExpression + + UNION(rhs SelectStatement) SetStatement + UNION_ALL(rhs SelectStatement) SetStatement +} + +type setOperatorsImpl struct { + parent SetOperators +} + +func (s *setOperatorsImpl) UNION(rhs SelectStatement) SetStatement { + return UNION(s.parent, rhs) +} + +func (s *setOperatorsImpl) UNION_ALL(rhs SelectStatement) SetStatement { + return UNION_ALL(s.parent, rhs) +} + +type setStatementImpl struct { + jet.ExpressionStatementImpl + + setOperatorsImpl + + setOperator jet.ClauseSetStmtOperator +} + +func newSetStatementImpl(operator string, all bool, selects []jet.StatementWithProjections) SetStatement { + newSetStatement := &setStatementImpl{} + newSetStatement.ExpressionStatementImpl.StatementImpl = jet.NewStatementImpl(Dialect, jet.SetStatementType, newSetStatement, + &newSetStatement.setOperator) + newSetStatement.ExpressionStatementImpl.ExpressionInterfaceImpl.Parent = newSetStatement + + newSetStatement.setOperator.Operator = operator + newSetStatement.setOperator.All = all + newSetStatement.setOperator.Selects = selects + newSetStatement.setOperator.Limit.Count = -1 + newSetStatement.setOperator.Offset.Count = -1 + + newSetStatement.setOperatorsImpl.parent = newSetStatement + + newSetStatement.Clauses = []jet.Clause{&newSetStatement.setOperator} + + return newSetStatement +} + +func (s *setStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SetStatement { + s.setOperator.OrderBy.List = orderByClauses + return s +} + +func (s *setStatementImpl) LIMIT(limit int64) SetStatement { + s.setOperator.Limit.Count = limit + return s +} + +func (s *setStatementImpl) OFFSET(offset int64) SetStatement { + s.setOperator.Offset.Count = offset + return s +} + +func (s *setStatementImpl) AsTable(alias string) SelectTable { + return newSelectTable(s, alias) +} + +const ( + Union = "UNION" +) + +func toSelectList(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) []jet.StatementWithProjections { + return append([]jet.StatementWithProjections{lhs, rhs}, selects...) +} diff --git a/mysql/set_statement_test.go b/mysql/set_statement_test.go new file mode 100644 index 0000000..950b511 --- /dev/null +++ b/mysql/set_statement_test.go @@ -0,0 +1,33 @@ +package mysql + +import ( + "testing" +) + +func TestSelectSets(t *testing.T) { + select1 := SELECT(table1ColBool).FROM(table1) + select2 := SELECT(table2ColBool).FROM(table2) + + assertStatement(t, select1.UNION(select2), ` +( + SELECT table1.col_bool AS "table1.col_bool" + FROM db.table1 +) +UNION +( + SELECT table2.col_bool AS "table2.col_bool" + FROM db.table2 +); +`) + assertStatement(t, select1.UNION_ALL(select2), ` +( + SELECT table1.col_bool AS "table1.col_bool" + FROM db.table1 +) +UNION ALL +( + SELECT table2.col_bool AS "table2.col_bool" + FROM db.table2 +); +`) +} diff --git a/mysql/statements.go b/mysql/statements.go deleted file mode 100644 index 30cd202..0000000 --- a/mysql/statements.go +++ /dev/null @@ -1,23 +0,0 @@ -package mysql - -import "github.com/go-jet/jet/internal/jet" - -// ----------------- FUNCTIONS ----------------------// - -var SELECT = jet.SELECT - -type SelectLock jet.SelectLock - -var ( - UPDATE = jet.NewSelectLock("UPDATE") - SHARE = jet.NewSelectLock("SHARE") -) - -var UNION = jet.UNION -var UNION_ALL = jet.UNION_ALL - -//-----------------literals----------------------// - -var STAR = jet.STAR -var NULL = jet.NULL -var DEFAULT = jet.DEFAULT diff --git a/mysql/table.go b/mysql/table.go index 8eee2b2..c463882 100644 --- a/mysql/table.go +++ b/mysql/table.go @@ -2,8 +2,126 @@ package mysql import "github.com/go-jet/jet/internal/jet" -type Table jet.Table +//type Table jet.Table +// +//func NewTable(schemaName, name string, columns ...jet.Column) Table { +// return jet.NewTable(Dialect, schemaName, name, columns...) +//} + +type Table interface { + jet.SerializerTable + readableTable + + INSERT(columns ...jet.IColumn) InsertStatement + UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement + DELETE() DeleteStatement + //LOCK() LockStatement + + AS(alias string) +} + +type readableTable interface { + // Generates a select query on the current tableName. + SELECT(projection jet.Projection, projections ...jet.Projection) SelectStatement + + // Creates a inner join tableName Expression using onCondition. + INNER_JOIN(table ReadableTable, onCondition BoolExpression) Table + + // Creates a left join tableName Expression using onCondition. + LEFT_JOIN(table ReadableTable, onCondition BoolExpression) Table + + // Creates a right join tableName Expression using onCondition. + RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) Table + + // Creates a full join tableName Expression using onCondition. + FULL_JOIN(table ReadableTable, onCondition BoolExpression) Table + + // Creates a cross join tableName Expression using onCondition. + CROSS_JOIN(table ReadableTable) Table +} + +type ReadableTable interface { + jet.SerializerTable + readableTable +} + +type readableTableInterfaceImpl struct { + parent ReadableTable +} + +// Generates a select query on the current tableName. +func (r *readableTableInterfaceImpl) SELECT(projection1 jet.Projection, projections ...jet.Projection) SelectStatement { + return newSelectStatement(r.parent, append([]jet.Projection{projection1}, projections...)) +} + +// Creates a inner join tableName Expression using onCondition. +func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) Table { + return newJoinTable(r.parent, table, jet.InnerJoin, onCondition) +} + +// Creates a left join tableName Expression using onCondition. +func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) Table { + return newJoinTable(r.parent, table, jet.LeftJoin, onCondition) +} + +// Creates a right join tableName Expression using onCondition. +func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) Table { + return newJoinTable(r.parent, table, jet.RightJoin, onCondition) +} + +func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) Table { + return newJoinTable(r.parent, table, jet.FullJoin, onCondition) +} + +func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) Table { + return newJoinTable(r.parent, table, jet.CrossJoin, nil) +} func NewTable(schemaName, name string, columns ...jet.Column) Table { - return jet.NewTable(Dialect, schemaName, name, columns...) + t := &tableImpl{ + TableImpl2: jet.NewTable2(Dialect, schemaName, name, columns...), + } + + t.readableTableInterfaceImpl.parent = t + t.parent = t + + return t +} + +type tableImpl struct { + jet.TableImpl2 + readableTableInterfaceImpl + parent Table +} + +func (w *tableImpl) INSERT(columns ...jet.IColumn) InsertStatement { + return newInsertStatement(w.parent, jet.UnwidColumnList(columns)) +} + +func (w *tableImpl) UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement { + return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...)) +} + +func (w *tableImpl) DELETE() DeleteStatement { + return newDeleteStatement(w.parent) +} + +//func (w *tableInterfaceImpl) LOCK() LockStatement { +// return LOCK(w.parent) +//} + +type joinTable2 struct { + tableImpl + jet.JoinTableImpl +} + +func newJoinTable(lhs jet.Serializer, rhs jet.Serializer, joinType jet.JoinType, onCondition BoolExpression) Table { + newJoinTable := &joinTable2{ + JoinTableImpl: jet.NewJoinTableImpl(lhs, rhs, joinType, onCondition), + } + + newJoinTable.readableTableInterfaceImpl.parent = newJoinTable + newJoinTable.parent = newJoinTable + + return newJoinTable } diff --git a/mysql/update_statement.go b/mysql/update_statement.go new file mode 100644 index 0000000..cf82972 --- /dev/null +++ b/mysql/update_statement.go @@ -0,0 +1,48 @@ +package mysql + +import "github.com/go-jet/jet/internal/jet" + +// UpdateStatement is interface of SQL UPDATE statement +type UpdateStatement interface { + jet.Statement + + SET(value interface{}, values ...interface{}) UpdateStatement + MODEL(data interface{}) UpdateStatement + + WHERE(expression BoolExpression) UpdateStatement +} + +type updateStatementImpl struct { + jet.StatementImpl + + Update jet.ClauseUpdate + Set jet.ClauseSet + Where jet.ClauseWhere +} + +func newUpdateStatement(table Table, columns []jet.IColumn) UpdateStatement { + update := &updateStatementImpl{} + update.StatementImpl = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, + &update.Set, &update.Where) + + update.Update.Table = table + update.Set.Columns = columns + update.Where.Mandatory = true + + return update +} + +func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { + u.Set.Values = jet.UnwindRowFromValues(value, values) + return u +} + +func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement { + u.Set.Values = jet.UnwindRowFromModel(u.Set.Columns, data) + return u +} + +func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { + u.Where.Condition = expression + return u +} diff --git a/mysql/update_statement_test.go b/mysql/update_statement_test.go new file mode 100644 index 0000000..2980ef6 --- /dev/null +++ b/mysql/update_statement_test.go @@ -0,0 +1,63 @@ +package mysql + +import ( + "fmt" + "testing" +) + +func TestUpdateWithOneValue(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_int = ? +WHERE table1.col_int >= ?; +` + stmt := table1.UPDATE(table1ColInt). + SET(1). + WHERE(table1ColInt.GT_EQ(Int(33))) + + fmt.Println(stmt.Sql()) + + assertStatement(t, stmt, expectedSQL, 1, int64(33)) +} + +func TestUpdateWithValues(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_int = ?, + col_float = ? +WHERE table1.col_int >= ?; +` + stmt := table1.UPDATE(table1ColInt, table1ColFloat). + SET(1, 22.2). + WHERE(table1ColInt.GT_EQ(Int(33))) + + fmt.Println(stmt.Sql()) + + assertStatement(t, stmt, expectedSQL, 1, 22.2, int64(33)) +} + +func TestUpdateOneColumnWithSelect(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_float = ( + SELECT table1.col_float AS "table1.col_float" + FROM db.table1 + ) +WHERE table1.col1 = ?; +` + stmt := table1. + UPDATE(table1ColFloat). + SET( + table1.SELECT(table1ColFloat), + ). + WHERE(table1Col1.EQ(Int(2))) + + //fmt.Println(stmt.Sql()) + + assertStatement(t, stmt, expectedSQL, int64(2)) +} + +func TestInvalidInputs(t *testing.T) { + assertStatementErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") + assertStatementErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list") +} diff --git a/postgres/dialect.go b/postgres/dialect.go index 81ddc76..d392542 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -24,8 +24,6 @@ func NewDialect() jet.Dialect { ArgumentPlaceholder: func(ord int) string { return "$" + strconv.Itoa(ord) }, - //SetClause: postgresSetClause, - SupportsReturning: true, } return jet.NewDialect(dialectParams) diff --git a/postgres/select_statement.go b/postgres/select_statement.go index ee3ed77..61b7a58 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -122,23 +122,3 @@ func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement { func (s *selectStatementImpl) AsTable(alias string) SelectTable { return newSelectTable(s, alias) } - -type SelectTable interface { - ReadableTable - jet.SelectTable -} - -type selectTableImpl struct { - jet.SelectTableImpl2 - readableTableInterfaceImpl -} - -func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable { - subQuery := &selectTableImpl{ - SelectTableImpl2: jet.NewSelectTable(selectStmt, alias), - } - - subQuery.readableTableInterfaceImpl.parent = subQuery - - return subQuery -} diff --git a/postgres/select_table.go b/postgres/select_table.go new file mode 100644 index 0000000..7b94607 --- /dev/null +++ b/postgres/select_table.go @@ -0,0 +1,23 @@ +package postgres + +import "github.com/go-jet/jet/internal/jet" + +type SelectTable interface { + ReadableTable + jet.SelectTable +} + +type selectTableImpl struct { + jet.SelectTableImpl2 + readableTableInterfaceImpl +} + +func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable { + subQuery := &selectTableImpl{ + SelectTableImpl2: jet.NewSelectTable(selectStmt, alias), + } + + subQuery.readableTableInterfaceImpl.parent = subQuery + + return subQuery +} diff --git a/postgres/table.go b/postgres/table.go index f7423dd..bce6fe3 100644 --- a/postgres/table.go +++ b/postgres/table.go @@ -120,10 +120,6 @@ func NewTable(schemaName, name string, columns ...jet.Column) Table { TableImpl2: jet.NewTable2(Dialect, schemaName, name, columns...), } - for _, c := range columns { - c.SetTableName(name) - } - t.readableTableInterfaceImpl.parent = t t.writableTableInterfaceImpl.parent = t diff --git a/postgres/update_statement.go b/postgres/update_statement.go index 2ec984c..b06fc98 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -1,6 +1,9 @@ package postgres -import "github.com/go-jet/jet/internal/jet" +import ( + "errors" + "github.com/go-jet/jet/internal/jet" +) // UpdateStatement is interface of SQL UPDATE statement type UpdateStatement interface { @@ -17,7 +20,7 @@ type updateStatementImpl struct { jet.StatementImpl Update jet.ClauseUpdate - Set jet.ClauseSet + Set ClauseSet Where jet.ClauseWhere Returning jet.ClauseReturning } @@ -53,3 +56,49 @@ func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateSta u.Returning.Projections = projections return u } + +type ClauseSet struct { + Columns []jet.IColumn + Values []jet.Serializer +} + +func (s *ClauseSet) Serialize(statementType jet.StatementType, out *jet.SqlBuilder) error { + out.NewLine() + out.WriteString("SET") + + if len(s.Columns) == 0 { + return errors.New("jet: no columns selected") + } + + if len(s.Columns) > 1 { + out.WriteString("(") + } + + err := jet.SerializeColumnNames(s.Columns, out) + + if err != nil { + return err + } + + if len(s.Columns) > 1 { + out.WriteString(")") + } + + out.WriteString("=") + + if len(s.Values) > 1 { + out.WriteString("(") + } + + err = jet.SerializeClauseList(statementType, s.Values, out) + + if err != nil { + return err + } + + if len(s.Values) > 1 { + out.WriteString(")") + } + + return nil +} diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index d9b7ca2..2360a16 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -257,36 +257,6 @@ LIMIT ?; assert.NilError(t, err) } -func TestSelectINTERSECT(t *testing.T) { - defer func() { - r := recover() - assert.Equal(t, r, "jet: MySQL does not support INTERSECT operator.") - }() - - query := Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10). - INTERSECT(Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2)).LIMIT(1) - - //fmt.Println(query.DebugSql()) - - err := query.Query(db, &struct{}{}) - assert.NilError(t, err) -} - -func TestSelectEXCEPT(t *testing.T) { - defer func() { - r := recover() - assert.Equal(t, r, "jet: MySQL does not support EXCEPT operator.") - }() - - query := Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10). - EXCEPT(Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2)).LIMIT(1) - - //fmt.Println(query.DebugSql()) - - err := query.Query(db, &struct{}{}) - assert.NilError(t, err) -} - func TestSelectUNION_ALL(t *testing.T) { expectedSQL := ` ( @@ -303,21 +273,29 @@ func TestSelectUNION_ALL(t *testing.T) { LIMIT ? OFFSET ? ) -); +) +ORDER BY "payment.payment_id" +LIMIT ? +OFFSET ?; ` query := UNION_ALL( Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10), Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2), - ) + ).ORDER_BY(Payment.PaymentID). + LIMIT(4). + OFFSET(3) //fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, expectedSQL, int64(1), int64(10), int64(1), int64(2)) + testutils.AssertStatementSql(t, query, expectedSQL, int64(1), int64(10), int64(1), int64(2), int64(4), int64(3)) query2 := Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10). - UNION_ALL(Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2)) + UNION_ALL(Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2)). + ORDER_BY(Payment.PaymentID). + LIMIT(4). + OFFSET(3) - testutils.AssertStatementSql(t, query2, expectedSQL, int64(1), int64(10), int64(1), int64(2)) + testutils.AssertStatementSql(t, query2, expectedSQL, int64(1), int64(10), int64(1), int64(2), int64(4), int64(3)) err := query.Query(db, &struct{}{}) assert.NilError(t, err) diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index b65165a..8c2ce49 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -2,6 +2,7 @@ package mysql import ( "context" + "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/mysql" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" @@ -21,10 +22,12 @@ func TestUpdateValues(t *testing.T) { var expectedSQL = ` UPDATE test_sample.link -SET name = 'Bong', url = 'http://bong.com' +SET name = 'Bong', + url = 'http://bong.com' WHERE link.name = 'Bing'; ` + fmt.Println(query.DebugSql()) testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") testutils.AssertExec(t, query, db) @@ -61,34 +64,21 @@ func TestUpdateWithSubQueries(t *testing.T) { expectedSQL := ` UPDATE test_sample.link SET name = ( - SELECT ? -), url = ( - SELECT link2.url AS "link2.url" - FROM test_sample.link2 - WHERE link2.name = ? -) + SELECT ? + ), + url = ( + SELECT link2.url AS "link2.url" + FROM test_sample.link2 + WHERE link2.name = ? + ) WHERE link.name = ?; ` + fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") testutils.AssertExec(t, query, db) } -func TestUpdateAndReturning(t *testing.T) { - defer func() { - r := recover() - assert.Equal(t, r, "jet: MySQL dialect does not support RETURNING.") - }() - - stmt := Link. - UPDATE(Link.Name, Link.URL). - SET("DuckDuckGo", "http://www.duckduckgo.com"). - WHERE(Link.Name.EQ(String("Ask"))). - RETURNING(Link.AllColumns) - - stmt.Query(db, &struct{}{}) -} - func TestUpdateWithModelData(t *testing.T) { setupLinkTableForUpdateTest(t) @@ -105,9 +95,13 @@ func TestUpdateWithModelData(t *testing.T) { expectedSQL := ` UPDATE test_sample.link -SET id = ?, url = ?, name = ?, description = ? +SET id = ?, + url = ?, + name = ?, + description = ? WHERE link.id = ?; ` + fmt.Println(stmt.Sql()) testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) testutils.AssertExec(t, stmt, db) @@ -132,7 +126,9 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { var expectedSQL = ` UPDATE test_sample.link -SET description = NULL, name = 'DuckDuckGo', url = 'http://www.duckduckgo.com' +SET description = NULL, + name = 'DuckDuckGo', + url = 'http://www.duckduckgo.com' WHERE link.id = 201; ` //fmt.Println(stmt.DebugSql())