diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index afe81e9..0c2039a 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -4,6 +4,7 @@ package sqlbuilder import ( "bytes" + "github.com/dropbox/godropbox/database/sqltypes" "regexp" "github.com/dropbox/godropbox/errors" @@ -14,6 +15,7 @@ import ( // Representation of a table for query generation type Column interface { isProjectionInterface + isExpressionInterface Name() string // Serialization for use in column lists @@ -24,6 +26,14 @@ type Column interface { // Internal function for tracking table that a column belongs to // for the purpose of serialization setTableName(table string) error + + Eq(rhs Expression) BoolExpression + + Gte(rhs Expression) BoolExpression + GteLiteral(rhs interface{}) BoolExpression + + Lte(rhs Expression) BoolExpression + LteLiteral(rhs interface{}) BoolExpression } type NullableColumn bool @@ -37,7 +47,6 @@ const ( type NonAliasColumn interface { Column isOrderByClauseInterface - isExpressionInterface } type Collation string @@ -94,6 +103,26 @@ func (c *baseColumn) SerializeSql(out *bytes.Buffer) error { return nil } +func (c *baseColumn) Eq(rhs Expression) BoolExpression { + return Eq(c, rhs) +} + +func (c *baseColumn) Gte(rhs Expression) BoolExpression { + return Gte(c, rhs) +} + +func (c *baseColumn) GteLiteral(rhs interface{}) BoolExpression { + return Gte(c, Literal(rhs)) +} + +func (c *baseColumn) Lte(rhs Expression) BoolExpression { + return Lte(c, rhs) +} + +func (c *baseColumn) LteLiteral(literal interface{}) BoolExpression { + return Lte(c, Literal(literal)) +} + type bytesColumn struct { baseColumn isExpression @@ -305,3 +334,27 @@ func (c *deferredLookupColumn) setTableName(table string) error { "Lookup column '%s' should never have setTableName called on it", c.colName) } + +func (c *deferredLookupColumn) Eq(rhs Expression) BoolExpression { + lit, ok := rhs.(*literalExpression) + if ok && sqltypes.Value(lit.value).IsNull() { + return newBoolExpression(c, rhs, []byte(" IS ")) + } + return newBoolExpression(c, rhs, []byte(" = ")) +} + +func (c *deferredLookupColumn) Gte(rhs Expression) BoolExpression { + return Gte(c, rhs) +} + +func (c *deferredLookupColumn) GteLiteral(rhs interface{}) BoolExpression { + return Gte(c, Literal(rhs)) +} + +func (c *deferredLookupColumn) Lte(rhs Expression) BoolExpression { + return Lte(c, rhs) +} + +func (c *deferredLookupColumn) LteLiteral(literal interface{}) BoolExpression { + return Lte(c, Literal(literal)) +} diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index ff1dccc..2a58bfd 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -26,6 +26,8 @@ type ReadableTable interface { // Creates a inner join table expression using onCondition. InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable + InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable + // Creates a left join table expression using onCondition. LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable @@ -170,6 +172,14 @@ func (t *Table) InnerJoinOn( return InnerJoinOn(t, table, onCondition) } +func (t *Table) InnerJoinUsing( + table ReadableTable, + col1 Column, + col2 Column) ReadableTable { + + return InnerJoinOn(t, table, col1.Eq(col2)) +} + // Creates a left join table expression using onCondition. func (t *Table) LeftJoinOn( table ReadableTable, @@ -308,6 +318,14 @@ func (t *joinTable) InnerJoinOn( return InnerJoinOn(t, table, onCondition) } +func (t *joinTable) InnerJoinUsing( + table ReadableTable, + col1 Column, + col2 Column) ReadableTable { + + return InnerJoinOn(t, table, col1.Eq(col2)) +} + func (t *joinTable) LeftJoinOn( table ReadableTable, onCondition BoolExpression) ReadableTable { diff --git a/tests/generator_test.go b/tests/generator_test.go index c8d9d08..d877b40 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -134,11 +134,11 @@ func TestSelect_ScanToSlice(t *testing.T) { func TestJoinQueryStruct(t *testing.T) { query := FilmActor. - InnerJoinOn(Actor, sqlbuilder.Eq(FilmActor.ActorID, Actor.ActorID)). - InnerJoinOn(Film, sqlbuilder.Eq(FilmActor.FilmID, Film.FilmID)). - InnerJoinOn(Language, sqlbuilder.Eq(Film.LanguageID, Language.LanguageID)). + InnerJoinUsing(Actor, FilmActor.ActorID, Actor.ActorID). + InnerJoinUsing(Film, FilmActor.FilmID, Film.FilmID). + InnerJoinUsing(Language, Film.LanguageID, Language.LanguageID). Select(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). - Where(sqlbuilder.And(sqlbuilder.Gte(FilmActor.ActorID, sqlbuilder.Literal(1)), sqlbuilder.Lte(FilmActor.ActorID, sqlbuilder.Literal(2)))) + Where(sqlbuilder.And(FilmActor.ActorID.GteLiteral(1), FilmActor.ActorID.LteLiteral(2))) queryStr, err := query.String() assert.NilError(t, err) @@ -165,7 +165,7 @@ func TestJoinQuerySlice(t *testing.T) { filmsPerLanguage := []FilmsPerLanguage{} limit := 15 - query := Film.InnerJoinOn(Language, sqlbuilder.Eq(Film.LanguageID, Language.LanguageID)). + query := Film.InnerJoinUsing(Language, Film.LanguageID, Language.LanguageID). Select(Language.AllColumns, Film.AllColumns). Limit(15) @@ -205,7 +205,7 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) { limit := int64(3) - query := Film.InnerJoinOn(Language, sqlbuilder.Eq(Film.LanguageID, Language.LanguageID)). + query := Film.InnerJoinUsing(Language, Film.LanguageID, Language.LanguageID). Select(Language.AllColumns, Film.AllColumns). Limit(limit)