Add self join support.

This commit is contained in:
sub0Zero 2019-03-16 20:41:06 +01:00 committed by zer0sub
parent 20c6f39665
commit 1cb997fc54
5 changed files with 89 additions and 15 deletions

View file

@ -23,7 +23,7 @@ func (c ColumnInfo) IsUnique() bool {
}
func (c ColumnInfo) ToGoVarName() string {
return snaker.SnakeToCamelLower(c.TableInfo.Name) + snaker.SnakeToCamel(c.Name) + "Column"
return snaker.SnakeToCamel(c.Name) + "Column"
}
func (c ColumnInfo) ToGoType() string {

View file

@ -2,7 +2,9 @@ package generator
var SqlBuilderTableTemplate = `package table
import "github.com/sub0Zero/go-sqlbuilder/sqlbuilder"
import (
"github.com/sub0Zero/go-sqlbuilder/sqlbuilder"
)
type {{.ToGoStructName}} struct {
sqlbuilder.Table
@ -15,22 +17,36 @@ type {{.ToGoStructName}} struct {
AllColumns sqlbuilder.ColumnList
}
var {{.ToGoVarName}} = &{{.ToGoStructName}}{
Table: *sqlbuilder.NewTable("{{.DatabaseInfo.SchemaName}}", "{{.Name}}", {{.ToGoColumnFieldList ", "}}),
//Columns
var {{.ToGoVarName}} = new{{.ToGoStructName}}()
func new{{.ToGoStructName}}() *{{.ToGoStructName}} {
var (
{{- range .Columns}}
{{.ToGoVarName}} = sqlbuilder.IntColumn("{{.Name}}", {{if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}})
{{- end}}
)
return &{{.ToGoStructName}}{
Table: *sqlbuilder.NewTable("{{.DatabaseInfo.SchemaName}}", "{{.Name}}", {{.ToGoColumnFieldList ", "}}),
//Columns
{{- range .Columns}}
{{.ToGoFieldName}}: {{.ToGoVarName}},
{{.ToGoFieldName}}: {{.ToGoVarName}},
{{- end}}
AllColumns: sqlbuilder.ColumnList{ {{.ToGoColumnFieldList ", "}} },
AllColumns: sqlbuilder.ColumnList{ {{.ToGoColumnFieldList ", "}} },
}
}
func (a *{{.ToGoStructName}}) As(alias string) *{{.ToGoStructName}} {
aliasTable := new{{.ToGoStructName}}()
aliasTable.Table.SetAlias(alias)
return aliasTable
}
var (
{{- range .Columns}}
{{.ToGoVarName}} = sqlbuilder.IntColumn("{{.Name}}", {{if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}})
{{- end}}
)
`
var DataModelTemplate = `package model

View file

@ -27,6 +27,7 @@ type Column interface {
setTableName(table string) error
Eq(rhs Expression) BoolExpression
Neq(rhs Expression) BoolExpression
Gte(rhs Expression) BoolExpression
GteLiteral(rhs interface{}) BoolExpression
@ -109,6 +110,10 @@ func (c *baseColumn) Eq(rhs Expression) BoolExpression {
return Eq(c, rhs)
}
func (c *baseColumn) Neq(rhs Expression) BoolExpression {
return Neq(c, rhs)
}
func (c *baseColumn) Gte(rhs Expression) BoolExpression {
return Gte(c, rhs)
}
@ -198,7 +203,7 @@ type IntegerColumn struct {
// Representation of any integer column
// This function will panic if name is not valid
func IntColumn(name string, nullable NullableColumn) NonAliasColumn {
func IntColumn(name string, nullable NullableColumn) *IntegerColumn {
if !validIdentifierName(name) {
panic("Invalid column name in int column")
}

View file

@ -5,7 +5,6 @@ package sqlbuilder
import (
"bytes"
"fmt"
"github.com/dropbox/godropbox/errors"
)
@ -85,6 +84,7 @@ func NewTable(schemaName, name string, columns ...NonAliasColumn) *Table {
type Table struct {
schemaName string
name string
alias string
columns []NonAliasColumn
columnLookup map[string]NonAliasColumn
// If not empty, the name of the index to force
@ -119,6 +119,17 @@ func (t *Table) Projections() []Projection {
return result
}
func (t *Table) SetAlias(alias string) {
t.alias = alias
for _, c := range t.columns {
err := c.setTableName(alias)
if err != nil {
panic(err)
}
}
}
// Returns the table's name in the database
func (t *Table) Name() string {
return t.name
@ -151,6 +162,11 @@ func (t *Table) SerializeSql(out *bytes.Buffer) error {
_, _ = out.WriteString(".")
_, _ = out.WriteString(t.Name())
if len(t.alias) > 0 {
out.WriteString(" AS ")
out.WriteString(t.alias)
}
if t.forcedIndex != "" {
if !validIdentifierName(t.forcedIndex) {
return errors.Newf("'%s' is not a valid identifier for an index", t.forcedIndex)

View file

@ -3,6 +3,7 @@ package tests
import (
"database/sql"
"fmt"
"github.com/davecgh/go-spew/spew"
"github.com/sub0Zero/go-sqlbuilder/generator"
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model"
. "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table"
@ -290,6 +291,42 @@ func TestSelectFullCrossJoin(t *testing.T) {
assert.NilError(t, err)
}
func TestSelectSelfJoin(t *testing.T) {
f1 := Film.As("f1")
//spew.Dump(f1)
f2 := Film.As("f2")
query := f1.
InnerJoinOn(f2, f1.FilmID.Neq(f2.FilmID).And(f1.Length.Eq(f2.Length))).
Select(f1.AllColumns, f2.AllColumns).
OrderBy(f1.FilmID)
queryStr, err := query.String()
assert.NilError(t, err)
fmt.Println(queryStr)
type F1 model.Film
type F2 model.Film
theSameLengthFilms := []struct {
F1 F1
F2 F2
}{}
err = query.Execute(db, &theSameLengthFilms)
assert.NilError(t, err)
spew.Dump(theSameLengthFilms[0])
assert.Equal(t, len(theSameLengthFilms), 6972)
}
func int32Ptr(i int32) *int32 {
return &i
}