From b2f84d048ceb1ddbeb5a7409e934011e8054bcdc Mon Sep 17 00:00:00 2001 From: zer0sub Date: Wed, 3 Apr 2019 11:03:07 +0200 Subject: [PATCH] Add StringColumn type and expression Add Projection type Alias refactoring More numeric operations --- generator/metadata/column_info.go | 25 +++++ generator/templates.go | 4 +- sqlbuilder/alias.go | 8 +- sqlbuilder/bool_expresion.go | 18 ++-- sqlbuilder/column.go | 17 --- sqlbuilder/column_types.go | 28 ++++- sqlbuilder/expression.go | 77 ++++++++------ sqlbuilder/func.go | 45 ++++---- sqlbuilder/numeric_expression.go | 54 +++++++++- sqlbuilder/projection.go | 26 +++++ sqlbuilder/select_statement.go | 10 +- sqlbuilder/select_statement_table.go | 34 +++--- sqlbuilder/statement.go | 2 +- sqlbuilder/string_expression.go | 25 +++++ sqlbuilder/table.go | 22 ++-- tests/generator_test.go | 154 ++++++++++++++------------- 16 files changed, 350 insertions(+), 199 deletions(-) create mode 100644 sqlbuilder/projection.go create mode 100644 sqlbuilder/string_expression.go diff --git a/generator/metadata/column_info.go b/generator/metadata/column_info.go index 80ebed1..ba83d9f 100644 --- a/generator/metadata/column_info.go +++ b/generator/metadata/column_info.go @@ -26,6 +26,31 @@ func (c ColumnInfo) ToGoVarName() string { return snaker.SnakeToCamel(c.Name) + "Column" } +func (c ColumnInfo) ToSqlBuilderColumnType() string { + switch c.DataType { + case "boolean": + return "BoolColumn" + case "smallint": + return "IntegerColumn" + case "integer": + return "IntegerColumn" + case "bigint": + return "IntegerColumn" + case "date", "timestamp without time zone", "timestamp with time zone": + return "StringColumn" + case "bytea": + return "StringColumn" + case "text": + return "StringColumn" + case "real": + return "NumericColumn" + case "numeric", "double precision": + return "NumericColumn" + default: + return "StringColumn" + } +} + func (c ColumnInfo) ToGoType() string { typeStr := c.GoBaseType() if c.IsNullable || c.TableInfo.IsForeignKey(c.Name) { diff --git a/generator/templates.go b/generator/templates.go index 8dc9766..d7ec706 100644 --- a/generator/templates.go +++ b/generator/templates.go @@ -11,7 +11,7 @@ type {{.ToGoStructName}} struct { //Columns {{- range .Columns}} - {{.ToGoFieldName}} sqlbuilder.NonAliasColumn + {{.ToGoFieldName}} *sqlbuilder.{{.ToSqlBuilderColumnType}} {{- end}} AllColumns sqlbuilder.ColumnList @@ -22,7 +22,7 @@ var {{.ToGoVarName}} = new{{.ToGoStructName}}() func new{{.ToGoStructName}}() *{{.ToGoStructName}} { var ( {{- range .Columns}} - {{.ToGoVarName}} = sqlbuilder.IntColumn("{{.Name}}", {{if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}}) + {{.ToGoVarName}} = sqlbuilder.New{{.ToSqlBuilderColumnType}}("{{.Name}}", {{if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}}) {{- end}} ) diff --git a/sqlbuilder/alias.go b/sqlbuilder/alias.go index aef842b..e852e29 100644 --- a/sqlbuilder/alias.go +++ b/sqlbuilder/alias.go @@ -3,24 +3,18 @@ package sqlbuilder import "bytes" type Alias struct { - Clause - expression Expression alias string } func NewAlias(expression Expression, alias string) *Alias { - if !validIdentifierName(alias) { - panic("Invalid alias") - } - return &Alias{ expression: expression, alias: alias, } } -func (a *Alias) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { +func (a *Alias) SerializeForProjection(out *bytes.Buffer) error { err := a.expression.SerializeSql(out, ALIASED) diff --git a/sqlbuilder/bool_expresion.go b/sqlbuilder/bool_expresion.go index f5d3061..a7458fc 100644 --- a/sqlbuilder/bool_expresion.go +++ b/sqlbuilder/bool_expresion.go @@ -31,15 +31,15 @@ func (b *boolInterfaceImpl) Eq(expression BoolExpression) BoolExpression { } func (b *boolInterfaceImpl) NotEq(expression BoolExpression) BoolExpression { - return Neq(b.parent, expression) + return NotEq(b.parent, expression) } func (b *boolInterfaceImpl) GtEq(rhs Expression) BoolExpression { - return Gte(b.parent, rhs) + return GtEq(b.parent, rhs) } func (b *boolInterfaceImpl) LtEq(rhs Expression) BoolExpression { - return Lte(b.parent, rhs) + return LtEq(b.parent, rhs) } func (b *boolInterfaceImpl) And(expression BoolExpression) BoolExpression { @@ -196,7 +196,7 @@ func EqL(lhs Expression, val interface{}) BoolExpression { } // Returns a representation of "a!=b" -func Neq(lhs, rhs Expression) BoolExpression { +func NotEq(lhs, rhs Expression) BoolExpression { lit, ok := rhs.(*literalExpression) if ok && sqltypes.Value(lit.value).IsNull() { return newBinaryBoolExpression(lhs, rhs, []byte(" IS NOT ")) @@ -206,7 +206,7 @@ func Neq(lhs, rhs Expression) BoolExpression { // Returns a representation of "a!=b", where b is a literal func NeqL(lhs Expression, val interface{}) BoolExpression { - return Neq(lhs, Literal(val)) + return NotEq(lhs, Literal(val)) } // Returns a representation of "ab" @@ -240,13 +240,13 @@ func GtL(lhs Expression, val interface{}) BoolExpression { } // Returns a representation of "a>=b" -func Gte(lhs, rhs Expression) BoolExpression { +func GtEq(lhs, rhs Expression) BoolExpression { return newBinaryBoolExpression(lhs, rhs, []byte(">=")) } // Returns a representation of "a>=b", where b is a literal func GteL(lhs Expression, val interface{}) BoolExpression { - return Gte(lhs, Literal(val)) + return GtEq(lhs, Literal(val)) } // Returns a representation of "not expr" diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index e0b60ff..3a2fec3 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -19,21 +19,6 @@ type Column interface { // Internal function for tracking tableName that a column belongs to // for the purpose of serialization setTableName(table string) error - - Asc() OrderByClause - Desc() OrderByClause -} - -type columnInterfaceImpl struct { - parent Column -} - -func (c *columnInterfaceImpl) Asc() OrderByClause { - return &orderByClause{expression: c.parent, ascent: true} -} - -func (c *columnInterfaceImpl) Desc() OrderByClause { - return &orderByClause{expression: c.parent, ascent: false} } type NullableColumn bool @@ -66,7 +51,6 @@ const ( // The base type for real materialized columns. type baseColumn struct { expressionInterfaceImpl - columnInterfaceImpl name string nullable NullableColumn @@ -81,7 +65,6 @@ func newBaseColumn(name string, nullable NullableColumn, tableName string, paren } bc.expressionInterfaceImpl.parent = parent - bc.columnInterfaceImpl.parent = parent return bc } diff --git a/sqlbuilder/column_types.go b/sqlbuilder/column_types.go index 186f5df..724f6a2 100644 --- a/sqlbuilder/column_types.go +++ b/sqlbuilder/column_types.go @@ -50,10 +50,6 @@ type IntegerColumn struct { // Representation of any integer column // This function will panic if name is not valid func NewIntegerColumn(name string, nullable NullableColumn) *IntegerColumn { - if !validIdentifierName(name) { - panic("Invalid column name") - } - integerColumn := &IntegerColumn{} integerColumn.numericInterfaceImpl.parent = integerColumn @@ -63,3 +59,27 @@ func NewIntegerColumn(name string, nullable NullableColumn) *IntegerColumn { return integerColumn } + +//------------------------------------------------------// +type StringColumn struct { + stringInterfaceImpl + + baseColumn +} + +// Representation of any integer column +// This function will panic if name is not valid +func NewStringColumn(name string, nullable NullableColumn) *StringColumn { + if !validIdentifierName(name) { + panic("Invalid column name") + } + + stringColumn := &StringColumn{} + + stringColumn.stringInterfaceImpl.parent = stringColumn + stringColumn.stringInterfaceImpl.parent = stringColumn + + stringColumn.baseColumn = newBaseColumn(name, nullable, "", stringColumn) + + return stringColumn +} diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index 20a6c92..45b9a2a 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -9,17 +9,20 @@ import ( // An expression type Expression interface { Clause + Projection - As(alias string) Clause + As(alias string) Projection IsDistinct(expression Expression) BoolExpression IsNull() BoolExpression + Asc() OrderByClause + Desc() OrderByClause } type expressionInterfaceImpl struct { parent Expression } -func (e *expressionInterfaceImpl) As(alias string) Clause { +func (e *expressionInterfaceImpl) As(alias string) Projection { return NewAlias(e.parent, alias) } @@ -31,6 +34,18 @@ func (e *expressionInterfaceImpl) IsNull() BoolExpression { return nil } +func (e *expressionInterfaceImpl) Asc() OrderByClause { + return &orderByClause{expression: e.parent, ascent: true} +} + +func (e *expressionInterfaceImpl) Desc() OrderByClause { + return &orderByClause{expression: e.parent, ascent: false} +} + +func (e *expressionInterfaceImpl) SerializeForProjection(out *bytes.Buffer) error { + return e.parent.SerializeSql(out, FOR_PROJECTION) +} + // Representation of binary operations (e.g. comparisons, arithmetic) type binaryExpression struct { lhs, rhs Expression @@ -150,32 +165,32 @@ func (c literalExpression) SerializeSql(out *bytes.Buffer, options ...serializeO } //------------------------------------------------------// -// Dummy type for select * -type ColumnList []Column - -func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { - for i, column := range cl { - err := column.SerializeSql(out) - - if err != nil { - return err - } - - if i != len(cl)-1 { - out.WriteString(", ") - } - } - return nil -} - -func (e ColumnList) As(alias string) Clause { - panic("Invalid usage") -} - -func (e ColumnList) IsDistinct(expression Expression) BoolExpression { - panic("Invalid usage") -} - -func (e ColumnList) IsNull(expression Expression) BoolExpression { - panic("Invalid usage") -} +//// Dummy type for select * +//type ColumnList []Column +// +//func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { +// for i, column := range cl { +// err := column.SerializeSql(out) +// +// if err != nil { +// return err +// } +// +// if i != len(cl)-1 { +// out.WriteString(", ") +// } +// } +// return nil +//} +// +//func (e ColumnList) As(alias string) Clause { +// panic("Invalid usage") +//} +// +//func (e ColumnList) IsDistinct(expression Expression) BoolExpression { +// panic("Invalid usage") +//} +// +//func (e ColumnList) IsNull(expression Expression) BoolExpression { +// panic("Invalid usage") +//} diff --git a/sqlbuilder/func.go b/sqlbuilder/func.go index 08fbfb2..ae9babf 100644 --- a/sqlbuilder/func.go +++ b/sqlbuilder/func.go @@ -2,22 +2,31 @@ package sqlbuilder import "bytes" -type FuncExpression struct { +type FuncExpression interface { + Expression +} + +type numericFunc struct { + expressionInterfaceImpl + numericInterfaceImpl + name string expression Expression - - alias string } -func (f *FuncExpression) As(alias string) Clause { - newFuncExpression := *f +func NewNumericFunc(name string, expression Expression) NumericExpression { + numericFunc := &numericFunc{ + name: name, + expression: expression, + } - newFuncExpression.alias = alias + numericFunc.expressionInterfaceImpl.parent = numericFunc + numericFunc.numericInterfaceImpl.parent = numericFunc - return &newFuncExpression + return numericFunc } -func (f *FuncExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { +func (f *numericFunc) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { out.WriteString(f.name) out.WriteString("(") err := f.expression.SerializeSql(out) @@ -26,12 +35,6 @@ func (f *FuncExpression) SerializeSql(out *bytes.Buffer, options ...serializeOpt } out.WriteString(")") - if f.alias != "" { - out.WriteString(` AS "`) - out.WriteString(f.alias) - out.WriteString(`"`) - } - return nil } @@ -39,16 +42,10 @@ func (f *FuncExpression) SerializeSql(out *bytes.Buffer, options ...serializeOpt // return f.SerializeSql(out) //} -func MAX(expression Expression) *FuncExpression { - return &FuncExpression{ - name: "MAX", - expression: expression, - } +func MAX(expression NumericExpression) NumericExpression { + return NewNumericFunc("MAX", expression) } -func SUM(expression Expression) *FuncExpression { - return &FuncExpression{ - name: "SUM", - expression: expression, - } +func SUM(expression NumericExpression) NumericExpression { + return NewNumericFunc("SUM", expression) } diff --git a/sqlbuilder/numeric_expression.go b/sqlbuilder/numeric_expression.go index 688be82..5cf5eca 100644 --- a/sqlbuilder/numeric_expression.go +++ b/sqlbuilder/numeric_expression.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "bytes" "github.com/dropbox/godropbox/database/sqltypes" "github.com/pkg/errors" ) @@ -9,9 +10,13 @@ type NumericExpression interface { Expression Eq(expression NumericExpression) BoolExpression + EqL(literal interface{}) BoolExpression NotEq(expression NumericExpression) BoolExpression + NotEqL(literal interface{}) BoolExpression GtEq(rhs NumericExpression) BoolExpression + GtEqL(literal interface{}) BoolExpression LtEq(rhs NumericExpression) BoolExpression + LtEqL(literal interface{}) BoolExpression Add(expression NumericExpression) NumericExpression Sub(expression NumericExpression) NumericExpression @@ -27,16 +32,32 @@ func (n *numericInterfaceImpl) Eq(expression NumericExpression) BoolExpression { return Eq(n.parent, expression) } +func (n *numericInterfaceImpl) EqL(literal interface{}) BoolExpression { + return Eq(n.parent, Literal(literal)) +} + func (n *numericInterfaceImpl) NotEq(expression NumericExpression) BoolExpression { - return Neq(n.parent, expression) + return NotEq(n.parent, expression) +} + +func (n *numericInterfaceImpl) NotEqL(literal interface{}) BoolExpression { + return NotEq(n.parent, Literal(literal)) } func (n *numericInterfaceImpl) GtEq(expression NumericExpression) BoolExpression { - return Gte(n.parent, expression) + return GtEq(n.parent, expression) +} + +func (n *numericInterfaceImpl) GtEqL(literal interface{}) BoolExpression { + return GtEq(n.parent, Literal(literal)) } func (n *numericInterfaceImpl) LtEq(expression NumericExpression) BoolExpression { - return Lte(n.parent, expression) + return LtEq(n.parent, expression) +} + +func (n *numericInterfaceImpl) LtEqL(literal interface{}) BoolExpression { + return LtEq(n.parent, Literal(literal)) } func (n *numericInterfaceImpl) Add(expression NumericExpression) NumericExpression { @@ -92,3 +113,30 @@ func newBinaryNumericExpression(lhs, rhs Expression, operator []byte) NumericExp return &numericExpression } + +//---------------------------------------------------// +type numericExpressionWrapper struct { + expressionInterfaceImpl + numericInterfaceImpl + + expression Expression +} + +func newNumericExpressionWrap(expression Expression) NumericExpression { + numericExpressionWrap := numericExpressionWrapper{} + + numericExpressionWrap.expression = expression + + numericExpressionWrap.expressionInterfaceImpl.parent = &numericExpressionWrap + numericExpressionWrap.numericInterfaceImpl.parent = &numericExpressionWrap + + return &numericExpressionWrap +} + +func (c *numericExpressionWrapper) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { + out.WriteString("(") + err = c.expression.SerializeSql(out, options...) + out.WriteString(")") + + return nil +} diff --git a/sqlbuilder/projection.go b/sqlbuilder/projection.go new file mode 100644 index 0000000..41e80db --- /dev/null +++ b/sqlbuilder/projection.go @@ -0,0 +1,26 @@ +package sqlbuilder + +import "bytes" + +type Projection interface { + SerializeForProjection(out *bytes.Buffer) error +} + +//------------------------------------------------------// +// Dummy type for select * AllColumns +type ColumnList []Column + +func (cl ColumnList) SerializeForProjection(out *bytes.Buffer) error { + for i, column := range cl { + err := column.SerializeSql(out, FOR_PROJECTION) + + if err != nil { + return err + } + + if i != len(cl)-1 { + out.WriteString(", ") + } + } + return nil +} diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 0525e7c..2c53fd2 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -36,7 +36,7 @@ type selectStatementImpl struct { expressionInterfaceImpl table ReadableTable - projections []Expression + projections []Projection where BoolExpression group *listClause having BoolExpression @@ -50,7 +50,7 @@ type selectStatementImpl struct { func newSelectStatement( table ReadableTable, - projections []Expression) SelectStatement { + projections []Projection) SelectStatement { return &selectStatementImpl{ table: table, @@ -210,7 +210,7 @@ func (q *selectStatementImpl) String() (sql string, err error) { "nil column selected. Generated sql: %s", buf.String()) } - if err = col.SerializeSql(buf, FOR_PROJECTION); err != nil { + if err = col.SerializeForProjection(buf); err != nil { return } } @@ -267,3 +267,7 @@ func (q *selectStatementImpl) String() (sql string, err error) { return buf.String(), nil } + +func NumExp(statement SelectStatement) NumericExpression { + return newNumericExpressionWrap(statement) +} diff --git a/sqlbuilder/select_statement_table.go b/sqlbuilder/select_statement_table.go index 7edff79..fb75e09 100644 --- a/sqlbuilder/select_statement_table.go +++ b/sqlbuilder/select_statement_table.go @@ -12,18 +12,24 @@ func (s *SelectStatementTable) Columns() []Column { return s.columns } -func (s *SelectStatementTable) Column(name string) Column { - return &baseColumn{ - name: name, - tableName: s.alias, - } +func (s *SelectStatementTable) RefIntColumnName(name string) Column { + intColumn := NewIntegerColumn(name, NotNullable) + intColumn.setTableName(s.alias) + + return intColumn } -func (s *SelectStatementTable) ColumnFrom(column Column) Column { - return &baseColumn{ - name: column.TableName() + "." + column.Name(), - tableName: s.alias, - } +func (s *SelectStatementTable) RefIntColumn(column Column) *IntegerColumn { + intColumn := NewIntegerColumn(column.TableName()+"."+column.Name(), NotNullable) + intColumn.setTableName(s.alias) + + return intColumn +} + +func (s *SelectStatementTable) RefStringColumn(column Column) *StringColumn { + strColumn := NewStringColumn(column.Name(), NotNullable) + strColumn.setTableName(column.TableName()) + return strColumn } func (s *SelectStatementTable) SerializeSql(out *bytes.Buffer) error { @@ -43,17 +49,17 @@ func (s *SelectStatementTable) SerializeSql(out *bytes.Buffer) error { } // Generates a select query on the current tableName. -func (s *SelectStatementTable) Select(projections ...Expression) SelectStatement { +func (s *SelectStatementTable) SELECT(projections ...Projection) SelectStatement { return newSelectStatement(s, projections) } // Creates a inner join tableName expression using onCondition. -func (s *SelectStatementTable) InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (s *SelectStatementTable) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return InnerJoinOn(s, table, onCondition) } //func (s *SelectStatementTable) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable { -// return InnerJoinOn(s, table, col1.Eq(col2)) +// return INNER_JOIN(s, table, col1.Eq(col2)) //} // Creates a left join tableName expression using onCondition. @@ -66,7 +72,7 @@ func (s *SelectStatementTable) RightJoinOn(table ReadableTable, onCondition Bool return RightJoinOn(s, table, onCondition) } -func (s *SelectStatementTable) FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (s *SelectStatementTable) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return FullJoin(s, table, onCondition) } diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go index 9d6c989..653c4a5 100644 --- a/sqlbuilder/statement.go +++ b/sqlbuilder/statement.go @@ -178,7 +178,7 @@ func (us *unionStatementImpl) String() (sql string, err error) { } // Union statements in MySQL require that the same number of columns in each subquery - var projections []Expression + var projections []Projection for _, statement := range us.selects { // do a type assertion to get at the underlying struct diff --git a/sqlbuilder/string_expression.go b/sqlbuilder/string_expression.go new file mode 100644 index 0000000..9cbc4ab --- /dev/null +++ b/sqlbuilder/string_expression.go @@ -0,0 +1,25 @@ +package sqlbuilder + +type StringExpression interface { + Expression + + Eq(expression StringExpression) BoolExpression + EqL(value string) BoolExpression + NotEq(expression StringExpression) BoolExpression +} + +type stringInterfaceImpl struct { + parent StringExpression +} + +func (b *stringInterfaceImpl) Eq(expression StringExpression) BoolExpression { + return newBinaryBoolExpression(b.parent, expression, []byte(" = ")) +} + +func (b *stringInterfaceImpl) EqL(value string) BoolExpression { + return newBinaryBoolExpression(b.parent, Literal(value), []byte(" = ")) +} + +func (b *stringInterfaceImpl) NotEq(expression StringExpression) BoolExpression { + return newBinaryBoolExpression(b.parent, expression, []byte(" != ")) +} diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index 99c7810..98c4867 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -14,17 +14,17 @@ type ReadableTable interface { // Returns the list of columns that are in the current tableName expression. Columns() []Column - Column(name string) Column + //Column(name string) Column // Generates the sql string for the current tableName expression. Note: the // generated string may not be a valid/executable sql statement. SerializeSql(out *bytes.Buffer) error // Generates a select query on the current tableName. - Select(projections ...Expression) SelectStatement + SELECT(projections ...Projection) SelectStatement // Creates a inner join tableName expression using onCondition. - InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable + INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable //InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable @@ -34,7 +34,7 @@ type ReadableTable interface { // Creates a right join tableName expression using onCondition. RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable - FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable + FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable CrossJoin(table ReadableTable) ReadableTable } @@ -181,12 +181,12 @@ func (t *Table) SerializeSql(out *bytes.Buffer) error { } // Generates a select query on the current tableName. -func (t *Table) Select(projections ...Expression) SelectStatement { +func (t *Table) SELECT(projections ...Projection) SelectStatement { return newSelectStatement(t, projections) } // Creates a inner join tableName expression using onCondition. -func (t *Table) InnerJoinOn( +func (t *Table) INNER_JOIN( table ReadableTable, onCondition BoolExpression) ReadableTable { @@ -198,7 +198,7 @@ func (t *Table) InnerJoinOn( // col1 Column, // col2 Column) ReadableTable { // -// return InnerJoinOn(t, table, col1.Eq(col2)) +// return INNER_JOIN(t, table, col1.Eq(col2)) //} // Creates a left join tableName expression using onCondition. @@ -217,7 +217,7 @@ func (t *Table) RightJoinOn( return RightJoinOn(t, table, onCondition) } -func (t *Table) FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (t *Table) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return FullJoin(t, table, onCondition) } @@ -363,11 +363,11 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) { return nil } -func (t *joinTable) Select(projections ...Expression) SelectStatement { +func (t *joinTable) SELECT(projections ...Projection) SelectStatement { return newSelectStatement(t, projections) } -func (t *joinTable) InnerJoinOn( +func (t *joinTable) INNER_JOIN( table ReadableTable, onCondition BoolExpression) ReadableTable { @@ -381,7 +381,7 @@ func (t *joinTable) LeftJoinOn( return LeftJoinOn(t, table, onCondition) } -func (t *joinTable) FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (t *joinTable) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return FullJoin(t, table, onCondition) } diff --git a/tests/generator_test.go b/tests/generator_test.go index bd47281..c3e3cae 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -58,7 +58,15 @@ func TestGenerateModel(t *testing.T) { func TestSelect_ScanToStruct(t *testing.T) { actor := model.Actor{} - err := Actor.Select(Actor.AllColumns).Execute(db, &actor) + query := Actor.SELECT(Actor.AllColumns) + + queryStr, err := query.String() + + fmt.Println(queryStr) + + assert.Equal(t, queryStr, `SELECT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor`) + + err = query.Execute(db, &actor) assert.NilError(t, err) @@ -75,7 +83,7 @@ func TestSelect_ScanToStruct(t *testing.T) { func TestSelect_ScanToSlice(t *testing.T) { customers := []model.Customer{} - query := Customer.Select(Customer.AllColumns).OrderBy(Customer.CustomerID.Asc()) + query := Customer.SELECT(Customer.AllColumns).OrderBy(Customer.CustomerID.Asc()) queryStr, err := query.String() assert.NilError(t, err) @@ -92,30 +100,30 @@ func TestSelect_ScanToSlice(t *testing.T) { assert.DeepEqual(t, lastCustomer, customers[598]) } -func TestJoinQueryStruct(t *testing.T) { - - query := FilmActor. - InnerJoinUsing(Actor, FilmActor.ActorID, Actor.ActorID). - InnerJoinUsing(Film, FilmActor.FilmID, Film.FilmID). - InnerJoinUsing(Language, Film.LanguageID, Language.LanguageID). - Select(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). - Where(FilmActor.ActorID.GteLiteral(1).And(FilmActor.ActorID.LteLiteral(2))) - - queryStr, err := query.String() - assert.NilError(t, err) - assert.Equal(t, queryStr, `SELECT film_actor.actor_id AS "film_actor.actor_id", film_actor.film_id AS "film_actor.film_id", film_actor.last_update AS "film_actor.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext",language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.film_actor JOIN dvds.actor ON film_actor.actor_id = actor.actor_id JOIN dvds.film ON film_actor.film_id = film.film_id JOIN dvds.language ON film.language_id = language.language_id WHERE (film_actor.actor_id>=1 AND film_actor.actor_id<=2)`) - - //fmt.Println(queryStr) - - filmActor := []model.FilmActor{} - - err = query.Execute(db, &filmActor) - - assert.NilError(t, err) - - //fmt.Println("ACTORS: --------------------") - //spew.Dump(filmActor) -} +//func TestJoinQueryStruct(t *testing.T) { +// +// query := FilmActor. +// INNER_JOIN(Actor, FilmActor.ActorID.Eq(Actor.ActorID)). +// INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)). +// INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). +// SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). +// Where(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2))) +// +// queryStr, err := query.String() +// assert.NilError(t, err) +// assert.Equal(t, queryStr, `SELECT film_actor.actor_id AS "film_actor.actor_id", film_actor.film_id AS "film_actor.film_id", film_actor.last_update AS "film_actor.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext",language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.film_actor JOIN dvds.actor ON film_actor.actor_id = actor.actor_id JOIN dvds.film ON film_actor.film_id = film.film_id JOIN dvds.language ON film.language_id = language.language_id WHERE (film_actor.actor_id>=1 AND film_actor.actor_id<=2)`) +// +// //fmt.Println(queryStr) +// +// filmActor := []model.FilmActor{} +// +// err = query.Execute(db, &filmActor) +// +// assert.NilError(t, err) +// +// //fmt.Println("ACTORS: --------------------") +// //spew.Dump(filmActor) +//} func TestJoinQuerySlice(t *testing.T) { type FilmsPerLanguage struct { @@ -126,8 +134,9 @@ func TestJoinQuerySlice(t *testing.T) { filmsPerLanguage := []FilmsPerLanguage{} limit := 15 - query := Film.InnerJoinUsing(Language, Film.LanguageID, Language.LanguageID). - Select(Language.AllColumns, Film.AllColumns). + query := Film. + INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). + SELECT(Language.AllColumns, Film.AllColumns). Limit(15) queryStr, err := query.String() @@ -167,8 +176,8 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) { limit := int64(3) - query := Film.InnerJoinUsing(Language, Film.LanguageID, Language.LanguageID). - Select(Language.AllColumns, Film.AllColumns). + query := Film.INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). + SELECT(Language.AllColumns, Film.AllColumns). Limit(limit) filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} @@ -182,7 +191,7 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) { } func TestSelect_WithoutUniqueColumnSelected(t *testing.T) { - query := Customer.Select(Customer.FirstName, Customer.LastName, Customer.Email) + query := Customer.SELECT(Customer.FirstName, Customer.LastName, Customer.Email) customers := []model.Customer{} @@ -198,7 +207,7 @@ func TestSelect_WithoutUniqueColumnSelected(t *testing.T) { func TestSelectOrderByAscDesc(t *testing.T) { customersAsc := []model.Customer{} - err := Customer.Select(Customer.CustomerID, Customer.FirstName, Customer.LastName). + err := Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). OrderBy(Customer.FirstName.Asc()). Execute(db, &customersAsc) @@ -208,7 +217,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { lastCustomerAsc := customersAsc[len(customersAsc)-1] customersDesc := []model.Customer{} - err = Customer.Select(Customer.CustomerID, Customer.FirstName, Customer.LastName). + err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). OrderBy(Customer.FirstName.Desc()). Execute(db, &customersDesc) @@ -221,7 +230,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { assert.DeepEqual(t, lastCustomerAsc, firstCustomerDesc) customersAscDesc := []model.Customer{} - err = Customer.Select(Customer.CustomerID, Customer.FirstName, Customer.LastName). + err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). OrderBy(Customer.FirstName.Asc(), Customer.LastName.Desc()). Execute(db, &customersAscDesc) @@ -245,8 +254,8 @@ func TestSelectOrderByAscDesc(t *testing.T) { func TestSelectFullJoin(t *testing.T) { query := Customer. - FullJoin(Address, Customer.AddressID, Address.AddressID). - Select(Customer.AllColumns, Address.AllColumns). + FULL_JOIN(Address, Customer.AddressID.Eq(Address.AddressID)). + SELECT(Customer.AllColumns, Address.AllColumns). OrderBy(Customer.CustomerID.Asc()) queryStr, err := query.String() @@ -278,7 +287,7 @@ func TestSelectFullJoin(t *testing.T) { func TestSelectFullCrossJoin(t *testing.T) { query := Customer. CrossJoin(Address). - Select(Customer.AllColumns, Address.AllColumns). + SELECT(Customer.AllColumns, Address.AllColumns). OrderBy(Customer.CustomerID.Asc()). Limit(1000) @@ -304,9 +313,9 @@ func TestSelectSelfJoin(t *testing.T) { f2 := Film.As("f2") query := f1. - InnerJoinOn(f2, f1.FilmID.Neq(f2.FilmID).And(f1.Length.Eq(f2.Length))). - Select(f1.AllColumns, f2.AllColumns). - OrderBy(f1.FilmID) + INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))). + SELECT(f1.AllColumns, f2.AllColumns). + OrderBy(f1.FilmID.Asc()) queryStr, err := query.String() @@ -342,8 +351,8 @@ func TestSelectAliasColumn(t *testing.T) { } query := f1. - InnerJoinOn(f2, f1.FilmID.Neq(f2.FilmID).And(f1.Length.Eq(f2.Length))). - Select(f1.Title.As("thesame_length_films.title1"), + INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))). + SELECT(f1.Title.As("thesame_length_films.title1"), f2.Title.As("thesame_length_films.title2"), f1.Length.As("thesame_length_films.length")). OrderBy(f1.Length.Asc(), f1.Title.Asc(), f2.Title.Asc()). @@ -388,9 +397,9 @@ func TestSelectSelfReferenceType(t *testing.T) { manager := Staff.As("manager") query := Staff. - InnerJoinUsing(Address, Staff.AddressID, Address.AddressID). - InnerJoinUsing(manager, Staff.StaffID, manager.StaffID). - Select(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName) + INNER_JOIN(Address, Staff.AddressID.Eq(Address.AddressID)). + INNER_JOIN(manager, Staff.StaffID.Eq(manager.StaffID)). + SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName) queryStr, err := query.String() assert.NilError(t, err) @@ -407,11 +416,11 @@ func TestSelectSelfReferenceType(t *testing.T) { func TestSubQuery(t *testing.T) { - //selectStmtTable := Actor.Select(Actor.FirstName, Actor.LastName).AsTable("table_expression") + //selectStmtTable := Actor.SELECT(Actor.FirstName, Actor.LastName).AsTable("table_expression") // - //query := selectStmtTable.Select( - // selectStmtTable.ColumnFrom(Actor.FirstName).As("nesto"), - // selectStmtTable.Column("actor.last_name").As("nesto2"), + //query := selectStmtTable.SELECT( + // selectStmtTable.RefStringColumn(Actor.FirstName).As("nesto"), + // selectStmtTable.RefIntColumnName("actor.last_name").As("nesto2"), // ) // //queryStr, err := query.String() @@ -419,25 +428,25 @@ func TestSubQuery(t *testing.T) { //assert.NilError(t, err) // //fmt.Println(queryStr) - - //avrgCustomer := Customer.Select(Customer.LastName).Limit(1).AsExpression() + // + //avrgCustomer := sqlbuilder.NumExp(Customer.SELECT(Customer.LastName).Limit(1)) // //Customer. - // InnerJoinUsing(selectStmtTable, Customer.LastName, selectStmtTable.Column("first_name")). - // Select(Customer.AllColumns, selectStmtTable.Column("first_name")). + // INNER_JOIN(selectStmtTable, Customer.LastName.Eq(selectStmtTable.RefStringColumn(Actor.FirstName))). + // SELECT(Customer.AllColumns, selectStmtTable.RefIntColumnName("first_name")). // Where(Actor.LastName.Neq(avrgCustomer)) - rFilmsOnly := Film.Select(Film.FilmID, Film.Title, Film.Rating). - Where(Film.Rating.Eq(sqlbuilder.Literal("R"))). + rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating). + Where(Film.Rating.EqL("R")). AsTable("films") - query := Actor.InnerJoinUsing(FilmActor, Actor.ActorID, FilmActor.FilmID). - InnerJoinUsing(rFilmsOnly, FilmActor.FilmID, rFilmsOnly.ColumnFrom(Film.FilmID)). - Select( + query := Actor.INNER_JOIN(FilmActor, Actor.ActorID.Eq(FilmActor.FilmID)). + INNER_JOIN(rFilmsOnly, FilmActor.FilmID.Eq(rFilmsOnly.RefIntColumn(Film.FilmID))). + SELECT( Actor.AllColumns, FilmActor.AllColumns, - rFilmsOnly.ColumnFrom(Film.Title).As("film.title"), - rFilmsOnly.ColumnFrom(Film.Rating).As("film.rating"), + rFilmsOnly.RefStringColumn(Film.Title).As("film.title"), + rFilmsOnly.RefStringColumn(Film.Rating).As("film.rating"), ) queryStr, err := query.String() @@ -449,7 +458,7 @@ func TestSubQuery(t *testing.T) { } func TestSelectFunctions(t *testing.T) { - query := Film.Select(sqlbuilder.MAX(Film.RentalRate).As("max_film_rate")) + query := Film.SELECT(sqlbuilder.MAX(Film.RentalRate).As("max_film_rate")) str, err := query.String() @@ -462,11 +471,11 @@ func TestSelectFunctions(t *testing.T) { func TestSelectQueryScalar(t *testing.T) { - maxFilmRentalRate := Film.Select(sqlbuilder.MAX(Film.RentalRate)) + maxFilmRentalRate := sqlbuilder.NumExp(Film.SELECT(sqlbuilder.MAX(Film.RentalRate))) - query := Film.Select(Film.AllColumns). + query := Film.SELECT(Film.AllColumns). Where(Film.RentalRate.Eq(maxFilmRentalRate)). - OrderBy(Film.FilmID) + OrderBy(Film.FilmID.Asc()) queryStr, err := query.String() @@ -502,12 +511,12 @@ func TestSelectQueryScalar(t *testing.T) { func TestSelectGroupByHaving(t *testing.T) { customersPaymentQuery := Payment. - Select( + SELECT( Payment.CustomerID.As("customer_payment_sum.customer_id"), sqlbuilder.SUM(Payment.Amount).As("customer_payment_sum.amount_sum"), ). GroupBy(Payment.CustomerID). - OrderBy(sqlbuilder.SUM(Payment.Amount)). + OrderBy(sqlbuilder.SUM(Payment.Amount).Asc()). HAVING(sqlbuilder.Gt(sqlbuilder.SUM(Payment.Amount), sqlbuilder.Literal(100))) queryStr, err := customersPaymentQuery.String() @@ -515,8 +524,7 @@ func TestSelectGroupByHaving(t *testing.T) { assert.NilError(t, err) fmt.Println(queryStr) - assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id",SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>100 ORDER BY SUM(payment.amount)`) - + assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id",SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>100 ORDER BY SUM(payment.amount) ASC`) type CustomerPaymentSum struct { CustomerID int16 AmountSum float64 @@ -543,19 +551,19 @@ func TestSelectGroupBy2(t *testing.T) { customersWithAmounts := []CustomerWithAmounts{} customersPaymentSubQuery := Payment. - Select( + SELECT( Payment.CustomerID, sqlbuilder.SUM(Payment.Amount).As("amount_sum"), ). GroupBy(Payment.CustomerID) customersPaymentTable := customersPaymentSubQuery.AsTable("customer_payment_sum") - amountSumColumn := customersPaymentTable.Column("amount_sum") + amountSumColumn := customersPaymentTable.RefIntColumnName("amount_sum") query := Customer. - InnerJoinUsing(customersPaymentTable, Customer.CustomerID, customersPaymentTable.ColumnFrom(Payment.CustomerID)). - Select(Customer.AllColumns, amountSumColumn.As("customer_with_amounts.amount_sum")). - OrderBy(amountSumColumn) + INNER_JOIN(customersPaymentTable, Customer.CustomerID.Eq(customersPaymentTable.RefIntColumn(Payment.CustomerID))). + SELECT(Customer.AllColumns, amountSumColumn.As("customer_with_amounts.amount_sum")). + OrderBy(amountSumColumn.Asc()) queryStr, err := query.String() assert.NilError(t, err)