From 71fb1c7cd1dead9c3fa0f9bf6814470569699c8e Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 16 Dec 2023 11:43:40 +0100 Subject: [PATCH] Add support for sqlite generated columns. --- generator/sqlite/query_set.go | 36 +++++++++++++- internal/utils/semantic/version.go | 66 +++++++++++++++++++++++++ tests/sqlite/sample_test.go | 78 ++++++++++++++++++++++++++++++ 3 files changed, 178 insertions(+), 2 deletions(-) create mode 100644 internal/utils/semantic/version.go create mode 100644 tests/sqlite/sample_test.go diff --git a/generator/sqlite/query_set.go b/generator/sqlite/query_set.go index d1d0bf7..9cf6541 100644 --- a/generator/sqlite/query_set.go +++ b/generator/sqlite/query_set.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils/semantic" "github.com/go-jet/jet/v2/qrm" "strings" ) @@ -42,16 +43,45 @@ func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTy return tables, nil } +func getTableInfoQuery(db *sql.DB) (string, error) { + var version string + err := db.QueryRow("select sqlite_version();").Scan(&version) + + if err != nil { + return "", fmt.Errorf("failed to get sqlite version: %w", err) + } + + sqliteVersion, err := semantic.VersionFromString(version) + + if err != nil { + return "", fmt.Errorf("can't parse sqlite version: %w", err) + } + + // generated columns were added in version 3.26.0 + if sqliteVersion.Lt(semantic.Version{Major: 3, Minor: 26, Patch: 0}) { + return `select * from pragma_table_info(?);`, nil + } + + return `select * from pragma_table_xinfo(?);`, nil +} + func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) { - query := fmt.Sprintf(`select * from pragma_table_info(?);`) + + tableInfoQuery, err := getTableInfoQuery(db) + + if err != nil { + return nil, err + } + var columnInfos []struct { Name string Type string NotNull int32 Pk int32 + Hidden int32 } - _, err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos) + _, err = qrm.Query(context.Background(), db, tableInfoQuery, []interface{}{tableName}, &columnInfos) if err != nil { return nil, fmt.Errorf("failed to query '%s' column metadata: %w", tableName, err) } @@ -60,11 +90,13 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t for _, columnInfo := range columnInfos { columnType := getColumnType(columnInfo.Type) + isGenerated := columnInfo.Hidden == 2 || columnInfo.Hidden == 3 // stored or virtual column columns = append(columns, metadata.Column{ Name: columnInfo.Name, IsPrimaryKey: columnInfo.Pk != 0, IsNullable: columnInfo.NotNull != 1, + IsGenerated: isGenerated, DataType: metadata.DataType{ Name: columnType, Kind: metadata.BaseType, diff --git a/internal/utils/semantic/version.go b/internal/utils/semantic/version.go new file mode 100644 index 0000000..b13413f --- /dev/null +++ b/internal/utils/semantic/version.go @@ -0,0 +1,66 @@ +package semantic + +import ( + "fmt" + "strconv" + "strings" +) + +// Version struct holds semantic versioning information +type Version struct { + Major int + Minor int + Patch int +} + +// VersionFromString creates new semantic Version by parsing version string +func VersionFromString(version string) (Version, error) { + parts := strings.Split(version, ".") + + var ret Version + + if len(parts) > 0 { + major, err := strconv.Atoi(parts[0]) + + if err != nil { + return ret, fmt.Errorf("major is not a number: %w", err) + } + + ret.Major = major + } + + if len(parts) > 1 { + minor, err := strconv.Atoi(parts[1]) + + if err != nil { + return ret, fmt.Errorf("minor is not a number: %w", err) + } + + ret.Minor = minor + } + + if len(parts) > 2 { + patch, err := strconv.Atoi(parts[2]) + + if err != nil { + return ret, fmt.Errorf("patch is not a number: %w", err) + } + + ret.Patch = patch + } + + return ret, nil +} + +// Lt returns true if this version is less than version parameter +func (v Version) Lt(version Version) bool { + if v.Major < version.Major { + return true + } + + if v.Minor < version.Minor { + return true + } + + return v.Patch < version.Patch +} diff --git a/tests/sqlite/sample_test.go b/tests/sqlite/sample_test.go new file mode 100644 index 0000000..1671f5d --- /dev/null +++ b/tests/sqlite/sample_test.go @@ -0,0 +1,78 @@ +package sqlite + +import ( + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/stretchr/testify/require" + "testing" + + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" +) + +func TestMutableColumnsExcludeGeneratedColumn(t *testing.T) { + + t.Run("should not have the generated column in mutableColumns", func(t *testing.T) { + require.Equal(t, 2, len(People.MutableColumns)) + require.Equal(t, People.PeopleName, People.MutableColumns[0]) + require.Equal(t, People.PeopleHeightCm, People.MutableColumns[1]) + }) + + t.Run("should query with all columns", func(t *testing.T) { + query := SELECT( + People.AllColumns, + ).FROM( + People, + ).WHERE( + People.PeopleID.EQ(Int(3)), + ) + + testutils.AssertStatementSql(t, query, ` +SELECT people.people_id AS "people.people_id", + people.people_name AS "people.people_name", + people.people_height_cm AS "people.people_height_cm", + people.people_height_in AS "people.people_height_in" +FROM people +WHERE people.people_id = ?; +`) + var result model.People + + err := query.Query(sampleDB, &result) + require.NoError(t, err) + + require.Equal(t, "Carla", result.PeopleName) + require.Equal(t, 155., *result.PeopleHeightCm) + require.InEpsilon(t, 61.02, *result.PeopleHeightIn, 1e-3) + }) + + t.Run("should insert without generated columns", func(t *testing.T) { + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + insertQuery := People.INSERT( + People.MutableColumns, + ).MODEL( + model.People{ + PeopleName: "Dario", + PeopleHeightCm: testutils.Float64Ptr(120), + }, + ).RETURNING( + People.AllColumns, + ) + + testutils.AssertDebugStatementSql(t, insertQuery, ` +INSERT INTO people (people_name, people_height_cm) +VALUES ('Dario', 120) +RETURNING people.people_id AS "people.people_id", + people.people_name AS "people.people_name", + people.people_height_cm AS "people.people_height_cm", + people.people_height_in AS "people.people_height_in"; +`) + var result model.People + err := insertQuery.Query(tx, &result) + require.NoError(t, err) + + require.Equal(t, "Dario", result.PeopleName) + require.Equal(t, 120., *result.PeopleHeightCm) + }) + }) +}