From b88519bfd457f7feace85c7b6b7a84c8e89b67ea Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 20 Sep 2019 12:53:52 +0200 Subject: [PATCH] [Feature] Add support for database views. [Feature] Add support to manually set primary keys for destination structure fields. --- execution/execution.go | 36 +++++- .../internal/metadata/schema_meta_data.go | 37 ++++-- .../internal/metadata/table_meta_data.go | 4 +- generator/internal/template/generate.go | 71 +++++++---- generator/internal/template/templates.go | 2 +- generator/mysql/mysql_generator.go | 4 +- generator/mysql/query_set.go | 4 +- generator/postgres/postgres_generator.go | 4 +- generator/postgres/query_set.go | 2 +- internal/testutils/test_utils.go | 26 ++++ internal/utils/utils.go | 1 + tests/mysql/alltypes_test.go | 18 +++ tests/mysql/generator_test.go | 111 ++++++++++++------ tests/mysql/select_test.go | 74 ++++++++++-- tests/postgres/alltypes_test.go | 16 ++- tests/postgres/generator_test.go | 106 +++++++++++------ tests/postgres/select_test.go | 72 +++++++++++- tests/testdata | 2 +- 18 files changed, 462 insertions(+), 128 deletions(-) diff --git a/execution/execution.go b/execution/execution.go index 363b0f4..8a8fc24 100644 --- a/execution/execution.go +++ b/execution/execution.go @@ -771,7 +771,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl if len(subType.indexes) != 0 || len(subType.subTypes) != 0 { ret.subTypes = append(ret.subTypes, subType) } - } else if isPrimaryKey(field) { + } else if isPrimaryKey(field, parentField) { index := s.typeToColumnIndex(newTypeName, fieldName) if index < 0 { @@ -837,13 +837,45 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value { return newElem } -func isPrimaryKey(field reflect.StructField) bool { +func isPrimaryKey(field reflect.StructField, parentField *reflect.StructField) bool { + + if hasOverwrite, isPrimaryKey := primaryKeyOvewrite(field.Name, parentField); hasOverwrite { + return isPrimaryKey + } sqlTag := field.Tag.Get("sql") return sqlTag == "primary_key" } +func primaryKeyOvewrite(columnName string, parentField *reflect.StructField) (hasOverwrite, primaryKey bool) { + if parentField == nil { + return + } + + sqlTag := parentField.Tag.Get("sql") + + if !strings.HasPrefix(sqlTag, "primary_key") { + return + } + + parts := strings.Split(sqlTag, "=") + + if len(parts) < 2 { + return + } + + primaryKeyColumns := strings.Split(parts[1], ",") + + for _, primaryKeyCol := range primaryKeyColumns { + if toCommonIdentifier(columnName) == toCommonIdentifier(primaryKeyCol) { + return true, true + } + } + + return true, false +} + func indirectType(reflectType reflect.Type) reflect.Type { if reflectType.Kind() != reflect.Ptr { return reflectType diff --git a/generator/internal/metadata/schema_meta_data.go b/generator/internal/metadata/schema_meta_data.go index 0a05d6a..b7d3b7c 100644 --- a/generator/internal/metadata/schema_meta_data.go +++ b/generator/internal/metadata/schema_meta_data.go @@ -7,33 +7,50 @@ import ( // SchemaMetaData struct type SchemaMetaData struct { - TableInfos []MetaData - EnumInfos []MetaData + TablesMetaData []MetaData + ViewsMetaData []MetaData + EnumsMetaData []MetaData } -// GetSchemaInfo returns schema information from db connection. -func GetSchemaInfo(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData, err error) { +func (s SchemaMetaData) IsEmpty() bool { + return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0 +} - schemaInfo.TableInfos, err = getTableInfos(db, querySet, schemaName) +const ( + baseTable = "BASE TABLE" + view = "VIEW" +) + +// GetSchemaMetaData returns schema information from db connection. +func GetSchemaMetaData(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData, err error) { + + schemaInfo.TablesMetaData, err = getTablesMetaData(db, querySet, schemaName, baseTable) if err != nil { return } - schemaInfo.EnumInfos, err = querySet.GetEnumsMetaData(db, schemaName) + schemaInfo.ViewsMetaData, err = getTablesMetaData(db, querySet, schemaName, view) if err != nil { return } - fmt.Println(" FOUND", len(schemaInfo.TableInfos), "table(s), ", len(schemaInfo.EnumInfos), "enum(s)") + schemaInfo.EnumsMetaData, err = querySet.GetEnumsMetaData(db, schemaName) + + if err != nil { + return + } + + fmt.Println(" FOUND", len(schemaInfo.TablesMetaData), "table(s),", len(schemaInfo.ViewsMetaData), "view(s),", + len(schemaInfo.EnumsMetaData), "enum(s)") return } -func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]MetaData, error) { +func getTablesMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableType string) ([]MetaData, error) { - rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName) + rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName, tableType) if err != nil { return nil, err @@ -49,7 +66,7 @@ func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]M return nil, err } - tableInfo, err := GetTableInfo(db, querySet, schemaName, tableName) + tableInfo, err := GetTableMetaData(db, querySet, schemaName, tableName) if err != nil { return nil, err diff --git a/generator/internal/metadata/table_meta_data.go b/generator/internal/metadata/table_meta_data.go index c2f4b23..eea0604 100644 --- a/generator/internal/metadata/table_meta_data.go +++ b/generator/internal/metadata/table_meta_data.go @@ -67,8 +67,8 @@ func (t TableMetaData) GoStructName() string { return utils.ToGoIdentifier(t.name) + "Table" } -// GetTableInfo returns table info metadata -func GetTableInfo(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData, err error) { +// GetTableMetaData returns table info metadata +func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData, err error) { tableInfo.SchemaName = schemaName tableInfo.name = tableName diff --git a/generator/internal/template/generate.go b/generator/internal/template/generate.go index a16c133..c545335 100644 --- a/generator/internal/template/generate.go +++ b/generator/internal/template/generate.go @@ -12,8 +12,8 @@ import ( ) // GenerateFiles generates Go files from tables and enums metadata -func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect jet.Dialect) error { - if len(tables) == 0 && len(enums) == 0 { +func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect jet.Dialect) error { + if schemaInfo.IsEmpty() { return nil } @@ -25,43 +25,64 @@ func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect je return err } - fmt.Println("Generating table sql builder files...") - err = generate(destDir, "table", tableSQLBuilderTemplate, tables, dialect) + err = generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect) if err != nil { return err } - fmt.Println("Generating table model files...") - err = generate(destDir, "model", tableModelTemplate, tables, dialect) + err = generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect) if err != nil { return err } - if len(enums) > 0 { - fmt.Println("Generating enum sql builder files...") - err = generate(destDir, "enum", enumSQLBuilderTemplate, enums, dialect) + err = generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect) - if err != nil { - return err - } + if err != nil { + return err + } - fmt.Println("Generating enum model files...") - err = generate(destDir, "model", enumModelTemplate, enums, dialect) + err = generateModelFiles(destDir, "table", tableModelTemplate, schemaInfo.TablesMetaData, dialect) - if err != nil { - return err - } + if err != nil { + return err + } + + err = generateModelFiles(destDir, "view", tableModelTemplate, schemaInfo.ViewsMetaData, dialect) + + if err != nil { + return err + } + + err = generateModelFiles(destDir, "enum", enumModelTemplate, schemaInfo.EnumsMetaData, dialect) + + if err != nil { + return err } fmt.Println("Done") return nil - } -func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error { +func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) error { + if len(metaData) == 0 { + return nil + } + fmt.Printf("Generating %s sql builder files...\n", fileTypes) + return generateGoFiles(destDir, fileTypes, sqlBuilderTemplate, metaData, dialect) +} + +func generateModelFiles(destDir, fileTypes, modelTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) error { + if len(metaData) == 0 { + return nil + } + fmt.Printf("Generating %s model files...\n", fileTypes) + return generateGoFiles(destDir, "model", modelTemplate, metaData, dialect) +} + +func generateGoFiles(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error { modelDirPath := filepath.Join(dirPath, packageName) err := utils.EnsureDirPath(modelDirPath) @@ -77,7 +98,7 @@ func generate(dirPath, packageName string, template string, metaDataList []metad } for _, metaData := range metaDataList { - text, err := GenerateTemplate(template, metaData, dialect) + text, err := GenerateTemplate(template, metaData, dialect, map[string]interface{}{"package": packageName}) if err != nil { return err @@ -94,7 +115,7 @@ func generate(dirPath, packageName string, template string, metaDataList []metad } // GenerateTemplate generates template with template text and template data. -func GenerateTemplate(templateText string, templateData interface{}, dialect1 jet.Dialect) ([]byte, error) { +func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect, params ...map[string]interface{}) ([]byte, error) { t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ "ToGoIdentifier": utils.ToGoIdentifier, @@ -102,7 +123,13 @@ func GenerateTemplate(templateText string, templateData interface{}, dialect1 je return time.Now().Format(time.RFC850) }, "dialect": func() jet.Dialect { - return dialect1 + return dialect + }, + "param": func(name string) interface{} { + if len(params) > 0 { + return params[0][name] + } + return "" }, }).Parse(templateText) diff --git a/generator/internal/template/templates.go b/generator/internal/template/templates.go index c7c064f..0a2a9c7 100644 --- a/generator/internal/template/templates.go +++ b/generator/internal/template/templates.go @@ -18,7 +18,7 @@ var tableSQLBuilderTemplate = ` {{- end}} {{- end}} -package table +package {{param "package"}} import ( "github.com/go-jet/jet/{{dialect.PackageName}}" diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index c8b01ee..4a4b4ca 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -31,7 +31,7 @@ func Generate(destDir string, dbConn DBConnection) error { fmt.Println("Retrieving database information...") // No schemas in MySQL - dbInfo, err := metadata.GetSchemaInfo(db, dbConn.DBName, &mySqlQuerySet{}) + dbInfo, err := metadata.GetSchemaMetaData(db, dbConn.DBName, &mySqlQuerySet{}) if err != nil { return err @@ -39,7 +39,7 @@ func Generate(destDir string, dbConn DBConnection) error { genPath := path.Join(destDir, dbConn.DBName) - err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, mysql.Dialect) + err = template.GenerateFiles(genPath, dbInfo, mysql.Dialect) if err != nil { return err diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 20a01ac..3d146d4 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -13,7 +13,7 @@ func (m *mySqlQuerySet) ListOfTablesQuery() string { return ` SELECT table_name FROM INFORMATION_SCHEMA.tables -WHERE table_schema = ? and table_type = 'BASE TABLE'; +WHERE table_schema = ? and table_type = ?; ` } @@ -46,7 +46,7 @@ func (m *mySqlQuerySet) ListOfEnumsQuery() string { SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5) FROM information_schema.columns as c INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name) -WHERE c.table_schema = ? AND DATA_TYPE = 'enum' AND t.TABLE_TYPE = 'BASE TABLE'; +WHERE c.table_schema = ? AND DATA_TYPE = 'enum'; ` } diff --git a/generator/postgres/postgres_generator.go b/generator/postgres/postgres_generator.go index ce2ea34..d46d6f0 100644 --- a/generator/postgres/postgres_generator.go +++ b/generator/postgres/postgres_generator.go @@ -35,7 +35,7 @@ func Generate(destDir string, dbConn DBConnection) error { } fmt.Println("Retrieving schema information...") - schemaInfo, err := metadata.GetSchemaInfo(db, dbConn.SchemaName, &postgresQuerySet{}) + schemaInfo, err := metadata.GetSchemaMetaData(db, dbConn.SchemaName, &postgresQuerySet{}) if err != nil { return err @@ -43,7 +43,7 @@ func Generate(destDir string, dbConn DBConnection) error { genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName) - err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, postgres.Dialect) + err = template.GenerateFiles(genPath, schemaInfo, postgres.Dialect) if err != nil { return err diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index a4f1fdd..f1c7457 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -12,7 +12,7 @@ func (p *postgresQuerySet) ListOfTablesQuery() string { return ` SELECT table_name FROM information_schema.tables -where table_schema = $1 and table_type = 'BASE TABLE'; +where table_schema = $1 and table_type = $2; ` } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index d5dbbd6..c2779fe 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -159,3 +159,29 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db execution.DB, dest stmt.Query(db, dest) } + +func AssertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) { + enumFileData, err := ioutil.ReadFile(filePath) + + assert.NilError(t, err) + + beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) + + //fmt.Println("-"+string(enumFileData[beginIndex:])+"-") + + assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) +} + +func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) { + assert.Equal(t, len(fileInfos), len(fileNames)) + + fileNamesMap := map[string]bool{} + + for _, fileInfo := range fileInfos { + fileNamesMap[fileInfo.Name()] = true + } + + for _, fileName := range fileNames { + assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.") + } +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 9694cbb..9e623df 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -104,6 +104,7 @@ func DirExists(path string) (bool, error) { func replaceInvalidChars(str string) string { str = strings.Replace(str, " ", "_", -1) str = strings.Replace(str, "-", "_", -1) + str = strings.Replace(str, ".", "_", -1) return str } diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 1c93609..952ea90 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -5,6 +5,7 @@ import ( "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" + "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/view" "github.com/go-jet/jet/tests/testdata/results/common" "github.com/google/uuid" "time" @@ -36,6 +37,23 @@ func TestAllTypes(t *testing.T) { testutils.AssertJSON(t, dest, allTypesJson) } +func TestAllTypesViewSelect(t *testing.T) { + + type AllTypesView model.AllTypes + + dest := []AllTypesView{} + + err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) + assert.NilError(t, err) + assert.Equal(t, len(dest), 2) + + if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert + return + } + + testutils.AssertJSON(t, dest, allTypesJson) +} + func TestUUID(t *testing.T) { query := AllTypes. diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index a191214..b519ea9 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -1,8 +1,8 @@ package mysql import ( - "bytes" "github.com/go-jet/jet/generator/mysql" + "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/dbconfig" "gotest.tools/assert" "io/ioutil" @@ -63,53 +63,40 @@ func assertGeneratedFiles(t *testing.T) { tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table") assert.NilError(t, err) - assertFileNameEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", - "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", + testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", "payment.go", "rental.go", "staff.go", "store.go") - assertFileContent(t, genTestDir3+"/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) + testutils.AssertFileContent(t, genTestDir3+"/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) + + // View SQL Builder files + viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view") + assert.NilError(t, err) + + testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", + "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") + + testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilerFile) // Enums SQL Builder files enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") assert.NilError(t, err) - assertFileNameEqual(t, enumFiles, "film_rating.go") - assertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile) + testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go") + testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile) // Model files modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model") assert.NilError(t, err) - assertFileNameEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", - "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", - "payment.go", "rental.go", "staff.go", "store.go", "film_rating.go") + testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go", + "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go", + "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", + "customer_list.go", "sales_by_store.go", "staff_list.go") - assertFileContent(t, genTestDir3+"/dvds/model/actor.go", "\npackage model", actorModelFile) -} - -func assertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) { - enumFileData, err := ioutil.ReadFile(filePath) - - assert.NilError(t, err) - - beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) - - //fmt.Println("-"+string(enumFileData[beginIndex:])+"-") - - assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) -} - -func assertFileNameEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) { - - fileNamesMap := map[string]bool{} - - for _, fileInfo := range fileInfos { - fileNamesMap[fileInfo.Name()] = true - } - - for _, fileName := range fileNames { - assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.") - } + testutils.AssertFileContent(t, genTestDir3+"/dvds/model/actor.go", "\npackage model", actorModelFile) } var mpaaRatingEnumFile = ` @@ -200,3 +187,57 @@ type Actor struct { LastUpdate time.Time } ` + +var actorInfoSQLBuilerFile = ` +package view + +import ( + "github.com/go-jet/jet/mysql" +) + +var ActorInfo = newActorInfoTable() + +type ActorInfoTable struct { + mysql.Table + + //Columns + ActorID mysql.ColumnInteger + FirstName mysql.ColumnString + LastName mysql.ColumnString + FilmInfo mysql.ColumnString + + AllColumns mysql.IColumnList + MutableColumns mysql.IColumnList +} + +// creates new ActorInfoTable with assigned alias +func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { + aliasTable := newActorInfoTable() + + aliasTable.Table.AS(alias) + + return aliasTable +} + +func newActorInfoTable() *ActorInfoTable { + var ( + ActorIDColumn = mysql.IntegerColumn("actor_id") + FirstNameColumn = mysql.StringColumn("first_name") + LastNameColumn = mysql.StringColumn("last_name") + FilmInfoColumn = mysql.StringColumn("film_info") + ) + + return &ActorInfoTable{ + Table: mysql.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + + //Columns + ActorID: ActorIDColumn, + FirstName: FirstNameColumn, + LastName: LastNameColumn, + FilmInfo: FilmInfoColumn, + + AllColumns: mysql.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + MutableColumns: mysql.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + } +} +` diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 602a00d..0e20ae2 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/enum" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" + "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/view" "gotest.tools/assert" "testing" @@ -16,7 +17,7 @@ func TestSelect_ScanToStruct(t *testing.T) { query := Actor. SELECT(Actor.AllColumns). DISTINCT(). - WHERE(Actor.ActorID.EQ(Int(1))) + WHERE(Actor.ActorID.EQ(Int(2))) testutils.AssertStatementSql(t, query, ` SELECT DISTINCT actor.actor_id AS "actor.actor_id", @@ -25,20 +26,20 @@ SELECT DISTINCT actor.actor_id AS "actor.actor_id", actor.last_update AS "actor.last_update" FROM dvds.actor WHERE actor.actor_id = ?; -`, int64(1)) +`, int64(2)) actor := model.Actor{} err := query.Query(db, &actor) assert.NilError(t, err) - assert.DeepEqual(t, actor, actor1) + assert.DeepEqual(t, actor, actor2) } -var actor1 = model.Actor{ - ActorID: 1, - FirstName: "PENELOPE", - LastName: "GUINESS", +var actor2 = model.Actor{ + ActorID: 2, + FirstName: "NICK", + LastName: "WAHLBERG", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2), } @@ -62,7 +63,7 @@ ORDER BY actor.actor_id; assert.NilError(t, err) assert.Equal(t, len(dest), 200) - assert.DeepEqual(t, dest[0], actor1) + assert.DeepEqual(t, dest[1], actor2) //testutils.PrintJson(dest) //testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json") @@ -640,3 +641,60 @@ ORDER BY payment.customer_id; assert.NilError(t, err) } + +func TestSimpleView(t *testing.T) { + query := SELECT( + view.ActorInfo.AllColumns, + ). + FROM(view.ActorInfo). + ORDER_BY(view.ActorInfo.ActorID). + LIMIT(10) + + type ActorInfo struct { + ActorID int + FirstName string + LastName string + FilmInfo string + } + + var dest []ActorInfo + + err := query.Query(db, &dest) + assert.NilError(t, err) + + assert.Equal(t, len(dest), 10) + testutils.AssertJSON(t, dest[1:2], ` +[ + { + "ActorID": 2, + "FirstName": "NICK", + "LastName": "WAHLBERG", + "FilmInfo": "Action: BULL SHAWSHANK; Animation: FIGHT JAWBREAKER; Children: JERSEY SASSY; Classics: DRACULA CRYSTAL, GILBERT PELICAN; Comedy: MALLRATS UNITED, RUSHMORE MERMAID; Documentary: ADAPTATION HOLES; Drama: WARDROBE PHANTOM; Family: APACHE DIVINE, CHISUM BEHAVIOR, INDIAN LOVE, MAGUIRE APACHE; Foreign: BABY HALL, HAPPINESS UNITED; Games: ROOF CHAMPION; Music: LUCKY FLYING; New: DESTINY SATURDAY, FLASH WARS, JEKYLL FROGMEN, MASK PEACH; Sci-Fi: CHAINSAW UPTOWN, GOODFELLAS SALUTE; Travel: LIAISONS SWEET, SMILE EARRING" + } +] +`) +} + +func TestJoinViewWithTable(t *testing.T) { + query := SELECT( + view.CustomerList.AllColumns, + Rental.AllColumns, + ). + FROM(view.CustomerList. + INNER_JOIN(Rental, view.CustomerList.ID.EQ(Rental.CustomerID)), + ). + ORDER_BY(view.CustomerList.ID). + WHERE(view.CustomerList.ID.LT_EQ(Int(2))) + + var dest []struct { + model.CustomerList `sql:"primary_key=ID"` + Rentals []model.Rental + } + + err := query.Query(db, &dest) + assert.NilError(t, err) + + assert.Equal(t, len(dest), 2) + assert.Equal(t, len(dest[0].Rentals), 32) + assert.Equal(t, len(dest[1].Rentals), 27) +} diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 2b91c7a..a613173 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -6,6 +6,7 @@ import ( . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" + "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/view" "github.com/go-jet/jet/tests/testdata/results/common" "github.com/google/uuid" "gotest.tools/assert" @@ -23,6 +24,19 @@ func TestAllTypesSelect(t *testing.T) { assert.DeepEqual(t, dest[1], allTypesRow1) } +func TestAllTypesViewSelect(t *testing.T) { + + type AllTypesView model.AllTypes + + dest := []AllTypesView{} + + err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) + assert.NilError(t, err) + + assert.DeepEqual(t, dest[0], AllTypesView(allTypesRow0)) + assert.DeepEqual(t, dest[1], AllTypesView(allTypesRow1)) +} + func TestAllTypesInsertModel(t *testing.T) { query := AllTypes.INSERT(AllTypes.AllColumns). MODEL(allTypesRow0). @@ -31,8 +45,8 @@ func TestAllTypesInsertModel(t *testing.T) { dest := []model.AllTypes{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.Equal(t, len(dest), 2) assert.DeepEqual(t, dest[0], allTypesRow0) assert.DeepEqual(t, dest[1], allTypesRow1) diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index d73a850..8b04b6e 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -1,8 +1,8 @@ package postgres import ( - "bytes" "github.com/go-jet/jet/generator/postgres" + "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/dbconfig" "gotest.tools/assert" "io/ioutil" @@ -99,53 +99,39 @@ func assertGeneratedFiles(t *testing.T) { tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table") assert.NilError(t, err) - assertFileNameEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", "payment.go", "rental.go", "staff.go", "store.go") - assertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) + + // View SQL Builder files + viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view") + assert.NilError(t, err) + + testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", + "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") + + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilderFile) // Enums SQL Builder files enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum") assert.NilError(t, err) - assertFileNameEqual(t, enumFiles, "mpaa_rating.go") - assertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile) + testutils.AssertFileNamesEqual(t, enumFiles, "mpaa_rating.go") + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile) // Model files modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model") assert.NilError(t, err) - assertFileNameEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", - "payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go") + "payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go", + "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", + "customer_list.go", "sales_by_store.go", "staff_list.go") - assertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", "\npackage model", actorModelFile) -} - -func assertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) { - enumFileData, err := ioutil.ReadFile(filePath) - - assert.NilError(t, err) - - beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) - - //fmt.Println("-"+string(enumFileData[beginIndex:])+"-") - - assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) -} - -func assertFileNameEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) { - - fileNamesMap := map[string]bool{} - - for _, fileInfo := range fileInfos { - fileNamesMap[fileInfo.Name()] = true - } - - for _, fileName := range fileNames { - assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.") - } + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", "\npackage model", actorModelFile) } var mpaaRatingEnumFile = ` @@ -236,3 +222,57 @@ type Actor struct { LastUpdate time.Time } ` + +var actorInfoSQLBuilderFile = ` +package view + +import ( + "github.com/go-jet/jet/postgres" +) + +var ActorInfo = newActorInfoTable() + +type ActorInfoTable struct { + postgres.Table + + //Columns + ActorID postgres.ColumnInteger + FirstName postgres.ColumnString + LastName postgres.ColumnString + FilmInfo postgres.ColumnString + + AllColumns postgres.IColumnList + MutableColumns postgres.IColumnList +} + +// creates new ActorInfoTable with assigned alias +func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { + aliasTable := newActorInfoTable() + + aliasTable.Table.AS(alias) + + return aliasTable +} + +func newActorInfoTable() *ActorInfoTable { + var ( + ActorIDColumn = postgres.IntegerColumn("actor_id") + FirstNameColumn = postgres.StringColumn("first_name") + LastNameColumn = postgres.StringColumn("last_name") + FilmInfoColumn = postgres.StringColumn("film_info") + ) + + return &ActorInfoTable{ + Table: postgres.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + + //Columns + ActorID: ActorIDColumn, + FirstName: FirstNameColumn, + LastName: LastNameColumn, + FilmInfo: FilmInfoColumn, + + AllColumns: postgres.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + MutableColumns: postgres.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + } +} +` diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 1fa2421..15ec4fd 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/enum" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table" + "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/view" "gotest.tools/assert" "testing" "time" @@ -19,15 +20,15 @@ SELECT DISTINCT actor.actor_id AS "actor.actor_id", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor -WHERE actor.actor_id = 1; +WHERE actor.actor_id = 2; ` query := Actor. SELECT(Actor.AllColumns). DISTINCT(). - WHERE(Actor.ActorID.EQ(Int(1))) + WHERE(Actor.ActorID.EQ(Int(2))) - testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1)) + testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(2)) actor := model.Actor{} err := query.Query(db, &actor) @@ -35,9 +36,9 @@ WHERE actor.actor_id = 1; assert.NilError(t, err) expectedActor := model.Actor{ - ActorID: 1, - FirstName: "Penelope", - LastName: "Guiness", + ActorID: 2, + FirstName: "Nick", + LastName: "Wahlberg", LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2), } @@ -1722,3 +1723,62 @@ ORDER BY payment.customer_id; assert.NilError(t, err) } + +func TestSimpleView(t *testing.T) { + query := SELECT( + view.ActorInfo.AllColumns, + ). + FROM(view.ActorInfo). + ORDER_BY(view.ActorInfo.ActorID). + LIMIT(10) + + type ActorInfo struct { + ActorID int + FirstName string + LastName string + FilmInfo string + } + + var dest []ActorInfo + + err := query.Query(db, &dest) + assert.NilError(t, err) + + testutils.AssertJSON(t, dest[1:2], ` +[ + { + "ActorID": 2, + "FirstName": "Nick", + "LastName": "Wahlberg", + "FilmInfo": "Action: Bull Shawshank, Animation: Fight Jawbreaker, Children: Jersey Sassy, Classics: Dracula Crystal, Gilbert Pelican, Comedy: Mallrats United, Rushmore Mermaid, Documentary: Adaptation Holes, Drama: Wardrobe Phantom, Family: Apache Divine, Chisum Behavior, Indian Love, Maguire Apache, Foreign: Baby Hall, Happiness United, Games: Roof Champion, Music: Lucky Flying, New: Destiny Saturday, Flash Wars, Jekyll Frogmen, Mask Peach, Sci-Fi: Chainsaw Uptown, Goodfellas Salute, Travel: Liaisons Sweet, Smile Earring" + } +] +`) + +} + +func TestJoinViewWithTable(t *testing.T) { + query := SELECT( + view.CustomerList.AllColumns, + Rental.AllColumns, + ). + FROM(view.CustomerList. + INNER_JOIN(Rental, view.CustomerList.ID.EQ(Rental.CustomerID)), + ). + ORDER_BY(view.CustomerList.ID). + WHERE(view.CustomerList.ID.LT_EQ(Int(2))) + + var dest []struct { + model.CustomerList `sql:"primary_key=ID"` + Rentals []model.Rental + } + + fmt.Println(query.DebugSql()) + + err := query.Query(db, &dest) + assert.NilError(t, err) + + assert.Equal(t, len(dest), 2) + assert.Equal(t, len(dest[0].Rentals), 32) + assert.Equal(t, len(dest[1].Rentals), 27) +} diff --git a/tests/testdata b/tests/testdata index 7f3f3cc..088a035 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 7f3f3cc26ce34324f3699d6b422376671b827490 +Subproject commit 088a035d179707c9c59972922f6915c992347d2e