diff --git a/generator/templates.go b/generator/templates.go index 4c43958..d35b539 100644 --- a/generator/templates.go +++ b/generator/templates.go @@ -12,7 +12,7 @@ type {{.ToGoStructName}} struct { {{.ToGoFieldName}} sqlbuilder.NonAliasColumn {{- end}} - All []sqlbuilder.Projection + AllColumns sqlbuilder.ColumnList } var {{.ToGoVarName}} = &{{.ToGoStructName}}{ @@ -23,7 +23,7 @@ var {{.ToGoVarName}} = &{{.ToGoStructName}}{ {{.ToGoFieldName}}: {{.ToGoVarName}}, {{- end}} - All: []sqlbuilder.Projection{ {{.ToGoColumnFieldList ", "}} }, + AllColumns: sqlbuilder.ColumnList{ {{.ToGoColumnFieldList ", "}} }, } var ( diff --git a/sqlbuilder/types.go b/sqlbuilder/types.go index c9d05ea..82a6452 100644 --- a/sqlbuilder/types.go +++ b/sqlbuilder/types.go @@ -32,6 +32,33 @@ type Projection interface { SerializeSqlForColumnList(out *bytes.Buffer) error } +type ColumnList []NonAliasColumn + +func (cl ColumnList) SerializeSql(out *bytes.Buffer) error { + for i, column := range cl { + column.SerializeSql(out) + + if i != len(cl)-1 { + out.WriteString(", ") + } + } + return nil +} + +func (cl ColumnList) isProjectionType() { +} + +func (cl ColumnList) SerializeSqlForColumnList(out *bytes.Buffer) error { + for i, column := range cl { + column.SerializeSqlForColumnList(out) + + if i != len(cl)-1 { + out.WriteString(", ") + } + } + return nil +} + // // Boiler plates ... // diff --git a/tests/generator_test.go b/tests/generator_test.go index 8ec4646..c8d9d08 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -58,7 +58,7 @@ func TestGenerateModel(t *testing.T) { func TestSelect_ScanToStruct(t *testing.T) { actor := model.Actor{} - err := Actor.Select(Actor.All...).Execute(db, &actor) + err := Actor.Select(Actor.AllColumns).Execute(db, &actor) assert.NilError(t, err) @@ -75,13 +75,12 @@ func TestSelect_ScanToStruct(t *testing.T) { func TestSelect_ScanToSlice(t *testing.T) { customers := []model.Customer{} - query := Customer.Select(Customer.All...) + query := Customer.Select(Customer.AllColumns) queryStr, err := query.String() assert.NilError(t, err) - assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id",customer.store_id AS "customer.store_id",customer.first_name AS "customer.first_name",customer.last_name AS "customer.last_name",customer.email AS "customer.email",customer.address_id AS "customer.address_id",customer.activebool AS "customer.activebool",customer.create_date AS "customer.create_date",customer.last_update AS "customer.last_update",customer.active AS "customer.active" FROM dvds.customer`) - + assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer`) err = query.Execute(db, &customers) assert.NilError(t, err) @@ -134,19 +133,18 @@ func TestSelect_ScanToSlice(t *testing.T) { func TestJoinQueryStruct(t *testing.T) { - //filmActor := model.FilmActor{} - allFilmActorColumns := append(append(append(FilmActor.All, Film.All...), Language.All...), Actor.All...) query := FilmActor. InnerJoinOn(Actor, sqlbuilder.Eq(FilmActor.ActorID, Actor.ActorID)). InnerJoinOn(Film, sqlbuilder.Eq(FilmActor.FilmID, Film.FilmID)). InnerJoinOn(Language, sqlbuilder.Eq(Film.LanguageID, Language.LanguageID)). - Select(allFilmActorColumns...). + Select(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). Where(sqlbuilder.And(sqlbuilder.Gte(FilmActor.ActorID, sqlbuilder.Literal(1)), sqlbuilder.Lte(FilmActor.ActorID, sqlbuilder.Literal(2)))) queryStr, err := query.String() assert.NilError(t, err) + assert.Equal(t, queryStr, `SELECT film_actor.actor_id AS "film_actor.actor_id", film_actor.film_id AS "film_actor.film_id", film_actor.last_update AS "film_actor.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext",language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.film_actor JOIN dvds.actor ON film_actor.actor_id = actor.actor_id JOIN dvds.film ON film_actor.film_id = film.film_id JOIN dvds.language ON film.language_id = language.language_id WHERE (film_actor.actor_id>=1 AND film_actor.actor_id<=2)`) - fmt.Println(queryStr) + //fmt.Println(queryStr) filmActor := []model.FilmActor{} @@ -168,14 +166,17 @@ func TestJoinQuerySlice(t *testing.T) { limit := 15 query := Film.InnerJoinOn(Language, sqlbuilder.Eq(Film.LanguageID, Language.LanguageID)). - Select(append(Language.All, Film.All...)...). + Select(Language.AllColumns, Film.AllColumns). Limit(15) - queryStr, _ := query.String() + queryStr, err := query.String() - fmt.Println(queryStr) + assert.NilError(t, err) + assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id LIMIT 15`) - err := query.Execute(db, &filmsPerLanguage) + //fmt.Println(queryStr) + + err = query.Execute(db, &filmsPerLanguage) assert.NilError(t, err) @@ -205,7 +206,7 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) { limit := int64(3) query := Film.InnerJoinOn(Language, sqlbuilder.Eq(Film.LanguageID, Language.LanguageID)). - Select(append(Language.All, Film.All...)...). + Select(Language.AllColumns, Film.AllColumns). Limit(limit) filmsPerLanguageWithPtrs := []*FilmsPerLanguage{}