From 4d7fbf8f49a33b8a0954a6db5f2e856a60790196 Mon Sep 17 00:00:00 2001 From: zer0sub Date: Wed, 5 Jun 2019 17:15:20 +0200 Subject: [PATCH] Table interface cleanup. --- generator/templates.go | 10 +- sqlbuilder/bool_expresion.go | 16 +- sqlbuilder/bool_expression_test.go | 36 +-- sqlbuilder/clause.go | 2 +- sqlbuilder/column.go | 44 ++-- sqlbuilder/column_types.go | 36 +-- sqlbuilder/column_types_test.go | 6 +- sqlbuilder/delete_statement.go | 8 +- sqlbuilder/expression.go | 10 +- sqlbuilder/expression_table.go | 65 ++---- sqlbuilder/expression_test.go | 44 ++-- sqlbuilder/float_expression_test.go | 48 ++-- sqlbuilder/func_expression_test.go | 116 ++++----- sqlbuilder/insert_statement.go | 10 +- sqlbuilder/integer_expression_test.go | 56 ++--- sqlbuilder/literal_expression_test.go | 2 +- sqlbuilder/lock_statement.go | 4 +- sqlbuilder/operators.go | 41 ---- sqlbuilder/operators_test.go | 8 +- sqlbuilder/order_by_clause.go | 4 +- sqlbuilder/projection.go | 2 +- sqlbuilder/select_statement.go | 15 +- sqlbuilder/set_statement.go | 13 +- sqlbuilder/statement_test.go | 2 +- sqlbuilder/string_expression_test.go | 44 ++-- sqlbuilder/table.go | 323 ++++++++++---------------- sqlbuilder/table_test.go | 257 +++++++------------- sqlbuilder/test_utils.go | 46 ++-- sqlbuilder/time_expression_test.go | 24 +- sqlbuilder/update_statement.go | 8 +- sqlbuilder/utils.go | 7 +- tests/select_test.go | 30 +-- 32 files changed, 543 insertions(+), 794 deletions(-) diff --git a/generator/templates.go b/generator/templates.go index 1deda60..2b55e92 100644 --- a/generator/templates.go +++ b/generator/templates.go @@ -20,10 +20,6 @@ var sqlBuilderTableTemplate = ` {{- end}} {{- end}} -{{define "nullable" -}} - {{- if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}} -{{- end}} - package table import ( @@ -46,12 +42,12 @@ var {{camelize .Name}} = new{{.GoStructName}}() func new{{.GoStructName}}() *{{.GoStructName}} { var ( {{- range .Columns}} - {{camelize .Name}}Column = sqlbuilder.New{{.SqlBuilderColumnType}}("{{.Name}}", {{template "nullable" .}}) + {{camelize .Name}}Column = sqlbuilder.New{{.SqlBuilderColumnType}}("{{.Name}}", {{.IsNullable}}) {{- end}} ) return &{{.GoStructName}}{ - Table: *sqlbuilder.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}), + Table: sqlbuilder.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}), //Columns {{- range .Columns}} @@ -65,7 +61,7 @@ func new{{.GoStructName}}() *{{.GoStructName}} { func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} { aliasTable := new{{.GoStructName}}() - aliasTable.Table.SetAlias(alias) + aliasTable.Table.AS(alias) return aliasTable } diff --git a/sqlbuilder/bool_expresion.go b/sqlbuilder/bool_expresion.go index fddb37e..5f95eea 100644 --- a/sqlbuilder/bool_expresion.go +++ b/sqlbuilder/bool_expresion.go @@ -40,35 +40,35 @@ func (b *boolInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs BoolExpression) BoolExpress } func (b *boolInterfaceImpl) AND(expression BoolExpression) BoolExpression { - return And(b.parent, expression) + return newBinaryBoolExpression(b.parent, expression, "AND") } func (b *boolInterfaceImpl) OR(expression BoolExpression) BoolExpression { - return Or(b.parent, expression) + return newBinaryBoolExpression(b.parent, expression, "OR") } func (b *boolInterfaceImpl) IS_TRUE() BoolExpression { - return IS_TRUE(b.parent) + return newPostifxBoolExpression(b.parent, "IS TRUE") } func (b *boolInterfaceImpl) IS_NOT_TRUE() BoolExpression { - return IS_NOT_TRUE(b.parent) + return newPostifxBoolExpression(b.parent, "IS NOT TRUE") } func (b *boolInterfaceImpl) IS_FALSE() BoolExpression { - return IS_FALSE(b.parent) + return newPostifxBoolExpression(b.parent, "IS FALSE") } func (b *boolInterfaceImpl) IS_NOT_FALSE() BoolExpression { - return IS_NOT_FALSE(b.parent) + return newPostifxBoolExpression(b.parent, "IS NOT FALSE") } func (b *boolInterfaceImpl) IS_UNKNOWN() BoolExpression { - return IS_UNKNOWN(b.parent) + return newPostifxBoolExpression(b.parent, "IS UNKNOWN") } func (b *boolInterfaceImpl) IS_NOT_UNKNOWN() BoolExpression { - return IS_NOT_UNKNOWN(b.parent) + return newPostifxBoolExpression(b.parent, "IS NOT UNKNOWN") } //---------------------------------------------------// diff --git a/sqlbuilder/bool_expression_test.go b/sqlbuilder/bool_expression_test.go index de5e2bc..a8e0bf5 100644 --- a/sqlbuilder/bool_expression_test.go +++ b/sqlbuilder/bool_expression_test.go @@ -5,62 +5,62 @@ import ( ) func TestBoolExpressionEQ(t *testing.T) { - assertExpressionSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.colBool = table2.colBool)") - assertExpressionSerialize(t, table1ColBool.EQ(Bool(true)), "(table1.colBool = $1)", true) + assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.colBool = table2.colBool)") + assertClauseSerialize(t, table1ColBool.EQ(Bool(true)), "(table1.colBool = $1)", true) } func TestBoolExpressionNOT_EQ(t *testing.T) { - assertExpressionSerialize(t, table1ColBool.NOT_EQ(table2ColBool), "(table1.colBool != table2.colBool)") - assertExpressionSerialize(t, table1ColBool.NOT_EQ(Bool(true)), "(table1.colBool != $1)", true) + assertClauseSerialize(t, table1ColBool.NOT_EQ(table2ColBool), "(table1.colBool != table2.colBool)") + assertClauseSerialize(t, table1ColBool.NOT_EQ(Bool(true)), "(table1.colBool != $1)", true) } func TestBoolExpressionIS_TRUE(t *testing.T) { - assertExpressionSerialize(t, table1ColBool.IS_TRUE(), "table1.colBool IS TRUE") - assertExpressionSerialize(t, (Int(2).EQ(table1ColInt)).IS_TRUE(), + assertClauseSerialize(t, table1ColBool.IS_TRUE(), "table1.colBool IS TRUE") + assertClauseSerialize(t, (Int(2).EQ(table1ColInt)).IS_TRUE(), `($1 = table1.colInt) IS TRUE`, int64(2)) - assertExpressionSerialize(t, (Int(2).EQ(table1ColInt)).IS_TRUE().AND(Int(4).EQ(table2ColInt)), + assertClauseSerialize(t, (Int(2).EQ(table1ColInt)).IS_TRUE().AND(Int(4).EQ(table2ColInt)), `(($1 = table1.colInt) IS TRUE AND ($2 = table2.colInt))`, int64(2), int64(4)) } func TestBoolExpressionIS_NOT_TRUE(t *testing.T) { - assertExpressionSerialize(t, table1ColBool.IS_NOT_TRUE(), "table1.colBool IS NOT TRUE") + assertClauseSerialize(t, table1ColBool.IS_NOT_TRUE(), "table1.colBool IS NOT TRUE") } func TestBoolExpressionIS_FALSE(t *testing.T) { - assertExpressionSerialize(t, table1ColBool.IS_FALSE(), "table1.colBool IS FALSE") + assertClauseSerialize(t, table1ColBool.IS_FALSE(), "table1.colBool IS FALSE") } func TestBoolExpressionIS_NOT_FALSE(t *testing.T) { - assertExpressionSerialize(t, table1ColBool.IS_NOT_FALSE(), "table1.colBool IS NOT FALSE") + assertClauseSerialize(t, table1ColBool.IS_NOT_FALSE(), "table1.colBool IS NOT FALSE") } func TestBoolExpressionIS_UNKNOWN(t *testing.T) { - assertExpressionSerialize(t, table1ColBool.IS_UNKNOWN(), "table1.colBool IS UNKNOWN") + assertClauseSerialize(t, table1ColBool.IS_UNKNOWN(), "table1.colBool IS UNKNOWN") } func TestBoolExpressionIS_NOT_UNKNOWN(t *testing.T) { - assertExpressionSerialize(t, table1ColBool.IS_NOT_UNKNOWN(), "table1.colBool IS NOT UNKNOWN") + assertClauseSerialize(t, table1ColBool.IS_NOT_UNKNOWN(), "table1.colBool IS NOT UNKNOWN") } func TestBinaryBoolExpression(t *testing.T) { boolExpression := Int(2).EQ(Int(3)) - assertExpressionSerialize(t, boolExpression, "($1 = $2)", int64(2), int64(3)) + assertClauseSerialize(t, boolExpression, "($1 = $2)", int64(2), int64(3)) assertProjectionSerialize(t, boolExpression.AS("alias_eq_expression"), `$1 = $2 AS "alias_eq_expression"`, int64(2), int64(3)) - assertExpressionSerialize(t, boolExpression.AND(Int(4).EQ(Int(5))), + assertClauseSerialize(t, boolExpression.AND(Int(4).EQ(Int(5))), "(($1 = $2) AND ($3 = $4))", int64(2), int64(3), int64(4), int64(5)) - assertExpressionSerialize(t, boolExpression.OR(Int(4).EQ(Int(5))), + assertClauseSerialize(t, boolExpression.OR(Int(4).EQ(Int(5))), "(($1 = $2) OR ($3 = $4))", int64(2), int64(3), int64(4), int64(5)) } func TestBoolLiteral(t *testing.T) { - assertExpressionSerialize(t, Bool(true), "$1", true) - assertExpressionSerialize(t, Bool(false), "$1", false) + assertClauseSerialize(t, Bool(true), "$1", true) + assertClauseSerialize(t, Bool(false), "$1", false) } func TestExists(t *testing.T) { - assertExpressionSerialize(t, EXISTS( + assertClauseSerialize(t, EXISTS( table2. SELECT(Int(1)). WHERE(table1Col1.EQ(table2Col3)), diff --git a/sqlbuilder/clause.go b/sqlbuilder/clause.go index 1a75d02..87e5857 100644 --- a/sqlbuilder/clause.go +++ b/sqlbuilder/clause.go @@ -65,7 +65,7 @@ func (q *queryData) writeProjection(statement statementType, projections []proje return err } -func (q *queryData) writeFrom(statement statementType, table tableInterface) error { +func (q *queryData) writeFrom(statement statementType, table ReadableTable) error { q.nextLine() q.writeString("FROM") diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index 4a0fe2b..642f8cd 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -6,54 +6,32 @@ import ( "strings" ) -type column interface { +type Column interface { Expression Name() string TableName() string - + IsNullable() bool DefaultAlias() projection // Internal function for tracking tableName that a column belongs to // for the purpose of serialization setTableName(table string) } -type NullableColumn bool - -const ( - Nullable NullableColumn = true - NotNullable NullableColumn = false -) - -type Collation string - -const ( - UTF8CaseInsensitive Collation = "utf8_unicode_ci" - UTF8CaseSensitive Collation = "utf8_unicode" - UTF8Binary Collation = "utf8_bin" -) - -// Representation of MySQL charsets -type Charset string - -const ( - UTF8 Charset = "utf8" -) - // The base type for real materialized columns. type baseColumn struct { expressionInterfaceImpl - name string - nullable NullableColumn - tableName string + name string + isNullable bool + tableName string } -func newBaseColumn(name string, nullable NullableColumn, tableName string, parent column) baseColumn { +func newBaseColumn(name string, isNullable bool, tableName string, parent Column) baseColumn { bc := baseColumn{ - name: name, - nullable: nullable, - tableName: tableName, + name: name, + isNullable: isNullable, + tableName: tableName, } bc.expressionInterfaceImpl.parent = parent @@ -73,6 +51,10 @@ func (c *baseColumn) setTableName(table string) { c.tableName = table } +func (c *baseColumn) IsNullable() bool { + return c.isNullable +} + func (c *baseColumn) DefaultAlias() projection { return c.AS(c.tableName + "." + c.name) } diff --git a/sqlbuilder/column_types.go b/sqlbuilder/column_types.go index 613bb4f..4a27da1 100644 --- a/sqlbuilder/column_types.go +++ b/sqlbuilder/column_types.go @@ -7,10 +7,10 @@ type BoolColumn struct { baseColumn } -func NewBoolColumn(name string, nullable NullableColumn) *BoolColumn { +func NewBoolColumn(name string, isNullable bool) *BoolColumn { boolColumn := &BoolColumn{} - boolColumn.baseColumn = newBaseColumn(name, nullable, "", boolColumn) + boolColumn.baseColumn = newBaseColumn(name, isNullable, "", boolColumn) boolColumn.boolInterfaceImpl.parent = boolColumn @@ -23,13 +23,13 @@ type FloatColumn struct { baseColumn } -func NewFloatColumn(name string, nullable NullableColumn) *FloatColumn { +func NewFloatColumn(name string, isNullable bool) *FloatColumn { floatColumn := &FloatColumn{} floatColumn.floatInterfaceImpl.parent = floatColumn - floatColumn.baseColumn = newBaseColumn(name, nullable, "", floatColumn) + floatColumn.baseColumn = newBaseColumn(name, isNullable, "", floatColumn) return floatColumn } @@ -43,12 +43,12 @@ type IntegerColumn struct { // Representation of any integer column // This function will panic if name is not valid -func NewIntegerColumn(name string, nullable NullableColumn) *IntegerColumn { +func NewIntegerColumn(name string, isNullable bool) *IntegerColumn { integerColumn := &IntegerColumn{} integerColumn.integerInterfaceImpl.parent = integerColumn - integerColumn.baseColumn = newBaseColumn(name, nullable, "", integerColumn) + integerColumn.baseColumn = newBaseColumn(name, isNullable, "", integerColumn) return integerColumn } @@ -62,13 +62,13 @@ type StringColumn struct { // Representation of any integer column // This function will panic if name is not valid -func NewStringColumn(name string, nullable NullableColumn) *StringColumn { +func NewStringColumn(name string, isNullable bool) *StringColumn { stringColumn := &StringColumn{} stringColumn.stringInterfaceImpl.parent = stringColumn - stringColumn.baseColumn = newBaseColumn(name, nullable, "", stringColumn) + stringColumn.baseColumn = newBaseColumn(name, isNullable, "", stringColumn) return stringColumn } @@ -82,12 +82,12 @@ type TimeColumn struct { // Representation of any integer column // This function will panic if name is not valid -func NewTimeColumn(name string, nullable NullableColumn) *TimeColumn { +func NewTimeColumn(name string, isNullable bool) *TimeColumn { timeColumn := &TimeColumn{} timeColumn.timeInterfaceImpl.parent = timeColumn - timeColumn.baseColumn = newBaseColumn(name, nullable, "", timeColumn) + timeColumn.baseColumn = newBaseColumn(name, isNullable, "", timeColumn) return timeColumn } @@ -101,12 +101,12 @@ type TimezColumn struct { // Representation of any integer column // This function will panic if name is not valid -func NewTimezColumn(name string, nullable NullableColumn) *TimezColumn { +func NewTimezColumn(name string, isNullable bool) *TimezColumn { timezColumn := &TimezColumn{} timezColumn.timezInterfaceImpl.parent = timezColumn - timezColumn.baseColumn = newBaseColumn(name, nullable, "", timezColumn) + timezColumn.baseColumn = newBaseColumn(name, isNullable, "", timezColumn) return timezColumn } @@ -120,12 +120,12 @@ type TimestampColumn struct { // Representation of any integer column // This function will panic if name is not valid -func NewTimestampColumn(name string, nullable NullableColumn) *TimestampColumn { +func NewTimestampColumn(name string, isNullable bool) *TimestampColumn { timestampColumn := &TimestampColumn{} timestampColumn.timestampInterfaceImpl.parent = timestampColumn - timestampColumn.baseColumn = newBaseColumn(name, nullable, "", timestampColumn) + timestampColumn.baseColumn = newBaseColumn(name, isNullable, "", timestampColumn) return timestampColumn } @@ -139,12 +139,12 @@ type TimestampzColumn struct { // Representation of any integer column // This function will panic if name is not valid -func NewTimestampzColumn(name string, nullable NullableColumn) *TimestampzColumn { +func NewTimestampzColumn(name string, isNullable bool) *TimestampzColumn { timestampzColumn := &TimestampzColumn{} timestampzColumn.timestampzInterfaceImpl.parent = timestampzColumn - timestampzColumn.baseColumn = newBaseColumn(name, nullable, "", timestampzColumn) + timestampzColumn.baseColumn = newBaseColumn(name, isNullable, "", timestampzColumn) return timestampzColumn } @@ -158,12 +158,12 @@ type DateColumn struct { // Representation of any integer column // This function will panic if name is not valid -func NewDateColumn(name string, nullable NullableColumn) *DateColumn { +func NewDateColumn(name string, isNullable bool) *DateColumn { dateColumn := &DateColumn{} dateColumn.dateInterfaceImpl.parent = dateColumn - dateColumn.baseColumn = newBaseColumn(name, nullable, "", dateColumn) + dateColumn.baseColumn = newBaseColumn(name, isNullable, "", dateColumn) return dateColumn } diff --git a/sqlbuilder/column_types_test.go b/sqlbuilder/column_types_test.go index d127295..cc9fa9a 100644 --- a/sqlbuilder/column_types_test.go +++ b/sqlbuilder/column_types_test.go @@ -6,7 +6,7 @@ import ( ) func TestNewBoolColumn(t *testing.T) { - boolColumn := NewBoolColumn("col", Nullable) + boolColumn := NewBoolColumn("col", false) out := queryData{} err := boolColumn.serialize(select_statement, &out) @@ -34,7 +34,7 @@ func TestNewBoolColumn(t *testing.T) { } func TestNewIntColumn(t *testing.T) { - integerColumn := NewIntegerColumn("col", Nullable) + integerColumn := NewIntegerColumn("col", false) out := queryData{} err := integerColumn.serialize(select_statement, &out) @@ -62,7 +62,7 @@ func TestNewIntColumn(t *testing.T) { } func TestNewNumericColumnColumn(t *testing.T) { - numericColumn := NewFloatColumn("col", Nullable) + numericColumn := NewFloatColumn("col", false) out := queryData{} err := numericColumn.serialize(select_statement, &out) diff --git a/sqlbuilder/delete_statement.go b/sqlbuilder/delete_statement.go index 60587ec..ca1e801 100644 --- a/sqlbuilder/delete_statement.go +++ b/sqlbuilder/delete_statement.go @@ -2,7 +2,7 @@ package sqlbuilder import ( "database/sql" - "github.com/dropbox/godropbox/errors" + "errors" "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" ) @@ -12,14 +12,14 @@ type DeleteStatement interface { WHERE(expression BoolExpression) DeleteStatement } -func newDeleteStatement(table writableTable) DeleteStatement { +func newDeleteStatement(table WritableTable) DeleteStatement { return &deleteStatementImpl{ table: table, } } type deleteStatementImpl struct { - table writableTable + table WritableTable where BoolExpression } @@ -30,7 +30,7 @@ func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement { func (d *deleteStatementImpl) serializeImpl(out *queryData) error { if d == nil { - return errors.New("Delete statement. ") + return errors.New("Delete expression. ") } out.nextLine() out.writeString("DELETE FROM") diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index 8b6eb8e..3eebb54 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -1,7 +1,7 @@ package sqlbuilder import ( - "github.com/dropbox/godropbox/errors" + "errors" ) // An Expression @@ -139,10 +139,10 @@ func (c *binaryOpExpression) serialize(statement statementType, out *queryData, return errors.New("Binary Expression is nil.") } if c.lhs == nil { - return errors.Newf("nil lhs.") + return errors.New("nil lhs.") } if c.rhs == nil { - return errors.Newf("nil rhs.") + return errors.New("nil rhs.") } wrap := !contains(options, NO_WRAP) @@ -191,7 +191,7 @@ func (p *prefixOpExpression) serialize(statement statementType, out *queryData, out.writeString(p.operator + " ") if p.expression == nil { - return errors.Newf("nil prefix Expression.") + return errors.New("nil prefix Expression.") } if err := p.expression.serialize(statement, out); err != nil { return err @@ -221,7 +221,7 @@ func (p *postfixOpExpression) serialize(statement statementType, out *queryData, } if p.expression == nil { - return errors.Newf("nil prefix Expression.") + return errors.New("nil prefix Expression.") } if err := p.expression.serialize(statement, out); err != nil { return err diff --git a/sqlbuilder/expression_table.go b/sqlbuilder/expression_table.go index ac9352e..0e190e1 100644 --- a/sqlbuilder/expression_table.go +++ b/sqlbuilder/expression_table.go @@ -6,13 +6,22 @@ type expressionTable interface { ReadableTable RefIntColumnName(name string) *IntegerColumn - RefIntColumn(column column) *IntegerColumn - RefStringColumn(column column) *StringColumn + RefIntColumn(column Column) *IntegerColumn + RefStringColumn(column Column) *StringColumn } type expressionTableImpl struct { - statement Expression - alias string + readableTableInterfaceImpl + expression Expression + alias string +} + +func newExpressionTable(expression Expression, alias string) expressionTable { + expTable := &expressionTableImpl{expression: expression, alias: alias} + + expTable.readableTableInterfaceImpl.parent = expTable + + return expTable } // Returns the tableName's name in the database @@ -24,26 +33,22 @@ func (e *expressionTableImpl) TableName() string { return e.alias } -func (e *expressionTableImpl) Columns() []column { - return []column{} -} - func (e *expressionTableImpl) RefIntColumnName(name string) *IntegerColumn { - intColumn := NewIntegerColumn(name, NotNullable) + intColumn := NewIntegerColumn(name, false) intColumn.setTableName(e.alias) return intColumn } -func (e *expressionTableImpl) RefIntColumn(column column) *IntegerColumn { - intColumn := NewIntegerColumn(column.TableName()+"."+column.Name(), NotNullable) +func (e *expressionTableImpl) RefIntColumn(column Column) *IntegerColumn { + intColumn := NewIntegerColumn(column.TableName()+"."+column.Name(), false) intColumn.setTableName(e.alias) return intColumn } -func (e *expressionTableImpl) RefStringColumn(column column) *StringColumn { - strColumn := NewStringColumn(column.TableName()+"."+column.Name(), NotNullable) +func (e *expressionTableImpl) RefStringColumn(column Column) *StringColumn { + strColumn := NewStringColumn(column.TableName()+"."+column.Name(), false) strColumn.setTableName(e.alias) return strColumn } @@ -53,7 +58,7 @@ func (e *expressionTableImpl) serialize(statement statementType, out *queryData, return errors.New("Expression table is nil. ") } //out.writeString("( ") - err := e.statement.serialize(statement, out) + err := e.expression.serialize(statement, out) if err != nil { return err @@ -64,35 +69,3 @@ func (e *expressionTableImpl) serialize(statement statementType, out *queryData, return nil } - -// Generates a select query on the current tableName. -func (e *expressionTableImpl) SELECT(projections ...projection) SelectStatement { - return newSelectStatement(e, projections) -} - -// Creates a inner join tableName Expression using onCondition. -func (e *expressionTableImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return InnerJoinOn(e, table, onCondition) -} - -//func (s *expressionTableImpl) InnerJoinUsing(table ReadableTable, col1 column, col2 column) ReadableTable { -// return INNER_JOIN(s, table, col1.EQ(col2)) -//} - -// Creates a left join tableName Expression using onCondition. -func (e *expressionTableImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return LeftJoinOn(e, table, onCondition) -} - -// Creates a right join tableName Expression using onCondition. -func (e *expressionTableImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return RightJoinOn(e, table, onCondition) -} - -func (e *expressionTableImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return FullJoin(e, table, onCondition) -} - -func (e *expressionTableImpl) CROSS_JOIN(table ReadableTable) ReadableTable { - return CrossJoin(e, table) -} diff --git a/sqlbuilder/expression_test.go b/sqlbuilder/expression_test.go index 59289e5..126b147 100644 --- a/sqlbuilder/expression_test.go +++ b/sqlbuilder/expression_test.go @@ -5,71 +5,71 @@ import ( ) func TestExpressionIS_NULL(t *testing.T) { - assertExpressionSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL") - assertExpressionSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL") + assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL") + assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL") } func TestExpressionIS_NOT_NULL(t *testing.T) { - assertExpressionSerialize(t, table2Col3.IS_NOT_NULL(), "table2.col3 IS NOT NULL") - assertExpressionSerialize(t, table2Col3.ADD(table2Col3).IS_NOT_NULL(), "(table2.col3 + table2.col3) IS NOT NULL") + assertClauseSerialize(t, table2Col3.IS_NOT_NULL(), "table2.col3 IS NOT NULL") + assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NOT_NULL(), "(table2.col3 + table2.col3) IS NOT NULL") } func TestExpressionIS_DISTINCT_FROM(t *testing.T) { - assertExpressionSerialize(t, table2Col3.IS_DISTINCT_FROM(table2Col4), "(table2.col3 IS DISTINCT FROM table2.col4)") - assertExpressionSerialize(t, table2Col3.ADD(table2Col3).IS_DISTINCT_FROM(Int(23)), "((table2.col3 + table2.col3) IS DISTINCT FROM $1)", int64(23)) + assertClauseSerialize(t, table2Col3.IS_DISTINCT_FROM(table2Col4), "(table2.col3 IS DISTINCT FROM table2.col4)") + assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_DISTINCT_FROM(Int(23)), "((table2.col3 + table2.col3) IS DISTINCT FROM $1)", int64(23)) } func TestExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { - assertExpressionSerialize(t, table2Col3.IS_NOT_DISTINCT_FROM(table2Col4), "(table2.col3 IS NOT DISTINCT FROM table2.col4)") - assertExpressionSerialize(t, table2Col3.ADD(table2Col3).IS_NOT_DISTINCT_FROM(Int(23)), "((table2.col3 + table2.col3) IS NOT DISTINCT FROM $1)", int64(23)) + assertClauseSerialize(t, table2Col3.IS_NOT_DISTINCT_FROM(table2Col4), "(table2.col3 IS NOT DISTINCT FROM table2.col4)") + assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NOT_DISTINCT_FROM(Int(23)), "((table2.col3 + table2.col3) IS NOT DISTINCT FROM $1)", int64(23)) } func TestExpressionCAST_TO_BOOL(t *testing.T) { - assertExpressionSerialize(t, table2Col3.CAST_TO_BOOL(), "table2.col3::boolean") - assertExpressionSerialize(t, table2Col3.ADD(table2Col3).CAST_TO_BOOL(), "(table2.col3 + table2.col3)::boolean") + assertClauseSerialize(t, table2Col3.CAST_TO_BOOL(), "table2.col3::boolean") + assertClauseSerialize(t, table2Col3.ADD(table2Col3).CAST_TO_BOOL(), "(table2.col3 + table2.col3)::boolean") } func TestExpressionCAST_TO_INTEGER(t *testing.T) { - assertExpressionSerialize(t, table2Col3.CAST_TO_INTEGER(), "table2.col3::integer") + assertClauseSerialize(t, table2Col3.CAST_TO_INTEGER(), "table2.col3::integer") } func TestExpressionCAST_TO_DOUBLE(t *testing.T) { - assertExpressionSerialize(t, table2Col3.CAST_TO_DOUBLE(), "table2.col3::double precision") + assertClauseSerialize(t, table2Col3.CAST_TO_DOUBLE(), "table2.col3::double precision") } func TestExpressionCAST_TO_TEXT(t *testing.T) { - assertExpressionSerialize(t, table2Col3.CAST_TO_TEXT(), "table2.col3::text") + assertClauseSerialize(t, table2Col3.CAST_TO_TEXT(), "table2.col3::text") } func TestExpressionCAST_TO_DATE(t *testing.T) { - assertExpressionSerialize(t, table2Col3.CAST_TO_DATE(), "table2.col3::date") + assertClauseSerialize(t, table2Col3.CAST_TO_DATE(), "table2.col3::date") } func TestExpressionCAST_TO_TIME(t *testing.T) { - assertExpressionSerialize(t, table2Col3.CAST_TO_TIME(), "table2.col3::time without time zone") + assertClauseSerialize(t, table2Col3.CAST_TO_TIME(), "table2.col3::time without time zone") } func TestExpressionCAST_TO_TIMEZ(t *testing.T) { - assertExpressionSerialize(t, table2Col3.CAST_TO_TIMEZ(), "table2.col3::time with time zone") + assertClauseSerialize(t, table2Col3.CAST_TO_TIMEZ(), "table2.col3::time with time zone") } func TestExpressionCAST_TO_TIMESTAMP(t *testing.T) { - assertExpressionSerialize(t, table2Col3.CAST_TO_TIMESTAMP(), "table2.col3::timestamp without time zone") + assertClauseSerialize(t, table2Col3.CAST_TO_TIMESTAMP(), "table2.col3::timestamp without time zone") } func TestExpressionCAST_TO_TIMESTAMPZ(t *testing.T) { - assertExpressionSerialize(t, table2Col3.CAST_TO_TIMESTAMPZ(), "table2.col3::timestamp with time zone") + assertClauseSerialize(t, table2Col3.CAST_TO_TIMESTAMPZ(), "table2.col3::timestamp with time zone") } func TestIN(t *testing.T) { - assertExpressionSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), + assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), `($1 IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 )))`, float64(1.11)) - assertExpressionSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), + assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), `(ROW($1, table1.col1) IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" @@ -79,13 +79,13 @@ func TestIN(t *testing.T) { func TestNOT_IN(t *testing.T) { - assertExpressionSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), + assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), `($1 NOT IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 )))`, float64(1.11)) - assertExpressionSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), + assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), `(ROW($1, table1.col1) NOT IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" diff --git a/sqlbuilder/float_expression_test.go b/sqlbuilder/float_expression_test.go index 0cff39d..5ff05e4 100644 --- a/sqlbuilder/float_expression_test.go +++ b/sqlbuilder/float_expression_test.go @@ -5,61 +5,61 @@ import ( ) func TestFloatExpressionEQ(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.EQ(table2ColFloat), "(table1.colFloat = table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.EQ(Float(2.11)), "(table1.colFloat = $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.EQ(table2ColFloat), "(table1.colFloat = table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.EQ(Float(2.11)), "(table1.colFloat = $1)", float64(2.11)) } func TestFloatExpressionNOT_EQ(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.NOT_EQ(table2ColFloat), "(table1.colFloat != table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.NOT_EQ(Float(2.11)), "(table1.colFloat != $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.NOT_EQ(table2ColFloat), "(table1.colFloat != table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.NOT_EQ(Float(2.11)), "(table1.colFloat != $1)", float64(2.11)) } func TestFloatExpressionGT(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.GT(table2ColFloat), "(table1.colFloat > table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.GT(Float(2.11)), "(table1.colFloat > $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.GT(table2ColFloat), "(table1.colFloat > table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.GT(Float(2.11)), "(table1.colFloat > $1)", float64(2.11)) } func TestFloatExpressionGT_EQ(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.GT_EQ(table2ColFloat), "(table1.colFloat >= table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.GT_EQ(Float(2.11)), "(table1.colFloat >= $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.GT_EQ(table2ColFloat), "(table1.colFloat >= table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.GT_EQ(Float(2.11)), "(table1.colFloat >= $1)", float64(2.11)) } func TestFloatExpressionLT(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.LT(table2ColFloat), "(table1.colFloat < table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.LT(Float(2.11)), "(table1.colFloat < $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.LT(table2ColFloat), "(table1.colFloat < table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.LT(Float(2.11)), "(table1.colFloat < $1)", float64(2.11)) } func TestFloatExpressionLT_EQ(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.LT_EQ(table2ColFloat), "(table1.colFloat <= table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.LT_EQ(Float(2.11)), "(table1.colFloat <= $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.LT_EQ(table2ColFloat), "(table1.colFloat <= table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.LT_EQ(Float(2.11)), "(table1.colFloat <= $1)", float64(2.11)) } func TestFloatExpressionADD(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.ADD(table2ColFloat), "(table1.colFloat + table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.ADD(Float(2.11)), "(table1.colFloat + $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.ADD(table2ColFloat), "(table1.colFloat + table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.ADD(Float(2.11)), "(table1.colFloat + $1)", float64(2.11)) } func TestFloatExpressionSUB(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.SUB(table2ColFloat), "(table1.colFloat - table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.SUB(Float(2.11)), "(table1.colFloat - $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.SUB(table2ColFloat), "(table1.colFloat - table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.SUB(Float(2.11)), "(table1.colFloat - $1)", float64(2.11)) } func TestFloatExpressionMUL(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.MUL(table2ColFloat), "(table1.colFloat * table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.MUL(Float(2.11)), "(table1.colFloat * $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.MUL(table2ColFloat), "(table1.colFloat * table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.MUL(Float(2.11)), "(table1.colFloat * $1)", float64(2.11)) } func TestFloatExpressionDIV(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.DIV(table2ColFloat), "(table1.colFloat / table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.DIV(Float(2.11)), "(table1.colFloat / $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.DIV(table2ColFloat), "(table1.colFloat / table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.DIV(Float(2.11)), "(table1.colFloat / $1)", float64(2.11)) } func TestFloatExpressionMOD(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.MOD(table2ColFloat), "(table1.colFloat % table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.MOD(Float(2.11)), "(table1.colFloat % $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.MOD(table2ColFloat), "(table1.colFloat % table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.MOD(Float(2.11)), "(table1.colFloat % $1)", float64(2.11)) } func TestFloatExpressionPOW(t *testing.T) { - assertExpressionSerialize(t, table1ColFloat.POW(table2ColFloat), "(table1.colFloat ^ table2.colFloat)") - assertExpressionSerialize(t, table1ColFloat.POW(Float(2.11)), "(table1.colFloat ^ $1)", float64(2.11)) + assertClauseSerialize(t, table1ColFloat.POW(table2ColFloat), "(table1.colFloat ^ table2.colFloat)") + assertClauseSerialize(t, table1ColFloat.POW(Float(2.11)), "(table1.colFloat ^ $1)", float64(2.11)) } diff --git a/sqlbuilder/func_expression_test.go b/sqlbuilder/func_expression_test.go index ae0d1fc..30d2b16 100644 --- a/sqlbuilder/func_expression_test.go +++ b/sqlbuilder/func_expression_test.go @@ -7,169 +7,169 @@ import ( func TestFuncAVG(t *testing.T) { t.Run("float", func(t *testing.T) { - assertExpressionSerialize(t, AVGf(table1ColFloat), "AVG(table1.colFloat)") + assertClauseSerialize(t, AVGf(table1ColFloat), "AVG(table1.colFloat)") }) t.Run("integer", func(t *testing.T) { - assertExpressionSerialize(t, AVGi(table1ColInt), "AVG(table1.colInt)") + assertClauseSerialize(t, AVGi(table1ColInt), "AVG(table1.colInt)") }) } func TestFuncBIT_AND(t *testing.T) { - assertExpressionSerialize(t, BIT_AND(table1ColInt), "BIT_AND(table1.colInt)") + assertClauseSerialize(t, BIT_AND(table1ColInt), "BIT_AND(table1.colInt)") } func TestFuncBIT_OR(t *testing.T) { - assertExpressionSerialize(t, BIT_OR(table1ColInt), "BIT_OR(table1.colInt)") + assertClauseSerialize(t, BIT_OR(table1ColInt), "BIT_OR(table1.colInt)") } func TestFuncBOOL_AND(t *testing.T) { - assertExpressionSerialize(t, BOOL_AND(table1ColBool), "BOOL_AND(table1.colBool)") + assertClauseSerialize(t, BOOL_AND(table1ColBool), "BOOL_AND(table1.colBool)") } func TestFuncBOOL_OR(t *testing.T) { - assertExpressionSerialize(t, BOOL_OR(table1ColBool), "BOOL_OR(table1.colBool)") + assertClauseSerialize(t, BOOL_OR(table1ColBool), "BOOL_OR(table1.colBool)") } func TestFuncEVERY(t *testing.T) { - assertExpressionSerialize(t, EVERY(table1ColBool), "EVERY(table1.colBool)") + assertClauseSerialize(t, EVERY(table1ColBool), "EVERY(table1.colBool)") } func TestFuncMIN(t *testing.T) { t.Run("float", func(t *testing.T) { - assertExpressionSerialize(t, MINf(table1ColFloat), "MIN(table1.colFloat)") + assertClauseSerialize(t, MINf(table1ColFloat), "MIN(table1.colFloat)") }) t.Run("integer", func(t *testing.T) { - assertExpressionSerialize(t, MINi(table1ColInt), "MIN(table1.colInt)") + assertClauseSerialize(t, MINi(table1ColInt), "MIN(table1.colInt)") }) } func TestFuncMAX(t *testing.T) { t.Run("float", func(t *testing.T) { - assertExpressionSerialize(t, MAXf(table1ColFloat), "MAX(table1.colFloat)") - assertExpressionSerialize(t, MAXf(Float(11.2222)), "MAX($1)", float64(11.2222)) + assertClauseSerialize(t, MAXf(table1ColFloat), "MAX(table1.colFloat)") + assertClauseSerialize(t, MAXf(Float(11.2222)), "MAX($1)", float64(11.2222)) }) t.Run("integer", func(t *testing.T) { - assertExpressionSerialize(t, MAXi(table1ColInt), "MAX(table1.colInt)") - assertExpressionSerialize(t, MAXi(Int(11)), "MAX($1)", int64(11)) + assertClauseSerialize(t, MAXi(table1ColInt), "MAX(table1.colInt)") + assertClauseSerialize(t, MAXi(Int(11)), "MAX($1)", int64(11)) }) } func TestFuncSUM(t *testing.T) { t.Run("float", func(t *testing.T) { - assertExpressionSerialize(t, SUMf(table1ColFloat), "SUM(table1.colFloat)") - assertExpressionSerialize(t, SUMf(Float(11.2222)), "SUM($1)", float64(11.2222)) + assertClauseSerialize(t, SUMf(table1ColFloat), "SUM(table1.colFloat)") + assertClauseSerialize(t, SUMf(Float(11.2222)), "SUM($1)", float64(11.2222)) }) t.Run("integer", func(t *testing.T) { - assertExpressionSerialize(t, SUMi(table1ColInt), "SUM(table1.colInt)") - assertExpressionSerialize(t, SUMi(Int(11)), "SUM($1)", int64(11)) + assertClauseSerialize(t, SUMi(table1ColInt), "SUM(table1.colInt)") + assertClauseSerialize(t, SUMi(Int(11)), "SUM($1)", int64(11)) }) } func TestFuncCOUNT(t *testing.T) { - assertExpressionSerialize(t, COUNT(STAR), "COUNT(*)") - assertExpressionSerialize(t, COUNT(table1ColFloat), "COUNT(table1.colFloat)") - assertExpressionSerialize(t, COUNT(Float(11.2222)), "COUNT($1)", float64(11.2222)) + assertClauseSerialize(t, COUNT(STAR), "COUNT(*)") + assertClauseSerialize(t, COUNT(table1ColFloat), "COUNT(table1.colFloat)") + assertClauseSerialize(t, COUNT(Float(11.2222)), "COUNT($1)", float64(11.2222)) } func TestFuncABS(t *testing.T) { t.Run("float", func(t *testing.T) { - assertExpressionSerialize(t, ABSf(table1ColFloat), "ABS(table1.colFloat)") - assertExpressionSerialize(t, ABSf(Float(11.2222)), "ABS($1)", float64(11.2222)) + assertClauseSerialize(t, ABSf(table1ColFloat), "ABS(table1.colFloat)") + assertClauseSerialize(t, ABSf(Float(11.2222)), "ABS($1)", float64(11.2222)) }) t.Run("integer", func(t *testing.T) { - assertExpressionSerialize(t, ABSi(table1ColInt), "ABS(table1.colInt)") - assertExpressionSerialize(t, ABSi(Int(11)), "ABS($1)", int64(11)) + assertClauseSerialize(t, ABSi(table1ColInt), "ABS(table1.colInt)") + assertClauseSerialize(t, ABSi(Int(11)), "ABS($1)", int64(11)) }) } func TestFuncSQRT(t *testing.T) { t.Run("float", func(t *testing.T) { - assertExpressionSerialize(t, SQRTf(table1ColFloat), "SQRT(table1.colFloat)") - assertExpressionSerialize(t, SQRTf(Float(11.2222)), "SQRT($1)", float64(11.2222)) + assertClauseSerialize(t, SQRTf(table1ColFloat), "SQRT(table1.colFloat)") + assertClauseSerialize(t, SQRTf(Float(11.2222)), "SQRT($1)", float64(11.2222)) }) t.Run("integer", func(t *testing.T) { - assertExpressionSerialize(t, SQRTi(table1ColInt), "SQRT(table1.colInt)") - assertExpressionSerialize(t, SQRTi(Int(11)), "SQRT($1)", int64(11)) + assertClauseSerialize(t, SQRTi(table1ColInt), "SQRT(table1.colInt)") + assertClauseSerialize(t, SQRTi(Int(11)), "SQRT($1)", int64(11)) }) } func TestFuncCBRT(t *testing.T) { t.Run("float", func(t *testing.T) { - assertExpressionSerialize(t, CBRTf(table1ColFloat), "CBRT(table1.colFloat)") - assertExpressionSerialize(t, CBRTf(Float(11.2222)), "CBRT($1)", float64(11.2222)) + assertClauseSerialize(t, CBRTf(table1ColFloat), "CBRT(table1.colFloat)") + assertClauseSerialize(t, CBRTf(Float(11.2222)), "CBRT($1)", float64(11.2222)) }) t.Run("integer", func(t *testing.T) { - assertExpressionSerialize(t, CBRTi(table1ColInt), "CBRT(table1.colInt)") - assertExpressionSerialize(t, CBRTi(Int(11)), "CBRT($1)", int64(11)) + assertClauseSerialize(t, CBRTi(table1ColInt), "CBRT(table1.colInt)") + assertClauseSerialize(t, CBRTi(Int(11)), "CBRT($1)", int64(11)) }) } func TestFuncCEIL(t *testing.T) { - assertExpressionSerialize(t, CEIL(table1ColFloat), "CEIL(table1.colFloat)") - assertExpressionSerialize(t, CEIL(Float(11.2222)), "CEIL($1)", float64(11.2222)) + assertClauseSerialize(t, CEIL(table1ColFloat), "CEIL(table1.colFloat)") + assertClauseSerialize(t, CEIL(Float(11.2222)), "CEIL($1)", float64(11.2222)) } func TestFuncFLOOR(t *testing.T) { - assertExpressionSerialize(t, FLOOR(table1ColFloat), "FLOOR(table1.colFloat)") - assertExpressionSerialize(t, FLOOR(Float(11.2222)), "FLOOR($1)", float64(11.2222)) + assertClauseSerialize(t, FLOOR(table1ColFloat), "FLOOR(table1.colFloat)") + assertClauseSerialize(t, FLOOR(Float(11.2222)), "FLOOR($1)", float64(11.2222)) } func TestFuncROUND(t *testing.T) { - assertExpressionSerialize(t, ROUND(table1ColFloat), "ROUND(table1.colFloat)") - assertExpressionSerialize(t, ROUND(Float(11.2222)), "ROUND($1)", float64(11.2222)) + assertClauseSerialize(t, ROUND(table1ColFloat), "ROUND(table1.colFloat)") + assertClauseSerialize(t, ROUND(Float(11.2222)), "ROUND($1)", float64(11.2222)) - assertExpressionSerialize(t, ROUND(table1ColFloat, Int(2)), "ROUND(table1.colFloat, $1)", int64(2)) - assertExpressionSerialize(t, ROUND(Float(11.2222), Int(1)), "ROUND($1, $2)", float64(11.2222), int64(1)) + assertClauseSerialize(t, ROUND(table1ColFloat, Int(2)), "ROUND(table1.colFloat, $1)", int64(2)) + assertClauseSerialize(t, ROUND(Float(11.2222), Int(1)), "ROUND($1, $2)", float64(11.2222), int64(1)) } func TestFuncSIGN(t *testing.T) { - assertExpressionSerialize(t, SIGN(table1ColFloat), "SIGN(table1.colFloat)") - assertExpressionSerialize(t, SIGN(Float(11.2222)), "SIGN($1)", float64(11.2222)) + assertClauseSerialize(t, SIGN(table1ColFloat), "SIGN(table1.colFloat)") + assertClauseSerialize(t, SIGN(Float(11.2222)), "SIGN($1)", float64(11.2222)) } func TestFuncTRUNC(t *testing.T) { - assertExpressionSerialize(t, TRUNC(table1ColFloat), "TRUNC(table1.colFloat)") - assertExpressionSerialize(t, TRUNC(Float(11.2222)), "TRUNC($1)", float64(11.2222)) + assertClauseSerialize(t, TRUNC(table1ColFloat), "TRUNC(table1.colFloat)") + assertClauseSerialize(t, TRUNC(Float(11.2222)), "TRUNC($1)", float64(11.2222)) - assertExpressionSerialize(t, TRUNC(table1ColFloat, Int(2)), "TRUNC(table1.colFloat, $1)", int64(2)) - assertExpressionSerialize(t, TRUNC(Float(11.2222), Int(1)), "TRUNC($1, $2)", float64(11.2222), int64(1)) + assertClauseSerialize(t, TRUNC(table1ColFloat, Int(2)), "TRUNC(table1.colFloat, $1)", int64(2)) + assertClauseSerialize(t, TRUNC(Float(11.2222), Int(1)), "TRUNC($1, $2)", float64(11.2222), int64(1)) } func TestFuncLN(t *testing.T) { - assertExpressionSerialize(t, LN(table1ColFloat), "LN(table1.colFloat)") - assertExpressionSerialize(t, LN(Float(11.2222)), "LN($1)", float64(11.2222)) + assertClauseSerialize(t, LN(table1ColFloat), "LN(table1.colFloat)") + assertClauseSerialize(t, LN(Float(11.2222)), "LN($1)", float64(11.2222)) } func TestFuncLOG(t *testing.T) { - assertExpressionSerialize(t, LOG(table1ColFloat), "LOG(table1.colFloat)") - assertExpressionSerialize(t, LOG(Float(11.2222)), "LOG($1)", float64(11.2222)) + assertClauseSerialize(t, LOG(table1ColFloat), "LOG(table1.colFloat)") + assertClauseSerialize(t, LOG(Float(11.2222)), "LOG($1)", float64(11.2222)) } func TestFuncCOALESCE(t *testing.T) { - assertExpressionSerialize(t, COALESCE(table1ColFloat), "COALESCE(table1.colFloat)") - assertExpressionSerialize(t, COALESCE(Float(11.2222), NULL, String("str")), "COALESCE($1, NULL, $2)", float64(11.2222), "str") + assertClauseSerialize(t, COALESCE(table1ColFloat), "COALESCE(table1.colFloat)") + assertClauseSerialize(t, COALESCE(Float(11.2222), NULL, String("str")), "COALESCE($1, NULL, $2)", float64(11.2222), "str") } func TestFuncNULLIF(t *testing.T) { - assertExpressionSerialize(t, NULLIF(table1ColFloat, table2ColInt), "NULLIF(table1.colFloat, table2.colInt)") - assertExpressionSerialize(t, NULLIF(Float(11.2222), NULL), "NULLIF($1, NULL)", float64(11.2222)) + assertClauseSerialize(t, NULLIF(table1ColFloat, table2ColInt), "NULLIF(table1.colFloat, table2.colInt)") + assertClauseSerialize(t, NULLIF(Float(11.2222), NULL), "NULLIF($1, NULL)", float64(11.2222)) } func TestFuncGREATEST(t *testing.T) { - assertExpressionSerialize(t, GREATEST(table1ColFloat), "GREATEST(table1.colFloat)") - assertExpressionSerialize(t, GREATEST(Float(11.2222), NULL, String("str")), "GREATEST($1, NULL, $2)", float64(11.2222), "str") + assertClauseSerialize(t, GREATEST(table1ColFloat), "GREATEST(table1.colFloat)") + assertClauseSerialize(t, GREATEST(Float(11.2222), NULL, String("str")), "GREATEST($1, NULL, $2)", float64(11.2222), "str") } func TestFuncLEAST(t *testing.T) { - assertExpressionSerialize(t, LEAST(table1ColFloat), "LEAST(table1.colFloat)") - assertExpressionSerialize(t, LEAST(Float(11.2222), NULL, String("str")), "LEAST($1, NULL, $2)", float64(11.2222), "str") + assertClauseSerialize(t, LEAST(table1ColFloat), "LEAST(table1.colFloat)") + assertClauseSerialize(t, LEAST(Float(11.2222), NULL, String("str")), "LEAST($1, NULL, $2)", float64(11.2222), "str") } func TestInterval(t *testing.T) { diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index 18f4e87..dba3b9b 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -2,7 +2,7 @@ package sqlbuilder import ( "database/sql" - "github.com/dropbox/godropbox/errors" + "errors" "github.com/serenize/snaker" "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" "reflect" @@ -22,7 +22,7 @@ type InsertStatement interface { QUERY(selectStatement SelectStatement) InsertStatement } -func newInsertStatement(t writableTable, columns ...column) InsertStatement { +func newInsertStatement(t WritableTable, columns ...Column) InsertStatement { return &insertStatementImpl{ table: t, columns: columns, @@ -30,8 +30,8 @@ func newInsertStatement(t writableTable, columns ...column) InsertStatement { } type insertStatementImpl struct { - table writableTable - columns []column + table WritableTable + columns []Column rows [][]clause query SelectStatement returning []projection @@ -136,7 +136,7 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) queryData.writeString("INSERT INTO") if s.table == nil { - return "", nil, errors.Newf("nil tableName.") + return "", nil, errors.New("nil tableName.") } err = s.table.serialize(insert_statement, queryData) diff --git a/sqlbuilder/integer_expression_test.go b/sqlbuilder/integer_expression_test.go index 8b0e4ca..af0cffa 100644 --- a/sqlbuilder/integer_expression_test.go +++ b/sqlbuilder/integer_expression_test.go @@ -5,71 +5,71 @@ import ( ) func TestIntegerExpressionEQ(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.EQ(table2ColInt), "(table1.colInt = table2.colInt)") - assertExpressionSerialize(t, table1ColInt.EQ(Int(11)), "(table1.colInt = $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.EQ(table2ColInt), "(table1.colInt = table2.colInt)") + assertClauseSerialize(t, table1ColInt.EQ(Int(11)), "(table1.colInt = $1)", int64(11)) } func TestIntegerExpressionNOT_EQ(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.NOT_EQ(table2ColInt), "(table1.colInt != table2.colInt)") - assertExpressionSerialize(t, table1ColInt.NOT_EQ(Int(11)), "(table1.colInt != $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.NOT_EQ(table2ColInt), "(table1.colInt != table2.colInt)") + assertClauseSerialize(t, table1ColInt.NOT_EQ(Int(11)), "(table1.colInt != $1)", int64(11)) } func TestIntegerExpressionGT(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.GT(table2ColInt), "(table1.colInt > table2.colInt)") - assertExpressionSerialize(t, table1ColInt.GT(Int(11)), "(table1.colInt > $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.GT(table2ColInt), "(table1.colInt > table2.colInt)") + assertClauseSerialize(t, table1ColInt.GT(Int(11)), "(table1.colInt > $1)", int64(11)) } func TestIntegerExpressionGT_EQ(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.GT_EQ(table2ColInt), "(table1.colInt >= table2.colInt)") - assertExpressionSerialize(t, table1ColInt.GT_EQ(Int(11)), "(table1.colInt >= $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.GT_EQ(table2ColInt), "(table1.colInt >= table2.colInt)") + assertClauseSerialize(t, table1ColInt.GT_EQ(Int(11)), "(table1.colInt >= $1)", int64(11)) } func TestIntegerExpressionLT(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.LT(table2ColInt), "(table1.colInt < table2.colInt)") - assertExpressionSerialize(t, table1ColInt.LT(Int(11)), "(table1.colInt < $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.LT(table2ColInt), "(table1.colInt < table2.colInt)") + assertClauseSerialize(t, table1ColInt.LT(Int(11)), "(table1.colInt < $1)", int64(11)) } func TestIntegerExpressionLT_EQ(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.LT_EQ(table2ColInt), "(table1.colInt <= table2.colInt)") - assertExpressionSerialize(t, table1ColInt.LT_EQ(Int(11)), "(table1.colInt <= $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.LT_EQ(table2ColInt), "(table1.colInt <= table2.colInt)") + assertClauseSerialize(t, table1ColInt.LT_EQ(Int(11)), "(table1.colInt <= $1)", int64(11)) } func TestIntegerExpressionADD(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.ADD(table2ColInt), "(table1.colInt + table2.colInt)") - assertExpressionSerialize(t, table1ColInt.ADD(Int(11)), "(table1.colInt + $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.ADD(table2ColInt), "(table1.colInt + table2.colInt)") + assertClauseSerialize(t, table1ColInt.ADD(Int(11)), "(table1.colInt + $1)", int64(11)) } func TestIntegerExpressionSUB(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.SUB(table2ColInt), "(table1.colInt - table2.colInt)") - assertExpressionSerialize(t, table1ColInt.SUB(Int(11)), "(table1.colInt - $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.SUB(table2ColInt), "(table1.colInt - table2.colInt)") + assertClauseSerialize(t, table1ColInt.SUB(Int(11)), "(table1.colInt - $1)", int64(11)) } func TestIntegerExpressionMUL(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.MUL(table2ColInt), "(table1.colInt * table2.colInt)") - assertExpressionSerialize(t, table1ColInt.MUL(Int(11)), "(table1.colInt * $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.MUL(table2ColInt), "(table1.colInt * table2.colInt)") + assertClauseSerialize(t, table1ColInt.MUL(Int(11)), "(table1.colInt * $1)", int64(11)) } func TestIntegerExpressionDIV(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.colInt / table2.colInt)") - assertExpressionSerialize(t, table1ColInt.DIV(Int(11)), "(table1.colInt / $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.colInt / table2.colInt)") + assertClauseSerialize(t, table1ColInt.DIV(Int(11)), "(table1.colInt / $1)", int64(11)) } func TestIntExpressionMOD(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.MOD(table2ColInt), "(table1.colInt % table2.colInt)") - assertExpressionSerialize(t, table1ColInt.MOD(Int(11)), "(table1.colInt % $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.MOD(table2ColInt), "(table1.colInt % table2.colInt)") + assertClauseSerialize(t, table1ColInt.MOD(Int(11)), "(table1.colInt % $1)", int64(11)) } func TestIntExpressionPOW(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.POW(table2ColInt), "(table1.colInt ^ table2.colInt)") - assertExpressionSerialize(t, table1ColInt.POW(Int(11)), "(table1.colInt ^ $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.POW(table2ColInt), "(table1.colInt ^ table2.colInt)") + assertClauseSerialize(t, table1ColInt.POW(Int(11)), "(table1.colInt ^ $1)", int64(11)) } func TestIntExpressionBIT_SHIFT_LEFT(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.BIT_SHIFT_LEFT(table2ColInt), "(table1.colInt << table2.colInt)") - assertExpressionSerialize(t, table1ColInt.BIT_SHIFT_LEFT(Int(11)), "(table1.colInt << $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.BIT_SHIFT_LEFT(table2ColInt), "(table1.colInt << table2.colInt)") + assertClauseSerialize(t, table1ColInt.BIT_SHIFT_LEFT(Int(11)), "(table1.colInt << $1)", int64(11)) } func TestIntExpressionBIT_SHIFT_RIGHT(t *testing.T) { - assertExpressionSerialize(t, table1ColInt.BIT_SHIFT_RIGHT(table2ColInt), "(table1.colInt >> table2.colInt)") - assertExpressionSerialize(t, table1ColInt.BIT_SHIFT_RIGHT(Int(11)), "(table1.colInt >> $1)", int64(11)) + assertClauseSerialize(t, table1ColInt.BIT_SHIFT_RIGHT(table2ColInt), "(table1.colInt >> table2.colInt)") + assertClauseSerialize(t, table1ColInt.BIT_SHIFT_RIGHT(Int(11)), "(table1.colInt >> $1)", int64(11)) } diff --git a/sqlbuilder/literal_expression_test.go b/sqlbuilder/literal_expression_test.go index a5ac931..7c134bf 100644 --- a/sqlbuilder/literal_expression_test.go +++ b/sqlbuilder/literal_expression_test.go @@ -3,5 +3,5 @@ package sqlbuilder import "testing" func TestRawExpression(t *testing.T) { - assertExpressionSerialize(t, RAW("current_database()"), "current_database()") + assertClauseSerialize(t, RAW("current_database()"), "current_database()") } diff --git a/sqlbuilder/lock_statement.go b/sqlbuilder/lock_statement.go index 2d9a77c..ce113de 100644 --- a/sqlbuilder/lock_statement.go +++ b/sqlbuilder/lock_statement.go @@ -27,12 +27,12 @@ type LockStatement interface { } type lockStatementImpl struct { - tables []tableInterface + tables []WritableTable lockMode lockMode nowait bool } -func LOCK(tables ...tableInterface) LockStatement { +func LOCK(tables ...WritableTable) LockStatement { return &lockStatementImpl{ tables: tables, } diff --git a/sqlbuilder/operators.go b/sqlbuilder/operators.go index dd2be31..db23f1b 100644 --- a/sqlbuilder/operators.go +++ b/sqlbuilder/operators.go @@ -49,47 +49,6 @@ func GT_EQ(lhs, rhs Expression) BoolExpression { return newBinaryBoolExpression(lhs, rhs, ">=") } -func IS_TRUE(expr BoolExpression) BoolExpression { - return newPostifxBoolExpression(expr, "IS TRUE") -} - -func IS_NOT_TRUE(expr BoolExpression) BoolExpression { - return newPostifxBoolExpression(expr, "IS NOT TRUE") -} - -func IS_FALSE(expr BoolExpression) BoolExpression { - return newPostifxBoolExpression(expr, "IS FALSE") -} - -func IS_NOT_FALSE(expr BoolExpression) BoolExpression { - return newPostifxBoolExpression(expr, "IS NOT FALSE") -} - -func IS_UNKNOWN(expr BoolExpression) BoolExpression { - return newPostifxBoolExpression(expr, "IS UNKNOWN") -} - -func IS_NOT_UNKNOWN(expr BoolExpression) BoolExpression { - return newPostifxBoolExpression(expr, "IS NOT UNKNOWN") -} - -func And(lhs, rhs Expression) BoolExpression { - return newBinaryBoolExpression(lhs, rhs, "AND") -} - -// Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses -func Or(lhs, rhs Expression) BoolExpression { - return newBinaryBoolExpression(lhs, rhs, "OR") -} - -func Regexp(lhs, rhs Expression) BoolExpression { - return newBinaryBoolExpression(lhs, rhs, "REGEXP") -} - -func RegexpL(lhs Expression, val string) BoolExpression { - return Regexp(lhs, literal(val)) -} - func EXISTS(subQuery SelectStatement) BoolExpression { return newPrefixBoolExpression(subQuery, "EXISTS") } diff --git a/sqlbuilder/operators_test.go b/sqlbuilder/operators_test.go index 9b13e73..f587ec7 100644 --- a/sqlbuilder/operators_test.go +++ b/sqlbuilder/operators_test.go @@ -5,9 +5,9 @@ import "testing" func TestOperatorNOT(t *testing.T) { notExpression := NOT(Int(2).EQ(Int(1))) - assertExpressionSerialize(t, notExpression, "NOT ($1 = $2)", int64(2), int64(1)) + assertClauseSerialize(t, notExpression, "NOT ($1 = $2)", int64(2), int64(1)) assertProjectionSerialize(t, notExpression.AS("alias_not_expression"), `NOT ($1 = $2) AS "alias_not_expression"`, int64(2), int64(1)) - assertExpressionSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `(NOT ($1 = $2) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5)) + assertClauseSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `(NOT ($1 = $2) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5)) } func TestCase1(t *testing.T) { @@ -15,7 +15,7 @@ func TestCase1(t *testing.T) { WHEN(table3Col1.EQ(Int(1))).THEN(table3Col1.ADD(Int(1))). WHEN(table3Col1.EQ(Int(2))).THEN(table3Col1.ADD(Int(2))) - assertExpressionSerialize(t, query, `(CASE WHEN table3.col1 = $1 THEN table3.col1 + $2 WHEN table3.col1 = $3 THEN table3.col1 + $4 END)`, + assertClauseSerialize(t, query, `(CASE WHEN table3.col1 = $1 THEN table3.col1 + $2 WHEN table3.col1 = $3 THEN table3.col1 + $4 END)`, int64(1), int64(1), int64(2), int64(2)) } @@ -25,6 +25,6 @@ func TestCase2(t *testing.T) { WHEN(Int(2)).THEN(table3Col1.ADD(Int(2))). ELSE(Int(0)) - assertExpressionSerialize(t, query, `(CASE table3.col1 WHEN $1 THEN table3.col1 + $2 WHEN $3 THEN table3.col1 + $4 ELSE $5 END)`, + assertClauseSerialize(t, query, `(CASE table3.col1 WHEN $1 THEN table3.col1 + $2 WHEN $3 THEN table3.col1 + $4 ELSE $5 END)`, int64(1), int64(1), int64(2), int64(2), int64(0)) } diff --git a/sqlbuilder/order_by_clause.go b/sqlbuilder/order_by_clause.go index b4655f8..2119c72 100644 --- a/sqlbuilder/order_by_clause.go +++ b/sqlbuilder/order_by_clause.go @@ -1,6 +1,6 @@ package sqlbuilder -import "github.com/dropbox/godropbox/errors" +import "errors" type OrderByClause interface { serializeAsOrderBy(statement statementType, out *queryData) error @@ -13,7 +13,7 @@ type orderByClauseImpl struct { func (o *orderByClauseImpl) serializeAsOrderBy(statement statementType, out *queryData) error { if o.expression == nil { - return errors.Newf("nil orderBy by clause.") + return errors.New("nil orderBy by clause.") } if err := o.expression.serializeAsOrderBy(statement, out); err != nil { diff --git a/sqlbuilder/projection.go b/sqlbuilder/projection.go index a90914d..6e5e189 100644 --- a/sqlbuilder/projection.go +++ b/sqlbuilder/projection.go @@ -6,7 +6,7 @@ type projection interface { //------------------------------------------------------// // Dummy type for select * AllColumns -type ColumnList []column +type ColumnList []Column func (cl ColumnList) isProjectionType() {} diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 26eb2e4..65f20b4 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -2,7 +2,7 @@ package sqlbuilder import ( "database/sql" - "github.com/dropbox/godropbox/errors" + "errors" "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" ) @@ -30,7 +30,7 @@ func SELECT(projection ...projection) SelectStatement { return newSelectStatement(nil, projection) } -// NOTE: SelectStatement purposely does not implement the Table interface since +// NOTE: SelectStatement purposely does not implement the tableImpl interface since // mysql's subquery performance is horrible. type selectStatementImpl struct { expressionInterfaceImpl @@ -53,7 +53,7 @@ func defaultProjectionAliasing(projections []projection) []projection { aliasedProjections := []projection{} for _, projection := range projections { - if column, ok := projection.(column); ok { + if column, ok := projection.(Column); ok { aliasedProjections = append(aliasedProjections, column.DefaultAlias()) } else if columnList, ok := projection.(ColumnList); ok { aliasedProjections = append(aliasedProjections, columnList.DefaultAlias()...) @@ -87,7 +87,7 @@ func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement { func (s *selectStatementImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { if s == nil { - return errors.New("Select statement is nil. ") + return errors.New("Select expression is nil. ") } out.writeString("(") @@ -107,7 +107,7 @@ func (s *selectStatementImpl) serialize(statement statementType, out *queryData, func (s *selectStatementImpl) serializeImpl(out *queryData) error { if s == nil { - return errors.New("Select statement is nil. ") + return errors.New("Select expression is nil. ") } out.nextLine() @@ -205,10 +205,7 @@ func (s *selectStatementImpl) DebugSql() (query string, err error) { } func (s *selectStatementImpl) AsTable(alias string) expressionTable { - return &expressionTableImpl{ - statement: s, - alias: alias, - } + return newExpressionTable(s.parent, alias) } func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { diff --git a/sqlbuilder/set_statement.go b/sqlbuilder/set_statement.go index cd2b751..6bb4a4f 100644 --- a/sqlbuilder/set_statement.go +++ b/sqlbuilder/set_statement.go @@ -2,7 +2,7 @@ package sqlbuilder import ( "database/sql" - "github.com/dropbox/godropbox/errors" + "errors" "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" ) @@ -92,15 +92,12 @@ func (us *setStatementImpl) OFFSET(offset int64) SetStatement { } func (us *setStatementImpl) AsTable(alias string) expressionTable { - return &expressionTableImpl{ - statement: us, - alias: alias, - } + return newExpressionTable(us.parent, alias) } func (s *setStatementImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { if s == nil { - return errors.New("Set statement is nil. ") + return errors.New("Set expression is nil. ") } if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 { @@ -125,11 +122,11 @@ func (s *setStatementImpl) serialize(statement statementType, out *queryData, op func (s *setStatementImpl) serializeImpl(out *queryData) error { if s == nil { - return errors.New("Set statement is nil. ") + return errors.New("Set expression is nil. ") } if len(s.selects) < 2 { - return errors.Newf("UNION Statement must have at least two SELECT statements.") + return errors.New("UNION Statement must have at least two SELECT statements.") } out.nextLine() diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go index d2bd06f..03c15d6 100644 --- a/sqlbuilder/statement_test.go +++ b/sqlbuilder/statement_test.go @@ -7,7 +7,7 @@ import ( gc "gopkg.in/check.v1" - "github.com/dropbox/godropbox/errors" + "errors" ) type StmtSuite struct { diff --git a/sqlbuilder/string_expression_test.go b/sqlbuilder/string_expression_test.go index a697a53..366324b 100644 --- a/sqlbuilder/string_expression_test.go +++ b/sqlbuilder/string_expression_test.go @@ -6,62 +6,62 @@ import ( func TestStringEQ(t *testing.T) { exp := table3StrCol.EQ(table2ColStr) - assertExpressionSerialize(t, exp, "(table3.col2 = table2.colStr)") + assertClauseSerialize(t, exp, "(table3.col2 = table2.colStr)") exp = table3StrCol.EQ(String("JOHN")) - assertExpressionSerialize(t, exp, "(table3.col2 = $1)", "JOHN") + assertClauseSerialize(t, exp, "(table3.col2 = $1)", "JOHN") } func TestStringNOT_EQ(t *testing.T) { exp := table3StrCol.NOT_EQ(table2ColStr) - assertExpressionSerialize(t, exp, "(table3.col2 != table2.colStr)") - assertExpressionSerialize(t, table3StrCol.NOT_EQ(String("JOHN")), "(table3.col2 != $1)", "JOHN") + assertClauseSerialize(t, exp, "(table3.col2 != table2.colStr)") + assertClauseSerialize(t, table3StrCol.NOT_EQ(String("JOHN")), "(table3.col2 != $1)", "JOHN") } func TestStringGT(t *testing.T) { exp := table3StrCol.GT(table2ColStr) - assertExpressionSerialize(t, exp, "(table3.col2 > table2.colStr)") - assertExpressionSerialize(t, table3StrCol.GT(String("JOHN")), "(table3.col2 > $1)", "JOHN") + assertClauseSerialize(t, exp, "(table3.col2 > table2.colStr)") + assertClauseSerialize(t, table3StrCol.GT(String("JOHN")), "(table3.col2 > $1)", "JOHN") } func TestStringGT_EQ(t *testing.T) { exp := table3StrCol.GT_EQ(table2ColStr) - assertExpressionSerialize(t, exp, "(table3.col2 >= table2.colStr)") - assertExpressionSerialize(t, table3StrCol.GT_EQ(String("JOHN")), "(table3.col2 >= $1)", "JOHN") + assertClauseSerialize(t, exp, "(table3.col2 >= table2.colStr)") + assertClauseSerialize(t, table3StrCol.GT_EQ(String("JOHN")), "(table3.col2 >= $1)", "JOHN") } func TestStringLT(t *testing.T) { exp := table3StrCol.LT(table2ColStr) - assertExpressionSerialize(t, exp, "(table3.col2 < table2.colStr)") - assertExpressionSerialize(t, table3StrCol.LT(String("JOHN")), "(table3.col2 < $1)", "JOHN") + assertClauseSerialize(t, exp, "(table3.col2 < table2.colStr)") + assertClauseSerialize(t, table3StrCol.LT(String("JOHN")), "(table3.col2 < $1)", "JOHN") } func TestStringLT_EQ(t *testing.T) { exp := table3StrCol.LT_EQ(table2ColStr) - assertExpressionSerialize(t, exp, "(table3.col2 <= table2.colStr)") - assertExpressionSerialize(t, table3StrCol.LT_EQ(String("JOHN")), "(table3.col2 <= $1)", "JOHN") + assertClauseSerialize(t, exp, "(table3.col2 <= table2.colStr)") + assertClauseSerialize(t, table3StrCol.LT_EQ(String("JOHN")), "(table3.col2 <= $1)", "JOHN") } func TestStringCONCAT(t *testing.T) { - assertExpressionSerialize(t, table3StrCol.CONCAT(table2ColStr), "(table3.col2 || table2.colStr)") - assertExpressionSerialize(t, table3StrCol.CONCAT(String("JOHN")), "(table3.col2 || $1)", "JOHN") + assertClauseSerialize(t, table3StrCol.CONCAT(table2ColStr), "(table3.col2 || table2.colStr)") + assertClauseSerialize(t, table3StrCol.CONCAT(String("JOHN")), "(table3.col2 || $1)", "JOHN") } func TestStringLIKE(t *testing.T) { - assertExpressionSerialize(t, table3StrCol.LIKE(table2ColStr), "(table3.col2 LIKE table2.colStr)") - assertExpressionSerialize(t, table3StrCol.LIKE(String("JOHN")), "(table3.col2 LIKE $1)", "JOHN") + assertClauseSerialize(t, table3StrCol.LIKE(table2ColStr), "(table3.col2 LIKE table2.colStr)") + assertClauseSerialize(t, table3StrCol.LIKE(String("JOHN")), "(table3.col2 LIKE $1)", "JOHN") } func TestStringNOT_LIKE(t *testing.T) { - assertExpressionSerialize(t, table3StrCol.NOT_LIKE(table2ColStr), "(table3.col2 NOT LIKE table2.colStr)") - assertExpressionSerialize(t, table3StrCol.NOT_LIKE(String("JOHN")), "(table3.col2 NOT LIKE $1)", "JOHN") + assertClauseSerialize(t, table3StrCol.NOT_LIKE(table2ColStr), "(table3.col2 NOT LIKE table2.colStr)") + assertClauseSerialize(t, table3StrCol.NOT_LIKE(String("JOHN")), "(table3.col2 NOT LIKE $1)", "JOHN") } func TestStringSIMILAR_TO(t *testing.T) { - assertExpressionSerialize(t, table3StrCol.SIMILAR_TO(table2ColStr), "(table3.col2 SIMILAR TO table2.colStr)") - assertExpressionSerialize(t, table3StrCol.SIMILAR_TO(String("JOHN")), "(table3.col2 SIMILAR TO $1)", "JOHN") + assertClauseSerialize(t, table3StrCol.SIMILAR_TO(table2ColStr), "(table3.col2 SIMILAR TO table2.colStr)") + assertClauseSerialize(t, table3StrCol.SIMILAR_TO(String("JOHN")), "(table3.col2 SIMILAR TO $1)", "JOHN") } func TestStringNOT_SIMILAR_TO(t *testing.T) { - assertExpressionSerialize(t, table3StrCol.NOT_SIMILAR_TO(table2ColStr), "(table3.col2 NOT SIMILAR TO table2.colStr)") - assertExpressionSerialize(t, table3StrCol.NOT_SIMILAR_TO(String("JOHN")), "(table3.col2 NOT SIMILAR TO $1)", "JOHN") + assertClauseSerialize(t, table3StrCol.NOT_SIMILAR_TO(table2ColStr), "(table3.col2 NOT SIMILAR TO table2.colStr)") + assertClauseSerialize(t, table3StrCol.NOT_SIMILAR_TO(String("JOHN")), "(table3.col2 NOT SIMILAR TO $1)", "JOHN") } diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index 220218e..8b68494 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -3,22 +3,10 @@ package sqlbuilder import ( - "github.com/dropbox/godropbox/errors" + "errors" ) -type tableInterface interface { - clause - SchemaName() string - TableName() string - - Columns() []column -} - -// The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause -// are not supported. -type ReadableTable interface { - tableInterface - +type readableTable interface { // Generates a select query on the current tableName. SELECT(projections ...projection) SelectStatement @@ -38,20 +26,87 @@ type ReadableTable interface { // The sql tableName write interface. type writableTable interface { - tableInterface - - INSERT(columns ...column) InsertStatement - UPDATE(columns ...column) UpdateStatement + INSERT(columns ...Column) InsertStatement + UPDATE(columns ...Column) UpdateStatement DELETE() DeleteStatement LOCK() LockStatement } -// Defines a physical tableName in the database that is both readable and writable. -// This function will panic if name is not valid -func NewTable(schemaName, name string, columns ...column) *Table { +type ReadableTable interface { + readableTable + clause +} - t := &Table{ +type WritableTable interface { + writableTable + clause +} + +type Table interface { + readableTable + writableTable + clause + SchemaName() string + TableName() string + AS(alias string) +} + +type readableTableInterfaceImpl struct { + parent ReadableTable +} + +// Generates a select query on the current tableName. +func (r *readableTableInterfaceImpl) SELECT(projections ...projection) SelectStatement { + return newSelectStatement(r.parent, projections) +} + +// Creates a inner join tableName Expression using onCondition. +func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { + return newJoinTable(r.parent, table, innerJoin, onCondition) +} + +// Creates a left join tableName Expression using onCondition. +func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { + return newJoinTable(r.parent, table, leftJoin, onCondition) +} + +// Creates a right join tableName Expression using onCondition. +func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { + return newJoinTable(r.parent, table, rightJoin, onCondition) +} + +func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { + return newJoinTable(r.parent, table, fullJoin, onCondition) +} + +func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable { + return newJoinTable(r.parent, table, crossJoin, nil) +} + +type writableTableInterfaceImpl struct { + parent WritableTable +} + +func (w *writableTableInterfaceImpl) INSERT(columns ...Column) InsertStatement { + return newInsertStatement(w.parent, columns...) +} + +func (w *writableTableInterfaceImpl) UPDATE(columns ...Column) UpdateStatement { + return newUpdateStatement(w.parent, columns) +} + +func (w *writableTableInterfaceImpl) DELETE() DeleteStatement { + return newDeleteStatement(w.parent) +} + +func (w *writableTableInterfaceImpl) LOCK() LockStatement { + return LOCK(w.parent) +} + +func NewTable(schemaName, name string, columns ...Column) Table { + + t := &tableImpl{ schemaName: schemaName, name: name, columns: columns, @@ -60,25 +115,23 @@ func NewTable(schemaName, name string, columns ...column) *Table { c.setTableName(name) } + t.readableTableInterfaceImpl.parent = t + t.writableTableInterfaceImpl.parent = t + return t } -type Table struct { +type tableImpl struct { + readableTableInterfaceImpl + writableTableInterfaceImpl + schemaName string name string alias string - columns []column + columns []Column } -func (t *Table) Column(name string) column { - return &baseColumn{ - name: name, - nullable: NotNullable, - tableName: t.name, - } -} - -func (t *Table) SetAlias(alias string) { +func (t *tableImpl) AS(alias string) { t.alias = alias for _, c := range t.columns { @@ -87,29 +140,22 @@ func (t *Table) SetAlias(alias string) { } // Returns the tableName's name in the database -func (t *Table) SchemaName() string { +func (t *tableImpl) SchemaName() string { return t.schemaName } // Returns the tableName's name in the database -func (t *Table) TableName() string { +func (t *tableImpl) TableName() string { return t.name } -func (t *Table) SchemaTableName() string { +func (t *tableImpl) SchemaTableName() string { return t.schemaName } -// Returns a list of the tableName's columns -func (t *Table) Columns() []column { - return t.columns -} - -// Generates the sql string for the current tableName Expression. Note: the -// generated string may not be a valid/executable sql Statement. -func (t *Table) serialize(statement statementType, out *queryData, options ...serializeOption) error { +func (t *tableImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { if t == nil { - return errors.Newf("Table is nil. ") + return errors.New("tableImpl is nil. ") } out.writeString(t.schemaName) @@ -124,71 +170,20 @@ func (t *Table) serialize(statement statementType, out *queryData, options ...se return nil } -// Generates a select query on the current tableName. -func (t *Table) SELECT(projections ...projection) SelectStatement { - return newSelectStatement(t, projections) -} - -// Creates a inner join tableName Expression using onCondition. -func (t *Table) INNER_JOIN( - table ReadableTable, - onCondition BoolExpression) ReadableTable { - - return InnerJoinOn(t, table, onCondition) -} - -// Creates a left join tableName Expression using onCondition. -func (t *Table) LEFT_JOIN( - table ReadableTable, - onCondition BoolExpression) ReadableTable { - - return LeftJoinOn(t, table, onCondition) -} - -// Creates a right join tableName Expression using onCondition. -func (t *Table) RIGHT_JOIN( - table ReadableTable, - onCondition BoolExpression) ReadableTable { - - return RightJoinOn(t, table, onCondition) -} - -func (t *Table) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return FullJoin(t, table, onCondition) -} - -func (t *Table) CROSS_JOIN(table ReadableTable) ReadableTable { - return CrossJoin(t, table) -} - -func (t *Table) INSERT(columns ...column) InsertStatement { - return newInsertStatement(t, columns...) -} - -func (t *Table) UPDATE(columns ...column) UpdateStatement { - return newUpdateStatement(t, columns) -} - -func (t *Table) DELETE() DeleteStatement { - return newDeleteStatement(t) -} - -func (t *Table) LOCK() LockStatement { - return LOCK(t) -} - type joinType int const ( - INNER_JOIN joinType = iota - LEFT_JOIN - RIGHT_JOIN - FULL_JOIN - CROSS_JOIN + innerJoin joinType = iota + leftJoin + rightJoin + fullJoin + crossJoin ) // Join expressions are pseudo readable tables. type joinTable struct { + readableTableInterfaceImpl + lhs ReadableTable rhs ReadableTable join_type joinType @@ -201,51 +196,16 @@ func newJoinTable( join_type joinType, onCondition BoolExpression) ReadableTable { - return &joinTable{ + joinTable := &joinTable{ lhs: lhs, rhs: rhs, join_type: join_type, onCondition: onCondition, } -} -func InnerJoinOn( - lhs ReadableTable, - rhs ReadableTable, - onCondition BoolExpression) ReadableTable { + joinTable.readableTableInterfaceImpl.parent = joinTable - return newJoinTable(lhs, rhs, INNER_JOIN, onCondition) -} - -func LeftJoinOn( - lhs ReadableTable, - rhs ReadableTable, - onCondition BoolExpression) ReadableTable { - - return newJoinTable(lhs, rhs, LEFT_JOIN, onCondition) -} - -func RightJoinOn( - lhs ReadableTable, - rhs ReadableTable, - onCondition BoolExpression) ReadableTable { - - return newJoinTable(lhs, rhs, RIGHT_JOIN, onCondition) -} - -func FullJoin( - lhs ReadableTable, - rhs ReadableTable, - onCondition BoolExpression) ReadableTable { - - return newJoinTable(lhs, rhs, FULL_JOIN, onCondition) -} - -func CrossJoin( - lhs ReadableTable, - rhs ReadableTable) ReadableTable { - - return newJoinTable(lhs, rhs, CROSS_JOIN, nil) + return joinTable } func (t *joinTable) SchemaName() string { @@ -256,33 +216,13 @@ func (t *joinTable) TableName() string { return "" } -func (t *joinTable) Columns() []column { - columns := make([]column, 0) - columns = append(columns, t.lhs.Columns()...) - columns = append(columns, t.rhs.Columns()...) - - return columns -} - -func (t *joinTable) Column(name string) column { - return &baseColumn{ - name: name, - nullable: NotNullable, - } -} - func (t *joinTable) serialize(statement statementType, out *queryData, options ...serializeOption) (err error) { if t == nil { return errors.New("Join table is nil. ") } - if t.lhs == nil { - return errors.Newf("nil lhs.") - } - if t.rhs == nil { - return errors.Newf("nil rhs.") - } - if t.onCondition == nil && t.join_type != CROSS_JOIN { - return errors.Newf("nil onCondition.") + + if isNil(t.lhs) { + return errors.New("left hand side of join operation is nil table") } if err = t.lhs.serialize(statement, out); err != nil { @@ -292,24 +232,32 @@ func (t *joinTable) serialize(statement statementType, out *queryData, options . out.nextLine() switch t.join_type { - case INNER_JOIN: - out.writeString("JOIN") - case LEFT_JOIN: + case innerJoin: + out.writeString("INNER JOIN") + case leftJoin: out.writeString("LEFT JOIN") - case RIGHT_JOIN: + case rightJoin: out.writeString("RIGHT JOIN") - case FULL_JOIN: + case fullJoin: out.writeString("FULL JOIN") - case CROSS_JOIN: + case crossJoin: out.writeString("CROSS JOIN") } + if isNil(t.rhs) { + return errors.New("right hand side of join operation is nil table") + } + if err = t.rhs.serialize(statement, out); err != nil { return } + if t.onCondition == nil && t.join_type != crossJoin { + return errors.New("join condition is nil") + } + if t.onCondition != nil { - out.writeString(" ON ") + out.writeString("ON") if err = t.onCondition.serialize(statement, out); err != nil { return } @@ -317,36 +265,3 @@ func (t *joinTable) serialize(statement statementType, out *queryData, options . return nil } - -func (t *joinTable) SELECT(projections ...projection) SelectStatement { - return newSelectStatement(t, projections) -} - -func (t *joinTable) INNER_JOIN( - table ReadableTable, - onCondition BoolExpression) ReadableTable { - - return InnerJoinOn(t, table, onCondition) -} - -func (t *joinTable) LEFT_JOIN( - table ReadableTable, - onCondition BoolExpression) ReadableTable { - - return LeftJoinOn(t, table, onCondition) -} - -func (t *joinTable) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return FullJoin(t, table, onCondition) -} - -func (t *joinTable) CROSS_JOIN(table ReadableTable) ReadableTable { - return CrossJoin(t, table) -} - -func (t *joinTable) RIGHT_JOIN( - table ReadableTable, - onCondition BoolExpression) ReadableTable { - - return RightJoinOn(t, table, onCondition) -} diff --git a/sqlbuilder/table_test.go b/sqlbuilder/table_test.go index fb85357..220b4b2 100644 --- a/sqlbuilder/table_test.go +++ b/sqlbuilder/table_test.go @@ -1,188 +1,101 @@ -// +build disabled - package sqlbuilder import ( - "bytes" - - gc "gopkg.in/check.v1" + "testing" ) -type TableSuite struct { +func TestJoinNilInputs(t *testing.T) { + assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), + "right hand side of join operation is nil table") + assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil), + "join condition is nil") } -var _ = gc.Suite(&TableSuite{}) - -// NOTE: tables / columns are defined in statement_test.go - -func (s *TableSuite) TestBasicColumns(c *gc.C) { - cols := table1.Columns() - - c.Assert(len(cols), gc.Equals, 4) - c.Assert(cols[0], gc.Equals, table1Col1) - c.Assert(cols[1], gc.Equals, table1ColFloat) - c.Assert(cols[2], gc.Equals, table1Col3) - c.Assert(cols[3], gc.Equals, table1ColTime) +func TestINNER_JOIN(t *testing.T) { + assertClauseSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +INNER JOIN db.table2 ON (table1.colInt = table2.colInt)`) + assertClauseSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)). + INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +INNER JOIN db.table2 ON (table1.colInt = table2.colInt) +INNER JOIN db.table3 ON (table1.colInt = table3.colInt)`) + assertClauseSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(Int(1))). + INNER_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +INNER JOIN db.table2 ON (table1.colInt = $1) +INNER JOIN db.table3 ON (table1.colInt = $2)`, int64(1), int64(2)) } -func (s *TableSuite) TestCValidLookup(c *gc.C) { - col := table1.C("col1") - - buf := &bytes.Buffer{} - - err := col.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1") +func TestLEFT_JOIN(t *testing.T) { + assertClauseSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +LEFT JOIN db.table2 ON (table1.colInt = table2.colInt)`) + assertClauseSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)). + LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +LEFT JOIN db.table2 ON (table1.colInt = table2.colInt) +LEFT JOIN db.table3 ON (table1.colInt = table3.colInt)`) + assertClauseSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(Int(1))). + LEFT_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +LEFT JOIN db.table2 ON (table1.colInt = $1) +LEFT JOIN db.table3 ON (table1.colInt = $2)`, int64(1), int64(2)) } -func (s *TableSuite) TestCInvalidLookup(c *gc.C) { - col := table1.C("foo") - - buf := &bytes.Buffer{} - - err := col.SerializeSql(buf) - c.Assert(err, gc.NotNil) +func TestRIGHT_JOIN(t *testing.T) { + assertClauseSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +RIGHT JOIN db.table2 ON (table1.colInt = table2.colInt)`) + assertClauseSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)). + RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +RIGHT JOIN db.table2 ON (table1.colInt = table2.colInt) +RIGHT JOIN db.table3 ON (table1.colInt = table3.colInt)`) + assertClauseSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))). + RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +RIGHT JOIN db.table2 ON (table1.colInt = $1) +RIGHT JOIN db.table3 ON (table1.colInt = $2)`, int64(1), int64(2)) } -func (s *TableSuite) TestJoinNilLeftTable(c *gc.C) { - join := InnerJoinOn(nil, table2, EqL(table2Col3, 123)) - - buf := &bytes.Buffer{} - - err := join.serialize("", buf) - c.Assert(err, gc.NotNil) +func TestFULL_JOIN(t *testing.T) { + assertClauseSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +FULL JOIN db.table2 ON (table1.colInt = table2.colInt)`) + assertClauseSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)). + FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +FULL JOIN db.table2 ON (table1.colInt = table2.colInt) +FULL JOIN db.table3 ON (table1.colInt = table3.colInt)`) + assertClauseSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(Int(1))). + FULL_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +FULL JOIN db.table2 ON (table1.colInt = $1) +FULL JOIN db.table3 ON (table1.colInt = $2)`, int64(1), int64(2)) } -func (s *TableSuite) TestJoinNilRightTable(c *gc.C) { - join := InnerJoinOn(table1, nil, EqL(table2Col3, 123)) - - buf := &bytes.Buffer{} - - err := join.serialize("", buf) - c.Assert(err, gc.NotNil) -} - -func (s *TableSuite) TestJoinNilOnCondition(c *gc.C) { - join := InnerJoinOn(table1, table2, nil) - - buf := &bytes.Buffer{} - - err := join.serialize("", buf) - c.Assert(err, gc.NotNil) -} - -func (s *TableSuite) TestInnerJoin(c *gc.C) { - join := table1.InnerJoinOn(table2, EQ(table1Col3, table2Col3)) - - buf := &bytes.Buffer{} - - err := join.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "db.table1 JOIN db.table2 ON table1.col3=table2.col3") -} - -func (s *TableSuite) TestLeftJoin(c *gc.C) { - join := table1.LEFT_JOIN(table2, EQ(table1Col3, table2Col3)) - - buf := &bytes.Buffer{} - - err := join.serialize("", buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "db.table1 LEFT JOIN db.table2 "+ - "ON table1.col3=table2.col3") -} - -func (s *TableSuite) TestRightJoin(c *gc.C) { - join := table1.RIGHT_JOIN(table2, EQ(table1Col3, table2Col3)) - - buf := &bytes.Buffer{} - - err := join.serialize("", buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "db.table1 RIGHT JOIN db.table2 "+ - "ON table1.col3=table2.col3") -} - -//func (s *TableSuite) TestJoinColumns(c *gc.C) { -// join := table1.RIGHT_JOIN(table2, EQ(table1Col3, table2Col3)) -// -// cols := join.Columns() -// c.Assert(len(cols), gc.Equals, 6) -// c.Assert(cols[0], gc.Equals, table1Col1) -// c.Assert(cols[1], gc.Equals, table1ColFloat) -// c.Assert(cols[2], gc.Equals, table1Col3) -// c.Assert(cols[3], gc.Equals, table1ColTime) -// c.Assert(cols[4], gc.Equals, table2Col3) -// c.Assert(cols[5], gc.Equals, table2Col4) -//} - -func (s *TableSuite) TestNestedInnerJoin(c *gc.C) { - join1 := table1.InnerJoinOn(table2, EQ(table1Col3, table2Col3)) - join2 := join1.InnerJoinOn(table3, EQ(table1Col1, table3Col1)) - - buf := &bytes.Buffer{} - - err := join2.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "db.table1 "+ - "JOIN db.table2 ON table1.col3=table2.col3 "+ - "JOIN db.table3 ON table1.col1=table3.col1") -} - -func (s *TableSuite) TestNestedLeftJoin(c *gc.C) { - join1 := table1.InnerJoinOn(table2, EQ(table1Col3, table2Col3)) - join2 := join1.LeftJoinOn(table3, EQ(table1Col1, table3Col1)) - - buf := &bytes.Buffer{} - - err := join2.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "db.table1 "+ - "JOIN db.table2 ON table1.col3=table2.col3 "+ - "LEFT JOIN db.table3 ON table1.col1=table3.col1") -} - -func (s *TableSuite) TestNestedRightJoin(c *gc.C) { - join1 := table1.InnerJoinOn(table2, EQ(table1Col3, table2Col3)) - join2 := join1.RightJoinOn(table3, EQ(table1Col1, table3Col1)) - - buf := &bytes.Buffer{} - - err := join2.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "db.table1 "+ - "JOIN db.table2 ON table1.col3=table2.col3 "+ - "RIGHT JOIN db.table3 ON table1.col1=table3.col1") +func TestCROSS_JOIN(t *testing.T) { + assertClauseSerialize(t, table1. + CROSS_JOIN(table2), + `db.table1 +CROSS JOIN db.table2`) + assertClauseSerialize(t, table1. + CROSS_JOIN(table2). + CROSS_JOIN(table3), + `db.table1 +CROSS JOIN db.table2 +CROSS JOIN db.table3`) } diff --git a/sqlbuilder/test_utils.go b/sqlbuilder/test_utils.go index 403af74..75ed415 100644 --- a/sqlbuilder/test_utils.go +++ b/sqlbuilder/test_utils.go @@ -1,16 +1,17 @@ package sqlbuilder import ( + "fmt" "gotest.tools/assert" "testing" ) -var table1Col1 = NewIntegerColumn("col1", Nullable) -var table1ColInt = NewIntegerColumn("colInt", Nullable) -var table1ColFloat = NewFloatColumn("colFloat", Nullable) -var table1Col3 = NewIntegerColumn("col3", Nullable) -var table1ColTime = NewTimeColumn("colTime", Nullable) -var table1ColBool = NewBoolColumn("colBool", Nullable) +var table1Col1 = NewIntegerColumn("col1", true) +var table1ColInt = NewIntegerColumn("colInt", true) +var table1ColFloat = NewFloatColumn("colFloat", true) +var table1Col3 = NewIntegerColumn("col3", true) +var table1ColTime = NewTimeColumn("colTime", true) +var table1ColBool = NewBoolColumn("colBool", true) var table1 = NewTable( "db", @@ -22,13 +23,13 @@ var table1 = NewTable( table1ColTime, table1ColBool) -var table2Col3 = NewIntegerColumn("col3", Nullable) -var table2Col4 = NewIntegerColumn("col4", Nullable) -var table2ColInt = NewIntegerColumn("colInt", Nullable) -var table2ColFloat = NewFloatColumn("colFloat", Nullable) -var table2ColStr = NewStringColumn("colStr", Nullable) -var table2ColBool = NewBoolColumn("colBool", Nullable) -var table2ColTime = NewTimeColumn("colTime", Nullable) +var table2Col3 = NewIntegerColumn("col3", true) +var table2Col4 = NewIntegerColumn("col4", true) +var table2ColInt = NewIntegerColumn("colInt", true) +var table2ColFloat = NewFloatColumn("colFloat", true) +var table2ColStr = NewStringColumn("colStr", true) +var table2ColBool = NewBoolColumn("colBool", true) +var table2ColTime = NewTimeColumn("colTime", true) var table2 = NewTable( "db", @@ -41,17 +42,19 @@ var table2 = NewTable( table2ColBool, table2ColTime) -var table3Col1 = NewIntegerColumn("col1", Nullable) -var table3StrCol = NewStringColumn("col2", Nullable) +var table3Col1 = NewIntegerColumn("col1", true) +var table3ColInt = NewIntegerColumn("colInt", true) +var table3StrCol = NewStringColumn("col2", true) var table3 = NewTable( "db", "table3", table3Col1, + table3ColInt, table3StrCol) -func assertExpressionSerialize(t *testing.T, expression Expression, query string, args ...interface{}) { +func assertClauseSerialize(t *testing.T, clause clause, query string, args ...interface{}) { out := queryData{} - err := expression.serialize(select_statement, &out) + err := clause.serialize(select_statement, &out) assert.NilError(t, err) @@ -59,6 +62,15 @@ func assertExpressionSerialize(t *testing.T, expression Expression, query string assert.DeepEqual(t, out.args, args) } +func assertClauseSerializeErr(t *testing.T, clause clause, errString string) { + out := queryData{} + err := clause.serialize(select_statement, &out) + + fmt.Println(err) + assert.Assert(t, err != nil) + assert.Equal(t, err.Error(), errString) +} + func assertProjectionSerialize(t *testing.T, projection projection, query string, args ...interface{}) { out := queryData{} err := projection.serializeForProjection(select_statement, &out) diff --git a/sqlbuilder/time_expression_test.go b/sqlbuilder/time_expression_test.go index 154de6d..5b35c55 100644 --- a/sqlbuilder/time_expression_test.go +++ b/sqlbuilder/time_expression_test.go @@ -5,31 +5,31 @@ import ( ) func TestTimeExpressionEQ(t *testing.T) { - assertExpressionSerialize(t, table1ColTime.EQ(table2ColTime), "(table1.colTime = table2.colTime)") - assertExpressionSerialize(t, table1ColTime.EQ(Time(10, 20, 0, 0)), "(table1.colTime = $1::time without time zone)", "10:20:00.000") + assertClauseSerialize(t, table1ColTime.EQ(table2ColTime), "(table1.colTime = table2.colTime)") + assertClauseSerialize(t, table1ColTime.EQ(Time(10, 20, 0, 0)), "(table1.colTime = $1::time without time zone)", "10:20:00.000") } func TestTimeExpressionNOT_EQ(t *testing.T) { - assertExpressionSerialize(t, table1ColTime.NOT_EQ(table2ColTime), "(table1.colTime != table2.colTime)") - assertExpressionSerialize(t, table1ColTime.NOT_EQ(Time(10, 20, 0, 0)), "(table1.colTime != $1::time without time zone)", "10:20:00.000") + assertClauseSerialize(t, table1ColTime.NOT_EQ(table2ColTime), "(table1.colTime != table2.colTime)") + assertClauseSerialize(t, table1ColTime.NOT_EQ(Time(10, 20, 0, 0)), "(table1.colTime != $1::time without time zone)", "10:20:00.000") } func TestTimeExpressionLT(t *testing.T) { - assertExpressionSerialize(t, table1ColTime.LT(table2ColTime), "(table1.colTime < table2.colTime)") - assertExpressionSerialize(t, table1ColTime.LT(Time(10, 20, 0, 0)), "(table1.colTime < $1::time without time zone)", "10:20:00.000") + assertClauseSerialize(t, table1ColTime.LT(table2ColTime), "(table1.colTime < table2.colTime)") + assertClauseSerialize(t, table1ColTime.LT(Time(10, 20, 0, 0)), "(table1.colTime < $1::time without time zone)", "10:20:00.000") } func TestTimeExpressionLT_EQ(t *testing.T) { - assertExpressionSerialize(t, table1ColTime.LT_EQ(table2ColTime), "(table1.colTime <= table2.colTime)") - assertExpressionSerialize(t, table1ColTime.LT_EQ(Time(10, 20, 0, 0)), "(table1.colTime <= $1::time without time zone)", "10:20:00.000") + assertClauseSerialize(t, table1ColTime.LT_EQ(table2ColTime), "(table1.colTime <= table2.colTime)") + assertClauseSerialize(t, table1ColTime.LT_EQ(Time(10, 20, 0, 0)), "(table1.colTime <= $1::time without time zone)", "10:20:00.000") } func TestTimeExpressionGT(t *testing.T) { - assertExpressionSerialize(t, table1ColTime.GT(table2ColTime), "(table1.colTime > table2.colTime)") - assertExpressionSerialize(t, table1ColTime.GT(Time(10, 20, 0, 0)), "(table1.colTime > $1::time without time zone)", "10:20:00.000") + assertClauseSerialize(t, table1ColTime.GT(table2ColTime), "(table1.colTime > table2.colTime)") + assertClauseSerialize(t, table1ColTime.GT(Time(10, 20, 0, 0)), "(table1.colTime > $1::time without time zone)", "10:20:00.000") } func TestTimeExpressionGT_EQ(t *testing.T) { - assertExpressionSerialize(t, table1ColTime.GT_EQ(table2ColTime), "(table1.colTime >= table2.colTime)") - assertExpressionSerialize(t, table1ColTime.GT_EQ(Time(10, 20, 0, 0)), "(table1.colTime >= $1::time without time zone)", "10:20:00.000") + assertClauseSerialize(t, table1ColTime.GT_EQ(table2ColTime), "(table1.colTime >= table2.colTime)") + assertClauseSerialize(t, table1ColTime.GT_EQ(Time(10, 20, 0, 0)), "(table1.colTime >= $1::time without time zone)", "10:20:00.000") } diff --git a/sqlbuilder/update_statement.go b/sqlbuilder/update_statement.go index abaec44..2b46de4 100644 --- a/sqlbuilder/update_statement.go +++ b/sqlbuilder/update_statement.go @@ -2,7 +2,7 @@ package sqlbuilder import ( "database/sql" - "github.com/dropbox/godropbox/errors" + "errors" "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" ) @@ -14,7 +14,7 @@ type UpdateStatement interface { RETURNING(projections ...projection) UpdateStatement } -func newUpdateStatement(table writableTable, columns []column) UpdateStatement { +func newUpdateStatement(table WritableTable, columns []Column) UpdateStatement { return &updateStatementImpl{ table: table, columns: columns, @@ -22,8 +22,8 @@ func newUpdateStatement(table writableTable, columns []column) UpdateStatement { } type updateStatementImpl struct { - table writableTable - columns []column + table WritableTable + columns []Column updateValues []clause where BoolExpression returning []projection diff --git a/sqlbuilder/utils.go b/sqlbuilder/utils.go index 5de6de5..23a2fe6 100644 --- a/sqlbuilder/utils.go +++ b/sqlbuilder/utils.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" + "reflect" ) func serializeOrderByClauseList(statement statementType, orderByClauses []OrderByClause, out *queryData) error { @@ -97,7 +98,7 @@ func serializeProjectionList(statement statementType, projections []projection, return nil } -func serializeColumnList(statement statementType, columns []column, out *queryData) error { +func serializeColumnList(statement statementType, columns []Column, out *queryData) error { for i, col := range columns { if i > 0 { out.writeByte(',') @@ -113,6 +114,10 @@ func serializeColumnList(statement statementType, columns []column, out *queryDa return nil } +func isNil(v interface{}) bool { + return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) +} + //func stringExpressionListToExpressionList(stringExpressions []StringExpression) []Expression{ // var ret []Expression // diff --git a/tests/select_test.go b/tests/select_test.go index b1fed8f..a6534ea 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -63,7 +63,7 @@ SELECT payment.payment_id AS "payment.payment_id", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.payment - JOIN dvds.customer ON (payment.customer_id = customer.customer_id) + INNER JOIN dvds.customer ON (payment.customer_id = customer.customer_id) ORDER BY payment.payment_id ASC LIMIT 30; ` @@ -197,11 +197,11 @@ SELECT film_actor.actor_id AS "film_actor.actor_id", rental.staff_id AS "rental.staff_id", rental.last_update AS "rental.last_update" FROM dvds.film_actor - JOIN dvds.actor ON (film_actor.actor_id = actor.actor_id) - JOIN dvds.film ON (film_actor.film_id = film.film_id) - JOIN dvds.language ON (film.language_id = language.language_id) - JOIN dvds.inventory ON (inventory.film_id = film.film_id) - JOIN dvds.rental ON (rental.inventory_id = inventory.inventory_id) + INNER JOIN dvds.actor ON (film_actor.actor_id = actor.actor_id) + INNER JOIN dvds.film ON (film_actor.film_id = film.film_id) + INNER JOIN dvds.language ON (film.language_id = language.language_id) + INNER JOIN dvds.inventory ON (inventory.film_id = film.film_id) + INNER JOIN dvds.rental ON (rental.inventory_id = inventory.inventory_id) ORDER BY film.film_id ASC LIMIT 50; ` @@ -272,7 +272,7 @@ SELECT language.language_id AS "language.language_id", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film - JOIN dvds.language ON (film.language_id = language.language_id) + INNER JOIN dvds.language ON (film.language_id = language.language_id) WHERE film.rating = 'NC-17' LIMIT 15; ` @@ -569,7 +569,7 @@ SELECT f1.film_id AS "f1.film_id", f2.special_features AS "f2.special_features", f2.fulltext AS "f2.fulltext" FROM dvds.film AS f1 - JOIN dvds.film AS f2 ON ((f1.film_id < f2.film_id) AND (f1.length = f2.length)) + INNER JOIN dvds.film AS f2 ON ((f1.film_id < f2.film_id) AND (f1.length = f2.length)) ORDER BY f1.film_id ASC; ` f1 := Film.AS("f1") @@ -606,7 +606,7 @@ SELECT f1.title AS "thesame_length_films.title1", f2.title AS "thesame_length_films.title2", f1.length AS "thesame_length_films.length" FROM dvds.film AS f1 - JOIN dvds.film AS f2 ON ((f1.film_id != f2.film_id) AND (f1.length = f2.length)) + INNER JOIN dvds.film AS f2 ON ((f1.film_id != f2.film_id) AND (f1.length = f2.length)) ORDER BY f1.length ASC, f1.title ASC, f2.title ASC LIMIT 1000; ` @@ -682,8 +682,8 @@ LIMIT 1000; // manager := Staff.AS("manager") // // query := Staff. -// INNER_JOIN(Address, Staff.AddressID.EQ(Address.AddressID)). -// INNER_JOIN(manager, Staff.StaffID.EQ(manager.StaffID)). +// innerJoin(Address, Staff.AddressID.EQ(Address.AddressID)). +// innerJoin(manager, Staff.StaffID.EQ(manager.StaffID)). // SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName). // DISTINCT() // @@ -717,7 +717,7 @@ func TestSubQuery(t *testing.T) { //avrgCustomer := NumExp(Customer.SELECT(Customer.LastName).LIMIT(1)) // //Customer. - // INNER_JOIN(selectStmtTable, Customer.LastName.EQ(selectStmtTable.RefStringColumn(Actor.FirstName))). + // innerJoin(selectStmtTable, Customer.LastName.EQ(selectStmtTable.RefStringColumn(Actor.FirstName))). // SELECT(Customer.AllColumns, selectStmtTable.RefIntColumnName("first_name")). // WHERE(Actor.LastName.Neq(avrgCustomer)) @@ -732,8 +732,8 @@ SELECT actor.actor_id AS "actor.actor_id", films."film.title" AS "film.title", films."film.rating" AS "film.rating" FROM dvds.actor - JOIN dvds.film_actor ON (actor.actor_id = film_actor.film_id) - JOIN ( + INNER JOIN dvds.film_actor ON (actor.actor_id = film_actor.film_id) + INNER JOIN ( SELECT film.film_id AS "film.film_id", film.title AS "film.title", film.rating AS "film.rating" @@ -918,7 +918,7 @@ SELECT customer.customer_id AS "customer.customer_id", customer.active AS "customer.active", customer_payment_sum.amount_sum AS "customer_with_amounts.amount_sum" FROM dvds.customer - JOIN ( + INNER JOIN ( SELECT payment.customer_id AS "payment.customer_id", SUM(payment.amount) AS "amount_sum" FROM dvds.payment