Merge pull request #13 from go-jet/develop

Merge develop to master
This commit is contained in:
go-jet 2019-09-21 15:58:30 +02:00 committed by GitHub
commit fbf3b6d51c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
48 changed files with 1903 additions and 512 deletions

4
NOTICE
View file

@ -8,3 +8,7 @@ https://github.com/dropbox/godropbox/tree/master/database/sqlbuilder (BSD-3)
This product contains a modified portion of 'snaker' which can be obtained at:
https://github.com/serenize/snaker (MIT)
This product contains `FormatTimestamp` function from 'pq' which can be obtained at:
https://github.com/lib/pq (MIT)

View file

@ -12,7 +12,7 @@ convert database query result into desired arbitrary object structure.
Jet currently supports `PostgreSQL`, `MySQL` and `MariaDB`. Future releases will add support for additional databases.
![jet](https://github.com/go-jet/jet/wiki/image/jet.png)
Jet is the easiest and fastest way to write complex SQL queries and map database query result
Jet is the easiest and the fastest way to write complex SQL queries and map database query result
into complex object composition. __It is not an ORM.__
## Motivation
@ -46,7 +46,7 @@ https://medium.com/@go.jet/jet-5f3667efa0cc
* UPDATE `(SET, WHERE)`,
* DELETE `(WHERE, ORDER_BY, LIMIT)`,
* LOCK `(READ, WRITE)`
2) Auto-generated Data Model types - Go types mapped to database type (table or enum), used to store
2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store
result of database queries. Can be combined to create desired query result destination.
3) Query execution with result mapping to arbitrary destination structure.
@ -88,12 +88,13 @@ jet -source=PostgreSQL -host=localhost -port=5432 -user=jetuser -password=jetpas
```sh
Connecting to postgres database: host=localhost port=5432 user=jetuser password=jetpass dbname=jetdb sslmode=disable
Retrieving schema information...
FOUND 15 table(s), 1 enum(s)
Destination directory: ./gen/jetdb/dvds
Cleaning up schema destination directory...
FOUND 15 table(s), 7 view(s), 1 enum(s)
Cleaning up destination directory...
Generating table sql builder files...
Generating table model files...
Generating view sql builder files...
Generating enum sql builder files...
Generating table model files...
Generating view model files...
Generating enum model files...
Done
```
@ -102,9 +103,9 @@ be omitted (both databases doesn't have schema support).
_*User has to have a permission to read information schema tables._
As command output suggest, Jet will:
- connect to postgres database and retrieve information about the _tables_ and _enums_ of `dvds` schema
- connect to postgres database and retrieve information about the _tables_, _views_ and _enums_ of `dvds` schema
- delete everything in schema destination folder - `./gen/jetdb/dvds`,
- and finally generate SQL Builder and Model files for each schema table and enum.
- and finally generate SQL Builder and Model files for each schema table, view and enum.
Generated files folder structure will look like this:
@ -112,20 +113,24 @@ Generated files folder structure will look like this:
|-- gen # -path
| `-- jetdb # database name
| `-- dvds # schema name
| |-- enum # sql builder folder for enums
| |-- enum # sql builder package for enums
| | |-- mpaa_rating.go
| |-- table # sql builder folder for tables
| |-- table # sql builder package for tables
| |-- actor.go
| |-- address.go
| |-- category.go
| ...
| |-- model # model files for each table and enum
| |-- view # sql builder package for views
| |-- actor_info.go
| |-- film_list.go
| ...
| |-- model # data model types for each table, view and enum
| | |-- actor.go
| | |-- address.go
| | |-- mpaa_rating.go
| | ...
```
Types from `table` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store
Types from `table`, `view` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store
results of the SQL queries.
@ -167,7 +172,8 @@ stmt := SELECT(
Film.FilmID.ASC(),
)
```
Package(dot) import is used so that statement would resemble as much as possible as native SQL. Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with
_Package(dot) import is used so that statement would resemble as much as possible as native SQL._
Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with
string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns
and can be compared only with integer columns and expressions.
@ -268,11 +274,12 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
#### Execute query and store result
Well formed SQL is just a first half the job. Lets see how can we make some sense of result set returned executing
Well formed SQL is just a first half of the job. Lets see how can we make some sense of result set returned executing
above statement. Usually this is the most complex and tedious work, but with Jet it is the easiest.
First we have to create desired structure to store query result set.
This is done be combining autogenerated model types or it can be done manually(see [wiki](https://github.com/go-jet/jet/wiki/Scan-to-arbitrary-destination) for more information).
First we have to create desired structure to store query result.
This is done be combining autogenerated model types or it can be done
manually(see [wiki](https://github.com/go-jet/jet/wiki/Query-Result-Mapping-(QRM)) for more information).
Let's say this is our desired structure:
```go
@ -287,8 +294,8 @@ var dest []struct {
}
}
```
Because one actor can act in multiple films, `Films` field is a slice, and because each film belongs to one language
`Langauge` field is just a single model struct.
`Films` field is a slice because one actor can act in multiple films, and because each film belongs to one language
`Langauge` field is just a single model struct. `Film` can belong to multiple categories.
_*There is no limitation of how big or nested destination can be._
Now lets execute a above statement on open database connection (or transaction) db and store result into `dest`.
@ -504,12 +511,14 @@ The biggest benefit is speed. Speed is improved in 3 major areas:
##### Speed of development
Writing SQL queries is much easier, because programmer has the help of SQL code completion and SQL type safety directly in Go.
Writing code is much faster and code is more robust. Automatic scan to arbitrary structure removes a lot of headache and
boilerplate code needed to structure database query result.
Writing SQL queries is faster and easier, because the developers have help of SQL code completion and SQL type safety directly from Go.
Automatic scan to arbitrary structure removes a lot of headache and boilerplate code needed to structure database query result.
##### Speed of execution
While ORM libraries can introduce significant performance penalties due to number of round-trips to the database,
Jet will always perform much better, because of the single database call.
Common web and database server usually are not on the same physical machine, and there is some latency between them.
Latency can vary from 5ms to 50+ms. In majority of cases query executed on database is simple query lasting no more than 1ms.
In those cases web server handler execution time is directly proportional to latency between server and database.
@ -521,14 +530,14 @@ With Jet, handler time lost on latency between server and database is constant.
return result in one database call. Handler execution will be only proportional to the number of rows returned from database.
ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate).
With Jet you can even join the whole database and store the whole structured result in in one query call.
With Jet you can even join the whole database and store the whole structured result in one database call.
This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/postgres/chinook_db_test.go#L40).
The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s.
##### How quickly bugs are found
The most expensive bugs are the one on the production and the least expensive are those found during development.
With automatically generated type safe SQL not only queries are written faster but bugs are found sooner.
With automatically generated type safe SQL, not only queries are written faster but bugs are found sooner.
Lets return to quick start example, and take closer look at a line:
```go
AND(Film.Length.GT(Int(180))),

View file

@ -91,9 +91,7 @@ func jsonSave(path string, v interface{}) {
err := ioutil.WriteFile(path, jsonText, 0644)
if err != nil {
panic(err)
}
panicOnError(err)
}
func printStatementInfo(stmt SelectStatement) {

View file

@ -771,7 +771,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl
if len(subType.indexes) != 0 || len(subType.subTypes) != 0 {
ret.subTypes = append(ret.subTypes, subType)
}
} else if isPrimaryKey(field) {
} else if isPrimaryKey(field, parentField) {
index := s.typeToColumnIndex(newTypeName, fieldName)
if index < 0 {
@ -813,9 +813,7 @@ func (s *scanContext) rowElem(index int) interface{} {
value, err := valuer.Value()
if err != nil {
panic(err)
}
utils.PanicOnError(err)
return value
}
@ -837,13 +835,45 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
return newElem
}
func isPrimaryKey(field reflect.StructField) bool {
func isPrimaryKey(field reflect.StructField, parentField *reflect.StructField) bool {
if hasOverwrite, isPrimaryKey := primaryKeyOvewrite(field.Name, parentField); hasOverwrite {
return isPrimaryKey
}
sqlTag := field.Tag.Get("sql")
return sqlTag == "primary_key"
}
func primaryKeyOvewrite(columnName string, parentField *reflect.StructField) (hasOverwrite, primaryKey bool) {
if parentField == nil {
return
}
sqlTag := parentField.Tag.Get("sql")
if !strings.HasPrefix(sqlTag, "primary_key") {
return
}
parts := strings.Split(sqlTag, "=")
if len(parts) < 2 {
return
}
primaryKeyColumns := strings.Split(parts[1], ",")
for _, primaryKeyCol := range primaryKeyColumns {
if toCommonIdentifier(columnName) == toCommonIdentifier(primaryKeyCol) {
return true, true
}
}
return true, false
}
func indirectType(reflectType reflect.Type) reflect.Type {
if reflectType.Kind() != reflect.Ptr {
return reflectType

View file

@ -62,7 +62,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
nt.Time, nt.Valid = parseTime(v)
return
default:
return fmt.Errorf("can't scan time from %v", value)
return fmt.Errorf("can't scan time.Time from %v", value)
}
}

View file

@ -0,0 +1,147 @@
package internal
import (
"fmt"
"gotest.tools/assert"
"testing"
"time"
)
func TestNullByteArray(t *testing.T) {
var array NullByteArray
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan([]byte("bytea")))
assert.Equal(t, array.Valid, true)
assert.Equal(t, string(array.ByteArray), string([]byte("bytea")))
assert.Error(t, array.Scan(12), "can't scan []byte from 12")
}
func TestNullTime(t *testing.T) {
var array NullTime
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
time := time.Now()
assert.NilError(t, array.Scan(time))
assert.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, time)
assert.NilError(t, array.Scan([]byte("13:10:11")))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
assert.NilError(t, array.Scan("13:10:11"))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
assert.Error(t, array.Scan(12), "can't scan time.Time from 12")
}
func TestNullInt8(t *testing.T) {
var array NullInt8
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, int8(11))
assert.Error(t, array.Scan("text"), "can't scan int8 from text")
}
func TestNullInt16(t *testing.T) {
var array NullInt16
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, int16(11))
assert.NilError(t, array.Scan(int16(20)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int16(20))
assert.NilError(t, array.Scan(int8(30)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int16(30))
assert.NilError(t, array.Scan(uint8(30)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int16(30))
assert.Error(t, array.Scan("text"), "can't scan int16 from text")
}
func TestNullInt32(t *testing.T) {
var array NullInt32
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, int32(11))
assert.NilError(t, array.Scan(int32(32)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(32))
assert.NilError(t, array.Scan(int16(20)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(20))
assert.NilError(t, array.Scan(uint16(16)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(16))
assert.NilError(t, array.Scan(int8(30)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(30))
assert.NilError(t, array.Scan(uint8(30)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(30))
assert.Error(t, array.Scan("text"), "can't scan int32 from text")
}
func TestNullFloat32(t *testing.T) {
var array NullFloat32
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(float64(64)))
assert.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, float32(64))
assert.NilError(t, array.Scan(float32(32)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, float32(32))
assert.Error(t, array.Scan(12), "can't scan float32 from 12")
}

View file

@ -142,13 +142,10 @@ func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string {
return ""
}
func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) ([]ColumnMetaData, error) {
func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) []ColumnMetaData {
rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName)
if err != nil {
return nil, err
}
utils.PanicOnError(err)
defer rows.Close()
ret := []ColumnMetaData{}
@ -157,19 +154,13 @@ func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableN
var name, isNullable, dataType, enumName string
var isUnsigned bool
err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned)
if err != nil {
return nil, err
}
utils.PanicOnError(err)
ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned))
}
err = rows.Err()
utils.PanicOnError(err)
if err != nil {
return nil, err
}
return ret, nil
return ret
}

View file

@ -11,5 +11,5 @@ type DialectQuerySet interface {
ListOfColumnsQuery() string
ListOfEnumsQuery() string
GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaData, error)
GetEnumsMetaData(db *sql.DB, schemaName string) []MetaData
}

View file

@ -3,41 +3,43 @@ package metadata
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/internal/utils"
)
// SchemaMetaData struct
type SchemaMetaData struct {
TableInfos []MetaData
EnumInfos []MetaData
TablesMetaData []MetaData
ViewsMetaData []MetaData
EnumsMetaData []MetaData
}
// GetSchemaInfo returns schema information from db connection.
func GetSchemaInfo(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData, err error) {
// IsEmpty returns true if schema info does not contain any table, views or enums metadata
func (s SchemaMetaData) IsEmpty() bool {
return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0
}
schemaInfo.TableInfos, err = getTableInfos(db, querySet, schemaName)
const (
baseTable = "BASE TABLE"
view = "VIEW"
)
if err != nil {
return
}
// GetSchemaMetaData returns schema information from db connection.
func GetSchemaMetaData(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData) {
schemaInfo.EnumInfos, err = querySet.GetEnumsMetaData(db, schemaName)
schemaInfo.TablesMetaData = getTablesMetaData(db, querySet, schemaName, baseTable)
schemaInfo.ViewsMetaData = getTablesMetaData(db, querySet, schemaName, view)
schemaInfo.EnumsMetaData = querySet.GetEnumsMetaData(db, schemaName)
if err != nil {
return
}
fmt.Println(" FOUND", len(schemaInfo.TableInfos), "table(s), ", len(schemaInfo.EnumInfos), "enum(s)")
fmt.Println(" FOUND", len(schemaInfo.TablesMetaData), "table(s),", len(schemaInfo.ViewsMetaData), "view(s),",
len(schemaInfo.EnumsMetaData), "enum(s)")
return
}
func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]MetaData, error) {
func getTablesMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableType string) []MetaData {
rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName)
if err != nil {
return nil, err
}
rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName, tableType)
utils.PanicOnError(err)
defer rows.Close()
ret := []MetaData{}
@ -45,24 +47,15 @@ func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]M
var tableName string
err = rows.Scan(&tableName)
if err != nil {
return nil, err
}
utils.PanicOnError(err)
tableInfo, err := GetTableInfo(db, querySet, schemaName, tableName)
if err != nil {
return nil, err
}
tableInfo := GetTableMetaData(db, querySet, schemaName, tableName)
ret = append(ret, tableInfo)
}
err = rows.Err()
utils.PanicOnError(err)
if err != nil {
return nil, err
}
return ret, nil
return ret
}

View file

@ -67,46 +67,32 @@ func (t TableMetaData) GoStructName() string {
return utils.ToGoIdentifier(t.name) + "Table"
}
// GetTableInfo returns table info metadata
func GetTableInfo(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData, err error) {
// GetTableMetaData returns table info metadata
func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) {
tableInfo.SchemaName = schemaName
tableInfo.name = tableName
tableInfo.PrimaryKeys, err = getPrimaryKeys(db, querySet, schemaName, tableName)
if err != nil {
return
}
tableInfo.Columns, err = getColumnsMetaData(db, querySet, schemaName, tableName)
if err != nil {
return
}
tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName)
tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName)
return
}
func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (map[string]bool, error) {
func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) map[string]bool {
rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName)
if err != nil {
return nil, err
}
utils.PanicOnError(err)
primaryKeyMap := map[string]bool{}
for rows.Next() {
primaryKey := ""
err := rows.Scan(&primaryKey)
if err != nil {
return nil, err
}
utils.PanicOnError(err)
primaryKeyMap[primaryKey] = true
}
return primaryKeyMap, nil
return primaryKeyMap
}

View file

@ -12,89 +12,65 @@ import (
)
// GenerateFiles generates Go files from tables and enums metadata
func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect jet.Dialect) error {
if len(tables) == 0 && len(enums) == 0 {
return nil
func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect jet.Dialect) {
if schemaInfo.IsEmpty() {
return
}
fmt.Println("Destination directory:", destDir)
fmt.Println("Cleaning up destination directory...")
err := utils.CleanUpGeneratedFiles(destDir)
utils.PanicOnError(err)
if err != nil {
return err
}
generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect)
generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect)
generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect)
fmt.Println("Generating table sql builder files...")
err = generate(destDir, "table", tableSQLBuilderTemplate, tables, dialect)
if err != nil {
return err
}
fmt.Println("Generating table model files...")
err = generate(destDir, "model", tableModelTemplate, tables, dialect)
if err != nil {
return err
}
if len(enums) > 0 {
fmt.Println("Generating enum sql builder files...")
err = generate(destDir, "enum", enumSQLBuilderTemplate, enums, dialect)
if err != nil {
return err
}
fmt.Println("Generating enum model files...")
err = generate(destDir, "model", enumModelTemplate, enums, dialect)
if err != nil {
return err
}
}
generateModelFiles(destDir, "table", tableModelTemplate, schemaInfo.TablesMetaData, dialect)
generateModelFiles(destDir, "view", tableModelTemplate, schemaInfo.ViewsMetaData, dialect)
generateModelFiles(destDir, "enum", enumModelTemplate, schemaInfo.EnumsMetaData, dialect)
fmt.Println("Done")
return nil
}
func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error {
func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) {
if len(metaData) == 0 {
return
}
fmt.Printf("Generating %s sql builder files...\n", fileTypes)
generateGoFiles(destDir, fileTypes, sqlBuilderTemplate, metaData, dialect)
}
func generateModelFiles(destDir, fileTypes, modelTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) {
if len(metaData) == 0 {
return
}
fmt.Printf("Generating %s model files...\n", fileTypes)
generateGoFiles(destDir, "model", modelTemplate, metaData, dialect)
}
func generateGoFiles(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) {
modelDirPath := filepath.Join(dirPath, packageName)
err := utils.EnsureDirPath(modelDirPath)
if err != nil {
return err
}
utils.PanicOnError(err)
autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect)
if err != nil {
return err
}
utils.PanicOnError(err)
for _, metaData := range metaDataList {
text, err := GenerateTemplate(template, metaData, dialect)
if err != nil {
return err
}
text, err := GenerateTemplate(template, metaData, dialect, map[string]interface{}{"package": packageName})
utils.PanicOnError(err)
err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...))
if err != nil {
return err
}
utils.PanicOnError(err)
}
return nil
return
}
// GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}, dialect1 jet.Dialect) ([]byte, error) {
func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect, params ...map[string]interface{}) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": utils.ToGoIdentifier,
@ -102,7 +78,13 @@ func GenerateTemplate(templateText string, templateData interface{}, dialect1 je
return time.Now().Format(time.RFC850)
},
"dialect": func() jet.Dialect {
return dialect1
return dialect
},
"param": func(name string) interface{} {
if len(params) > 0 {
return params[0][name]
}
return ""
},
}).Parse(templateText)

View file

@ -18,7 +18,7 @@ var tableSQLBuilderTemplate = `
{{- end}}
{{- end}}
package table
package {{param "package"}}
import (
"github.com/go-jet/jet/{{dialect.PackageName}}"

View file

@ -22,50 +22,34 @@ type DBConnection struct {
}
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
db, err := openConnection(dbConn)
if err != nil {
return err
}
func Generate(destDir string, dbConn DBConnection) (err error) {
defer utils.ErrorCatch(&err)
db := openConnection(dbConn)
defer utils.DBClose(db)
fmt.Println("Retrieving database information...")
// No schemas in MySQL
dbInfo, err := metadata.GetSchemaInfo(db, dbConn.DBName, &mySqlQuerySet{})
if err != nil {
return err
}
dbInfo := metadata.GetSchemaMetaData(db, dbConn.DBName, &mySqlQuerySet{})
genPath := path.Join(destDir, dbConn.DBName)
err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, mysql.Dialect)
if err != nil {
return err
}
template.GenerateFiles(genPath, dbInfo, mysql.Dialect)
return nil
}
func openConnection(dbConn DBConnection) (*sql.DB, error) {
func openConnection(dbConn DBConnection) *sql.DB {
var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName)
if dbConn.Params != "" {
connectionString += "?" + dbConn.Params
}
db, err := sql.Open("mysql", connectionString)
fmt.Println("Connecting to MySQL database: " + connectionString)
if err != nil {
return nil, err
}
db, err := sql.Open("mysql", connectionString)
utils.PanicOnError(err)
err = db.Ping()
utils.PanicOnError(err)
if err != nil {
return nil, err
}
return db, nil
return db
}

View file

@ -3,6 +3,7 @@ package mysql
import (
"database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/internal/utils"
"strings"
)
@ -13,7 +14,7 @@ func (m *mySqlQuerySet) ListOfTablesQuery() string {
return `
SELECT table_name
FROM INFORMATION_SCHEMA.tables
WHERE table_schema = ? and table_type = 'BASE TABLE';
WHERE table_schema = ? and table_type = ?;
`
}
@ -46,17 +47,14 @@ func (m *mySqlQuerySet) ListOfEnumsQuery() string {
SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5)
FROM information_schema.columns as c
INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name)
WHERE c.table_schema = ? AND DATA_TYPE = 'enum' AND t.TABLE_TYPE = 'BASE TABLE';
WHERE c.table_schema = ? AND DATA_TYPE = 'enum';
`
}
func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.MetaData, error) {
func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData {
rows, err := db.Query(m.ListOfEnumsQuery(), schemaName)
if err != nil {
return nil, err
}
utils.PanicOnError(err)
defer rows.Close()
ret := []metadata.MetaData{}
@ -65,9 +63,7 @@ func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metad
var enumName string
var enumValues string
err = rows.Scan(&enumName, &enumValues)
if err != nil {
return nil, err
}
utils.PanicOnError(err)
enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1)
@ -78,11 +74,8 @@ func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metad
}
err = rows.Err()
utils.PanicOnError(err)
if err != nil {
return nil, err
}
return ret, nil
return ret
}

View file

@ -25,31 +25,20 @@ type DBConnection struct {
}
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
func Generate(destDir string, dbConn DBConnection) (err error) {
defer utils.ErrorCatch(&err)
db, err := openConnection(dbConn)
utils.PanicOnError(err)
defer utils.DBClose(db)
if err != nil {
return err
}
fmt.Println("Retrieving schema information...")
schemaInfo, err := metadata.GetSchemaInfo(db, dbConn.SchemaName, &postgresQuerySet{})
if err != nil {
return err
}
schemaInfo := metadata.GetSchemaMetaData(db, dbConn.SchemaName, &postgresQuerySet{})
genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
template.GenerateFiles(genPath, schemaInfo, postgres.Dialect)
err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, postgres.Dialect)
if err != nil {
return err
}
return nil
return
}
func openConnection(dbConn DBConnection) (*sql.DB, error) {

View file

@ -3,6 +3,7 @@ package postgres
import (
"database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/internal/utils"
)
// postgresQuerySet is dialect query set for PostgreSQL
@ -12,7 +13,7 @@ func (p *postgresQuerySet) ListOfTablesQuery() string {
return `
SELECT table_name
FROM information_schema.tables
where table_schema = $1 and table_type = 'BASE TABLE';
where table_schema = $1 and table_type = $2;
`
}
@ -45,12 +46,9 @@ WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;`
}
func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.MetaData, error) {
func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData {
rows, err := db.Query(p.ListOfEnumsQuery(), schemaName)
if err != nil {
return nil, err
}
utils.PanicOnError(err)
defer rows.Close()
enumsInfosMap := map[string][]string{}
@ -58,9 +56,7 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me
var enumName string
var enumValue string
err = rows.Scan(&enumName, &enumValue)
if err != nil {
return nil, err
}
utils.PanicOnError(err)
enumValues := enumsInfosMap[enumName]
@ -70,10 +66,7 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me
}
err = rows.Err()
if err != nil {
return nil, err
}
utils.PanicOnError(err)
ret := []metadata.MetaData{}
@ -84,5 +77,5 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me
})
}
return ret, nil
return ret
}

View file

@ -0,0 +1,42 @@
package pq
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany
import (
"strconv"
"time"
)
// FormatTimestamp formats t into Postgres' text format for timestamps. From: github.com/lib/pq
func FormatTimestamp(t time.Time) []byte {
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
// minus sign preferred by Go.
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
bc := false
if t.Year() <= 0 {
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
t = t.AddDate((-t.Year())*2+1, 0, 0)
bc = true
}
b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
_, offset := t.Zone()
offset = offset % 60
if offset != 0 {
// RFC3339Nano already printed the minus sign
if offset < 0 {
offset = -offset
}
b = append(b, ':')
if offset < 10 {
b = append(b, '0')
}
b = strconv.AppendInt(b, int64(offset), 10)
}
if bc {
b = append(b, " BC"...)
}
return b
}

View file

@ -0,0 +1,39 @@
package pq
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany
import (
"testing"
"time"
)
var formatTimeTests = []struct {
time time.Time
expected string
}{
{time.Time{}, "0001-01-01 00:00:00Z"},
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03 04:05:06.123456789Z"},
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03 04:05:06.123456789+02:00"},
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03 04:05:06.123456789-06:00"},
{time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03 04:05:06-07:30:09"},
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z"},
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00"},
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00"},
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z BC"},
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00 BC"},
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00 BC"},
{time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09"},
{time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09 BC"},
}
func TestFormatTs(t *testing.T) {
for i, tt := range formatTimeTests {
val := string(FormatTimestamp(tt.time))
if val != tt.expected {
t.Errorf("%d: incorrect time format %q, want %q", i, val, tt.expected)
}
}
}

View file

@ -134,7 +134,8 @@ func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) {
// ClauseOrderBy struct
type ClauseOrderBy struct {
List []OrderByClause
List []OrderByClause
SkipNewLine bool
}
// Serialize serializes clause into SQLBuilder
@ -143,7 +144,9 @@ func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder)
return
}
out.NewLine()
if !o.SkipNewLine {
out.NewLine()
}
out.WriteString("ORDER BY")
out.IncreaseIdent()
@ -469,3 +472,37 @@ func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString(string(i.LockMode))
out.WriteString("MODE")
}
// WindowDefinition struct
type WindowDefinition struct {
Name string
Window Window
}
// ClauseWindow struct
type ClauseWindow struct {
Definitions []WindowDefinition
}
// Serialize serializes clause into SQLBuilder
func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) {
if len(i.Definitions) == 0 {
return
}
out.NewLine()
out.WriteString("WINDOW")
for i, def := range i.Definitions {
if i > 0 {
out.WriteString(", ")
}
out.WriteString(def.Name)
out.WriteString("AS")
if def.Window == nil {
out.WriteString("()")
continue
}
def.Window.serialize(statementType, out)
}
}

View file

@ -81,68 +81,154 @@ func LOG(floatExpression FloatExpression) FloatExpression {
// ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression
func AVG(numericExpression NumericExpression) FloatExpression {
return NewFloatFunc("AVG", numericExpression)
func AVG(numericExpression NumericExpression) floatWindowExpression {
return NewFloatWindowFunc("AVG", numericExpression)
}
// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none.
func BIT_AND(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_AND", integerExpression)
func BIT_AND(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("BIT_AND", integerExpression)
}
// BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none.
func BIT_OR(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_OR", integerExpression)
func BIT_OR(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("BIT_OR", integerExpression)
}
// BOOL_AND is aggregate function. Returns true if all input values are true, otherwise false
func BOOL_AND(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_AND", boolExpression)
func BOOL_AND(boolExpression BoolExpression) boolWindowExpression {
return newBoolWindowFunc("BOOL_AND", boolExpression)
}
// BOOL_OR is aggregate function. Returns true if at least one input value is true, otherwise false
func BOOL_OR(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_OR", boolExpression)
func BOOL_OR(boolExpression BoolExpression) boolWindowExpression {
return newBoolWindowFunc("BOOL_OR", boolExpression)
}
// COUNT is aggregate function. Returns number of input rows for which the value of expression is not null.
func COUNT(expression Expression) IntegerExpression {
return newIntegerFunc("COUNT", expression)
func COUNT(expression Expression) integerWindowExpression {
return newIntegerWindowFunc("COUNT", expression)
}
// EVERY is aggregate function. Returns true if all input values are true, otherwise false
func EVERY(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("EVERY", boolExpression)
func EVERY(boolExpression BoolExpression) boolWindowExpression {
return newBoolWindowFunc("EVERY", boolExpression)
}
// MAXf is aggregate function. Returns maximum value of float expression across all input values
func MAXf(floatExpression FloatExpression) FloatExpression {
return NewFloatFunc("MAX", floatExpression)
func MAXf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatWindowFunc("MAX", floatExpression)
}
// MAXi is aggregate function. Returns maximum value of int expression across all input values
func MAXi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MAX", integerExpression)
func MAXi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("MAX", integerExpression)
}
// MINf is aggregate function. Returns minimum value of float expression across all input values
func MINf(floatExpression FloatExpression) FloatExpression {
return NewFloatFunc("MIN", floatExpression)
func MINf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatWindowFunc("MIN", floatExpression)
}
// MINi is aggregate function. Returns minimum value of int expression across all input values
func MINi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MIN", integerExpression)
func MINi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("MIN", integerExpression)
}
// SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) FloatExpression {
return NewFloatFunc("SUM", floatExpression)
func SUMf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatWindowFunc("SUM", floatExpression)
}
// SUMi is aggregate function. Returns sum of expression across all integer expression.
func SUMi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("SUM", integerExpression)
func SUMi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("SUM", integerExpression)
}
// ----------------- Window functions -------------------//
// ROW_NUMBER returns number of the current row within its partition, counting from 1
func ROW_NUMBER() integerWindowExpression {
return newIntegerWindowFunc("ROW_NUMBER")
}
// RANK of the current row with gaps; same as row_number of its first peer
func RANK() integerWindowExpression {
return newIntegerWindowFunc("RANK")
}
// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups
func DENSE_RANK() integerWindowExpression {
return newIntegerWindowFunc("DENSE_RANK")
}
// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1)
func PERCENT_RANK() floatWindowExpression {
return NewFloatWindowFunc("PERCENT_RANK")
}
// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows
func CUME_DIST() floatWindowExpression {
return NewFloatWindowFunc("CUME_DIST")
}
// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible
func NTILE(numOfBuckets int64) integerWindowExpression {
return newIntegerWindowFunc("NTILE", FixedLiteral(numOfBuckets))
}
// LAG returns value evaluated at the row that is offset rows before the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
func LAG(expr Expression, offsetAndDefault ...interface{}) windowExpression {
return leadLagImpl("LAG", expr, offsetAndDefault...)
}
// LEAD returns value evaluated at the row that is offset rows after the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
func LEAD(expr Expression, offsetAndDefault ...interface{}) windowExpression {
return leadLagImpl("LEAD", expr, offsetAndDefault...)
}
// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame
func FIRST_VALUE(value Expression) windowExpression {
return newWindowFunc("FIRST_VALUE", value)
}
// LAST_VALUE returns value evaluated at the row that is the last row of the window frame
func LAST_VALUE(value Expression) windowExpression {
return newWindowFunc("LAST_VALUE", value)
}
// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row
func NTH_VALUE(value Expression, nth int64) windowExpression {
return newWindowFunc("NTH_VALUE", value, FixedLiteral(nth))
}
func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{}) windowExpression {
params := []Expression{expr}
if len(offsetAndDefault) >= 2 {
offset, ok := offsetAndDefault[0].(int)
if !ok {
panic("jet: LAG offset should be an integer")
}
var defaultValue Expression
defaultValue, ok = offsetAndDefault[1].(Expression)
if !ok {
defaultValue = literal(offsetAndDefault[1])
}
params = append(params, FixedLiteral(offset), defaultValue)
}
return newWindowFunc(name, params...)
}
//------------ String functions ------------------//
@ -349,7 +435,7 @@ func TO_HEX(number IntegerExpression) StringExpression {
// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise.
func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType ...string) BoolExpression {
if len(matchType) > 0 {
return newBoolFunc("REGEXP_LIKE", stringExp, pattern, ConstLiteral(matchType[0]))
return newBoolFunc("REGEXP_LIKE", stringExp, pattern, FixedLiteral(matchType[0]))
}
return newBoolFunc("REGEXP_LIKE", stringExp, pattern)
@ -391,7 +477,7 @@ func CURRENT_TIME(precision ...int) TimezExpression {
var timezFunc *timezFunc
if len(precision) > 0 {
timezFunc = newTimezFunc("CURRENT_TIME", ConstLiteral(precision[0]))
timezFunc = newTimezFunc("CURRENT_TIME", FixedLiteral(precision[0]))
} else {
timezFunc = newTimezFunc("CURRENT_TIME")
}
@ -406,7 +492,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression {
var timestampzFunc *timestampzFunc
if len(precision) > 0 {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", ConstLiteral(precision[0]))
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", FixedLiteral(precision[0]))
} else {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP")
}
@ -421,7 +507,7 @@ func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc
if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", ConstLiteral(precision[0]))
timeFunc = newTimeFunc("LOCALTIME", FixedLiteral(precision[0]))
} else {
timeFunc = newTimeFunc("LOCALTIME")
}
@ -436,7 +522,7 @@ func LOCALTIMESTAMP(precision ...int) TimestampExpression {
var timestampFunc *timestampFunc
if len(precision) > 0 {
timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", ConstLiteral(precision[0]))
timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", FixedLiteral(precision[0]))
} else {
timestampFunc = NewTimestampFunc("LOCALTIMESTAMP")
}
@ -504,6 +590,16 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
return funcExp
}
// NewFloatWindowFunc creates new float function with name and expressions
func newWindowFunc(name string, expressions ...Expression) windowExpression {
newFun := newFunc(name, expressions, nil)
windowExpr := newWindowExpression(newFun)
newFun.expressionInterfaceImpl.Parent = windowExpr
return windowExpr
}
func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(f.expressions...)
@ -536,10 +632,23 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression {
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc)
boolFunc.boolInterfaceImpl.parent = boolFunc
boolFunc.expressionInterfaceImpl.Parent = boolFunc
return boolFunc
}
// NewFloatWindowFunc creates new float function with name and expressions
func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression {
boolFunc := &boolFunc{}
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc)
intWindowFunc := newBoolWindowExpression(boolFunc)
boolFunc.boolInterfaceImpl.parent = intWindowFunc
boolFunc.expressionInterfaceImpl.Parent = intWindowFunc
return intWindowFunc
}
type floatFunc struct {
funcExpressionImpl
floatInterfaceImpl
@ -555,6 +664,18 @@ func NewFloatFunc(name string, expressions ...Expression) FloatExpression {
return floatFunc
}
// NewFloatWindowFunc creates new float function with name and expressions
func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression {
floatFunc := &floatFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc)
floatWindowFunc := newFloatWindowExpression(floatFunc)
floatFunc.floatInterfaceImpl.parent = floatWindowFunc
floatFunc.expressionInterfaceImpl.Parent = floatWindowFunc
return floatWindowFunc
}
type integerFunc struct {
funcExpressionImpl
integerInterfaceImpl
@ -569,6 +690,18 @@ func newIntegerFunc(name string, expressions ...Expression) IntegerExpression {
return floatFunc
}
// NewFloatWindowFunc creates new float function with name and expressions
func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression {
integerFunc := &integerFunc{}
integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc)
intWindowFunc := newIntegerWindowExpression(integerFunc)
integerFunc.integerInterfaceImpl.parent = intWindowFunc
integerFunc.expressionInterfaceImpl.Parent = intWindowFunc
return intWindowFunc
}
type stringFunc struct {
funcExpressionImpl
stringInterfaceImpl

View file

@ -32,8 +32,8 @@ func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl
return &exp
}
// ConstLiteral is injected directly to SQL query, and does not appear in argument list.
func ConstLiteral(value interface{}) *literalExpressionImpl {
// FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list.
func FixedLiteral(value interface{}) *literalExpressionImpl {
exp := literal(value)
exp.constant = true

View file

@ -11,16 +11,9 @@ func TestArgToString(t *testing.T) {
assert.Equal(t, argToString(true), "TRUE")
assert.Equal(t, argToString(false), "FALSE")
assert.Equal(t, argToString(int8(-8)), "-8")
assert.Equal(t, argToString(int16(-16)), "-16")
assert.Equal(t, argToString(int(-32)), "-32")
assert.Equal(t, argToString(int32(-32)), "-32")
assert.Equal(t, argToString(int64(-64)), "-64")
assert.Equal(t, argToString(uint8(8)), "8")
assert.Equal(t, argToString(uint16(16)), "16")
assert.Equal(t, argToString(uint(32)), "32")
assert.Equal(t, argToString(uint32(32)), "32")
assert.Equal(t, argToString(uint64(64)), "64")
assert.Equal(t, argToString(float64(1.11)), "1.11")
assert.Equal(t, argToString("john"), "'john'")
@ -31,5 +24,12 @@ func TestArgToString(t *testing.T) {
time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006")
assert.NilError(t, err)
assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'")
assert.Equal(t, argToString(map[string]bool{}), "[Unsupported type]")
func() {
defer func() {
assert.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter")
}()
argToString(map[string]bool{})
}()
}

View file

@ -2,8 +2,11 @@ package jet
import (
"bytes"
"fmt"
"github.com/go-jet/jet/internal/3rdparty/pq"
"github.com/go-jet/jet/internal/utils"
"github.com/google/uuid"
"reflect"
"strconv"
"strings"
"time"
@ -139,28 +142,13 @@ func argToString(value interface{}) string {
return "TRUE"
}
return "FALSE"
case int8:
return strconv.FormatInt(int64(bindVal), 10)
case int:
return strconv.FormatInt(int64(bindVal), 10)
case int16:
return strconv.FormatInt(int64(bindVal), 10)
case int32:
return strconv.FormatInt(int64(bindVal), 10)
case int64:
return strconv.FormatInt(bindVal, 10)
case uint8:
return strconv.FormatUint(uint64(bindVal), 10)
case uint:
return strconv.FormatUint(uint64(bindVal), 10)
case uint16:
return strconv.FormatUint(uint64(bindVal), 10)
case uint32:
return strconv.FormatUint(uint64(bindVal), 10)
case uint64:
return strconv.FormatUint(uint64(bindVal), 10)
case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case float64:
@ -173,9 +161,9 @@ func argToString(value interface{}) string {
case uuid.UUID:
return stringQuote(bindVal.String())
case time.Time:
return stringQuote(string(utils.FormatTimestamp(bindVal)))
return stringQuote(string(pq.FormatTimestamp(bindVal)))
default:
return "[Unsupported type]"
panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String()))
}
}

View file

@ -19,15 +19,17 @@ type Table interface {
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, columns ...ColumnExpression) SerializerTable {
func NewTable(schemaName, name string, column ColumnExpression, columns ...ColumnExpression) SerializerTable {
columnList := append([]ColumnExpression{column}, columns...)
t := tableImpl{
schemaName: schemaName,
name: name,
columnList: columns,
columnList: columnList,
}
for _, c := range columns {
for _, c := range columnList {
c.setTableName(name)
}

View file

@ -0,0 +1,33 @@
package jet
import (
"gotest.tools/assert"
"testing"
)
func TestNewTable(t *testing.T) {
newTable := NewTable("schema", "table", IntegerColumn("intCol"))
assert.Equal(t, newTable.SchemaName(), "schema")
assert.Equal(t, newTable.TableName(), "table")
assert.Equal(t, len(newTable.columns()), 1)
assert.Equal(t, newTable.columns()[0].Name(), "intCol")
}
func TestNewJoinTable(t *testing.T) {
newTable1 := NewTable("schema", "table", IntegerColumn("intCol1"))
newTable2 := NewTable("schema", "table2", IntegerColumn("intCol2"))
joinTable := NewJoinTable(newTable1, newTable2, InnerJoin, IntegerColumn("intCol1").EQ(IntegerColumn("intCol2")))
assertClauseSerialize(t, joinTable, `schema.table
INNER JOIN schema.table2 ON ("intCol1" = "intCol2")`)
assert.Equal(t, joinTable.SchemaName(), "schema")
assert.Equal(t, joinTable.TableName(), "")
assert.Equal(t, len(joinTable.columns()), 2)
assert.Equal(t, joinTable.columns()[0].Name(), "intCol1")
assert.Equal(t, joinTable.columns()[1].Name(), "intCol2")
}

View file

@ -0,0 +1,146 @@
package jet
type commonWindowImpl struct {
expression Expression
window Window
}
func (w *commonWindowImpl) over(window ...Window) {
if len(window) > 0 {
w.window = window[0]
} else {
w.window = newWindowImpl(nil)
}
}
func (w *commonWindowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
w.expression.serialize(statement, out)
if w.window != nil {
out.WriteString("OVER")
w.window.serialize(statement, out)
}
}
// --------------------------------------
type windowExpression interface {
Expression
OVER(window ...Window) Expression
}
func newWindowExpression(Exp Expression) windowExpression {
newExp := &windowExpressionImpl{
Expression: Exp,
}
newExp.commonWindowImpl.expression = Exp
return newExp
}
type windowExpressionImpl struct {
Expression
commonWindowImpl
}
func (f *windowExpressionImpl) OVER(window ...Window) Expression {
f.commonWindowImpl.over(window...)
return f
}
func (f *windowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}
// -----------------------------------------------------
type floatWindowExpression interface {
FloatExpression
OVER(window ...Window) FloatExpression
}
func newFloatWindowExpression(floatExp FloatExpression) floatWindowExpression {
newExp := &floatWindowExpressionImpl{
FloatExpression: floatExp,
}
newExp.commonWindowImpl.expression = floatExp
return newExp
}
type floatWindowExpressionImpl struct {
FloatExpression
commonWindowImpl
}
func (f *floatWindowExpressionImpl) OVER(window ...Window) FloatExpression {
f.commonWindowImpl.over(window...)
return f
}
func (f *floatWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}
// ------------------------------------------------
type integerWindowExpression interface {
IntegerExpression
OVER(window ...Window) IntegerExpression
}
func newIntegerWindowExpression(intExp IntegerExpression) integerWindowExpression {
newExp := &integerWindowExpressionImpl{
IntegerExpression: intExp,
}
newExp.commonWindowImpl.expression = intExp
return newExp
}
type integerWindowExpressionImpl struct {
IntegerExpression
commonWindowImpl
}
func (f *integerWindowExpressionImpl) OVER(window ...Window) IntegerExpression {
f.commonWindowImpl.over(window...)
return f
}
func (f *integerWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}
// ------------------------------------------------
type boolWindowExpression interface {
BoolExpression
OVER(window ...Window) BoolExpression
}
func newBoolWindowExpression(boolExp BoolExpression) boolWindowExpression {
newExp := &boolWindowExpressionImpl{
BoolExpression: boolExp,
}
newExp.commonWindowImpl.expression = boolExp
return newExp
}
type boolWindowExpressionImpl struct {
BoolExpression
commonWindowImpl
}
func (f *boolWindowExpressionImpl) OVER(window ...Window) BoolExpression {
f.commonWindowImpl.over(window...)
return f
}
func (f *boolWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}

186
internal/jet/window_func.go Normal file
View file

@ -0,0 +1,186 @@
package jet
// Window interface
type Window interface {
Serializer
ORDER_BY(expr ...OrderByClause) Window
ROWS(start FrameExtent, end ...FrameExtent) Window
RANGE(start FrameExtent, end ...FrameExtent) Window
GROUPS(start FrameExtent, end ...FrameExtent) Window
}
type windowImpl struct {
partitionBy []Expression
orderBy ClauseOrderBy
frameUnits string
start, end FrameExtent
parent Window
}
func newWindowImpl(parent Window) *windowImpl {
newWindow := &windowImpl{}
if parent == nil {
newWindow.parent = newWindow
} else {
newWindow.parent = parent
}
return newWindow
}
func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, noWrap) {
out.WriteByte('(')
}
if w.partitionBy != nil {
out.WriteString("PARTITION BY")
serializeExpressionList(statement, w.partitionBy, ", ", out)
}
w.orderBy.SkipNewLine = true
w.orderBy.Serialize(statement, out)
if w.frameUnits != "" {
out.WriteString(w.frameUnits)
if w.end == nil {
w.start.serialize(statement, out)
} else {
out.WriteString("BETWEEN")
w.start.serialize(statement, out)
out.WriteString("AND")
w.end.serialize(statement, out)
}
}
if !contains(options, noWrap) {
out.WriteByte(')')
}
}
func (w *windowImpl) ORDER_BY(exprs ...OrderByClause) Window {
w.orderBy.List = exprs
return w.parent
}
func (w *windowImpl) ROWS(start FrameExtent, end ...FrameExtent) Window {
w.frameUnits = "ROWS"
w.setFrameRange(start, end...)
return w.parent
}
func (w *windowImpl) RANGE(start FrameExtent, end ...FrameExtent) Window {
w.frameUnits = "RANGE"
w.setFrameRange(start, end...)
return w.parent
}
func (w *windowImpl) GROUPS(start FrameExtent, end ...FrameExtent) Window {
w.frameUnits = "GROUPS"
w.setFrameRange(start, end...)
return w.parent
}
func (w *windowImpl) setFrameRange(start FrameExtent, end ...FrameExtent) {
w.start = start
if len(end) > 0 {
w.end = end[0]
}
}
// PARTITION_BY window function constructor
func PARTITION_BY(exp Expression, exprs ...Expression) Window {
funImpl := newWindowImpl(nil)
funImpl.partitionBy = append([]Expression{exp}, exprs...)
return funImpl
}
// ORDER_BY window function constructor
func ORDER_BY(expr ...OrderByClause) Window {
funImpl := newWindowImpl(nil)
funImpl.orderBy.List = expr
return funImpl
}
// -----------------------------------------------
// FrameExtent interface
type FrameExtent interface {
Serializer
isFrameExtent()
}
// PRECEDING window frame clause
func PRECEDING(offset Serializer) FrameExtent {
return &frameExtentImpl{
preceding: true,
offset: offset,
}
}
// FOLLOWING window frame clause
func FOLLOWING(offset Serializer) FrameExtent {
return &frameExtentImpl{
preceding: false,
offset: offset,
}
}
type frameExtentImpl struct {
preceding bool
offset Serializer
}
func (f *frameExtentImpl) isFrameExtent() {}
func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if f == nil {
return
}
f.offset.serialize(statement, out)
if f.preceding {
out.WriteString("PRECEDING")
} else {
out.WriteString("FOLLOWING")
}
}
// -----------------------------------------------
// Window function keywords
var (
UNBOUNDED = keywordClause("UNBOUNDED")
CURRENT_ROW = frameExtentKeyword{"CURRENT ROW"}
)
type frameExtentKeyword struct {
keywordClause
}
func (f frameExtentKeyword) isFrameExtent() {}
// -----------------------------------------------
// WindowName is used to specify window reference from WINDOW clause
func WindowName(name string) Window {
newWindow := &windowName{name: name}
newWindow.parent = newWindow
return newWindow
}
type windowName struct {
windowImpl
name string
}
func (w windowName) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteByte('(')
out.WriteString(w.name)
w.windowImpl.serialize(statement, out, noWrap)
out.WriteByte(')')
}

View file

@ -0,0 +1,21 @@
package jet
import "testing"
func TestFrameExtent(t *testing.T) {
assertClauseSerialize(t, PRECEDING(Int(2)), "$1 PRECEDING", int64(2))
assertClauseSerialize(t, FOLLOWING(Int(4)), "$1 FOLLOWING", int64(4))
}
func TestWindowFunctions(t *testing.T) {
assertClauseSerialize(t, PARTITION_BY(table1Col1), "(PARTITION BY table1.col1)")
assertClauseSerialize(t, PARTITION_BY(table1Col3).ORDER_BY(table1Col1), "(PARTITION BY table1.col3 ORDER BY table1.col1)")
assertClauseSerialize(t, ORDER_BY(table1Col1), "(ORDER BY table1.col1)")
assertClauseSerialize(t, ORDER_BY(table1Col1).ROWS(PRECEDING(Int(1))), "(ORDER BY table1.col1 ROWS $1 PRECEDING)", int64(1))
assertClauseSerialize(t, ORDER_BY(table1Col1).ROWS(PRECEDING(Int(1)), FOLLOWING(Int(33))),
"(ORDER BY table1.col1 ROWS BETWEEN $1 PRECEDING AND $2 FOLLOWING)", int64(1), int64(33))
assertClauseSerialize(t, ORDER_BY(table1Col1).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)),
"(ORDER BY table1.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)")
assertClauseSerialize(t, ORDER_BY(table1Col1).RANGE(PRECEDING(UNBOUNDED), CURRENT_ROW),
"(ORDER BY table1.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)")
}

View file

@ -6,6 +6,7 @@ import (
"fmt"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/utils"
"gotest.tools/assert"
"io/ioutil"
"os"
@ -60,9 +61,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath)
err := ioutil.WriteFile(filePath, jsonText, 0644)
if err != nil {
panic(err)
}
utils.PanicOnError(err)
}
// AssertJSONFile check if data json representation is the same as json at testRelativePath
@ -159,3 +158,31 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db execution.DB, dest
stmt.Query(db, dest)
}
// AssertFileContent check if file content at filePath contains expectedContent text.
func AssertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err)
beginIndex := bytes.Index(enumFileData, []byte(contentBegin))
//fmt.Println("-"+string(enumFileData[beginIndex:])+"-")
assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent)
}
// AssertFileNamesEqual check if all filesInfos are contained in fileNames
func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) {
assert.Equal(t, len(fileInfos), len(fileNames))
fileNamesMap := map[string]bool{}
for _, fileInfo := range fileInfos {
fileNamesMap[fileInfo.Name()] = true
}
for _, fileName := range fileNames {
assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.")
}
}

View file

@ -1,6 +1,7 @@
package testutils
import (
"github.com/go-jet/jet/internal/utils"
"strings"
"time"
)
@ -9,9 +10,7 @@ import (
func Date(t string) *time.Time {
newTime, err := time.Parse("2006-01-02", t)
if err != nil {
panic(err)
}
utils.PanicOnError(err)
return &newTime
}
@ -27,9 +26,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time {
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000")
if err != nil {
panic(err)
}
utils.PanicOnError(err)
return &newTime
}
@ -38,9 +35,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time {
func TimeWithoutTimeZone(t string) *time.Time {
newTime, err := time.Parse("15:04:05", t)
if err != nil {
panic(err)
}
utils.PanicOnError(err)
return &newTime
}
@ -49,9 +44,7 @@ func TimeWithoutTimeZone(t string) *time.Time {
func TimeWithTimeZone(t string) *time.Time {
newTimez, err := time.Parse("15:04:05 -0700", t)
if err != nil {
panic(err)
}
utils.PanicOnError(err)
return &newTimez
}
@ -67,9 +60,7 @@ func TimestampWithTimeZone(t string, precision int) *time.Time {
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t)
if err != nil {
panic(err)
}
utils.PanicOnError(err)
return &newTime
}

View file

@ -2,14 +2,13 @@ package utils
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/internal/3rdparty/snaker"
"go/format"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"time"
)
// ToGoIdentifier converts database to Go identifier.
@ -104,44 +103,11 @@ func DirExists(path string) (bool, error) {
func replaceInvalidChars(str string) string {
str = strings.Replace(str, " ", "_", -1)
str = strings.Replace(str, "-", "_", -1)
str = strings.Replace(str, ".", "_", -1)
return str
}
// FormatTimestamp formats t into Postgres' text format for timestamps. From: github.com/lib/pq
func FormatTimestamp(t time.Time) []byte {
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
// minus sign preferred by Go.
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
bc := false
if t.Year() <= 0 {
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
t = t.AddDate((-t.Year())*2+1, 0, 0)
bc = true
}
b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
_, offset := t.Zone()
offset = offset % 60
if offset != 0 {
// RFC3339Nano already printed the minus sign
if offset < 0 {
offset = -offset
}
b = append(b, ':')
if offset < 10 {
b = append(b, '0')
}
b = strconv.AppendInt(b, int64(offset), 10)
}
if bc {
b = append(b, " BC"...)
}
return b
}
// IsNil check if v is nil
func IsNil(v interface{}) bool {
return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil())
@ -174,3 +140,27 @@ func MustBeInitializedPtr(val interface{}, errorStr string) {
panic(errorStr)
}
}
// PanicOnError panics if err is not nil
func PanicOnError(err error) {
if err != nil {
panic(err)
}
}
// ErrorCatch is used in defer to recover from panics and to set err
func ErrorCatch(err *error) {
recovered := recover()
if recovered == nil {
return
}
recoveredErr, isError := recovered.(error)
if isError {
*err = recoveredErr
} else {
*err = fmt.Errorf("%v", recovered)
}
}

View file

@ -1,6 +1,7 @@
package utils
import (
"fmt"
"gotest.tools/assert"
"testing"
)
@ -23,3 +24,27 @@ func TestToGoIdentifier(t *testing.T) {
assert.Equal(t, ToGoIdentifier("My Table"), "MyTable")
assert.Equal(t, ToGoIdentifier("My-Table"), "MyTable")
}
func TestErrorCatchErr(t *testing.T) {
var err error
func() {
defer ErrorCatch(&err)
panic(fmt.Errorf("newError"))
}()
assert.Error(t, err, "newError")
}
func TestErrorCatchNonErr(t *testing.T) {
var err error
func() {
defer ErrorCatch(&err)
panic(11)
}()
assert.Error(t, err, "11")
}

View file

@ -10,13 +10,13 @@ var Dialect = newDialect()
func newDialect() jet.Dialect {
operatorSerializeOverrides := map[string]jet.SerializeOverride{}
operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysql_REGEXP_LIKE_operator
operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysql_NOT_REGEXP_LIKE_operator
operatorSerializeOverrides["IS DISTINCT FROM"] = mysql_IS_DISTINCT_FROM
operatorSerializeOverrides["IS NOT DISTINCT FROM"] = mysql_IS_NOT_DISTINCT_FROM
operatorSerializeOverrides["/"] = mysql_DIVISION
operatorSerializeOverrides["#"] = mysql_BIT_XOR
operatorSerializeOverrides[jet.StringConcatOperator] = mysql_CONCAT_operator
operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysqlREGEXPLIKEoperator
operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysqlNOTREGEXPLIKEoperator
operatorSerializeOverrides["IS DISTINCT FROM"] = mysqlISDISTINCTFROM
operatorSerializeOverrides["IS NOT DISTINCT FROM"] = mysqlISNOTDISTINCTFROM
operatorSerializeOverrides["/"] = mysqlDivision
operatorSerializeOverrides["#"] = mysqlBitXor
operatorSerializeOverrides[jet.StringConcatOperator] = mysqlCONCAToperator
mySQLDialectParams := jet.DialectParams{
Name: "MySQL",
@ -32,7 +32,7 @@ func newDialect() jet.Dialect {
return jet.NewDialect(mySQLDialectParams)
}
func mysql_BIT_XOR(expressions ...jet.Expression) jet.SerializeFunc {
func mysqlBitXor(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator XOR")
@ -49,7 +49,7 @@ func mysql_BIT_XOR(expressions ...jet.Expression) jet.SerializeFunc {
}
}
func mysql_CONCAT_operator(expressions ...jet.Expression) jet.SerializeFunc {
func mysqlCONCAToperator(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator CONCAT")
@ -66,7 +66,7 @@ func mysql_CONCAT_operator(expressions ...jet.Expression) jet.SerializeFunc {
}
}
func mysql_DIVISION(expressions ...jet.Expression) jet.SerializeFunc {
func mysqlDivision(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator DIV")
@ -90,7 +90,7 @@ func mysql_DIVISION(expressions ...jet.Expression) jet.SerializeFunc {
}
}
func mysql_IS_NOT_DISTINCT_FROM(expressions ...jet.Expression) jet.SerializeFunc {
func mysqlISNOTDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")
@ -102,15 +102,15 @@ func mysql_IS_NOT_DISTINCT_FROM(expressions ...jet.Expression) jet.SerializeFunc
}
}
func mysql_IS_DISTINCT_FROM(expressions ...jet.Expression) jet.SerializeFunc {
func mysqlISDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
out.WriteString("NOT(")
mysql_IS_NOT_DISTINCT_FROM(expressions...)(statement, out, options...)
mysqlISNOTDISTINCTFROM(expressions...)(statement, out, options...)
out.WriteString(")")
}
}
func mysql_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc {
func mysqlREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")
@ -136,7 +136,7 @@ func mysql_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc
}
}
func mysql_NOT_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc {
func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")

View file

@ -85,6 +85,47 @@ var SUMi = jet.SUMi
// SUMf is aggregate function. Returns sum of float expression.
var SUMf = jet.SUMf
// -------------------- Window functions -----------------------//
// ROW_NUMBER returns number of the current row within its partition, counting from 1
var ROW_NUMBER = jet.ROW_NUMBER
// RANK of the current row with gaps; same as row_number of its first peer
var RANK = jet.RANK
// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups
var DENSE_RANK = jet.DENSE_RANK
// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1)
var PERCENT_RANK = jet.PERCENT_RANK
// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows
var CUME_DIST = jet.CUME_DIST
// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible
var NTILE = jet.NTILE
// LAG returns value evaluated at the row that is offset rows before the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LAG = jet.LAG
// LEAD returns value evaluated at the row that is offset rows after the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LEAD = jet.LEAD
// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame
var FIRST_VALUE = jet.FIRST_VALUE
// LAST_VALUE returns value evaluated at the row that is the last row of the window frame
var LAST_VALUE = jet.LAST_VALUE
// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row
var NTH_VALUE = jet.NTH_VALUE
//--------------------- String functions ------------------//
// BIT_LENGTH returns number of bits in string expression
@ -181,7 +222,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampExpression {
// NOW returns current datetime
func NOW(fsp ...int) DateTimeExpression {
if len(fsp) > 0 {
return jet.NewTimestampFunc("NOW", jet.ConstLiteral(int64(fsp[0])))
return jet.NewTimestampFunc("NOW", jet.FixedLiteral(int64(fsp[0])))
}
return jet.NewTimestampFunc("NOW")
}

View file

@ -1,6 +1,8 @@
package mysql
import "github.com/go-jet/jet/internal/jet"
import (
"github.com/go-jet/jet/internal/jet"
)
// RowLock is interface for SELECT statement row lock types
type RowLock = jet.RowLock
@ -11,6 +13,27 @@ var (
SHARE = jet.NewRowLock("SHARE")
)
// Window function clauses
var (
PARTITION_BY = jet.PARTITION_BY
ORDER_BY = jet.ORDER_BY
UNBOUNDED = jet.UNBOUNDED
CURRENT_ROW = jet.CURRENT_ROW
)
// PRECEDING window frame clause
func PRECEDING(offset interface{}) jet.FrameExtent {
return jet.PRECEDING(toJetFrameOffset(offset))
}
// FOLLOWING window frame clause
func FOLLOWING(offset interface{}) jet.FrameExtent {
return jet.FOLLOWING(toJetFrameOffset(offset))
}
// Window is used to specify window reference from WINDOW clause
var Window = jet.WindowName
// SelectStatement is interface for MySQL SELECT statement
type SelectStatement interface {
Statement
@ -22,6 +45,7 @@ type SelectStatement interface {
WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement
WINDOW(name string) windowExpand
ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement
LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement
@ -42,7 +66,7 @@ func SELECT(projection Projection, projections ...Projection) SelectStatement {
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
newSelect := &selectStatementImpl{}
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select,
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy,
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy,
&newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock)
newSelect.Select.Projections = toJetProjectionList(projections)
@ -66,6 +90,7 @@ type selectStatementImpl struct {
Where jet.ClauseWhere
GroupBy jet.ClauseGroupBy
Having jet.ClauseHaving
Window jet.ClauseWindow
OrderBy jet.ClauseOrderBy
Limit jet.ClauseLimit
Offset jet.ClauseOffset
@ -98,6 +123,11 @@ func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatem
return s
}
func (s *selectStatementImpl) WINDOW(name string) windowExpand {
s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name})
return windowExpand{selectStatement: s}
}
func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement {
s.OrderBy.List = orderByClauses
return s
@ -126,3 +156,31 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement {
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
}
//-----------------------------------------------------
type windowExpand struct {
selectStatement *selectStatementImpl
}
func (w windowExpand) AS(window ...jet.Window) SelectStatement {
if len(window) == 0 {
return w.selectStatement
}
windowsDefinition := w.selectStatement.Window.Definitions
windowsDefinition[len(windowsDefinition)-1].Window = window[0]
return w.selectStatement
}
func toJetFrameOffset(offset interface{}) jet.Serializer {
if offset == UNBOUNDED {
return jet.UNBOUNDED
}
// check for interval expression
//if exp, ok := offset.(Expression); ok {
// return exp
//}
return jet.FixedLiteral(offset)
}

View file

@ -77,9 +77,9 @@ func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectU
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table {
func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table {
t := &tableImpl{
SerializerTable: jet.NewTable(schemaName, name, columns...),
SerializerTable: jet.NewTable(schemaName, name, column, columns...),
}
t.readableTableInterfaceImpl.parent = t

View file

@ -11,8 +11,8 @@ var Dialect = newDialect()
func newDialect() jet.Dialect {
operatorSerializeOverrides := map[string]jet.SerializeOverride{}
operatorSerializeOverrides[jet.StringRegexpLikeOperator] = postgres_REGEXP_LIKE_operator
operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = postgres_NOT_REGEXP_LIKE_operator
operatorSerializeOverrides[jet.StringRegexpLikeOperator] = postgresREGEXPLIKEoperator
operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = postgresNOTREGEXPLIKEoperator
operatorSerializeOverrides["CAST"] = postgresCAST
dialectParams := jet.DialectParams{
@ -54,7 +54,7 @@ func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc {
}
}
func postgres_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc {
func postgresREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")
@ -80,7 +80,7 @@ func postgres_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeF
}
}
func postgres_NOT_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc {
func postgresNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")

View file

@ -87,6 +87,47 @@ var SUMf = jet.SUMf
// SUMi is aggregate function. Returns sum of expression across all integer expression.
var SUMi = jet.SUMi
// -------------------- Window functions -----------------------//
// ROW_NUMBER returns number of the current row within its partition, counting from 1
var ROW_NUMBER = jet.ROW_NUMBER
// RANK of the current row with gaps; same as row_number of its first peer
var RANK = jet.RANK
// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups
var DENSE_RANK = jet.DENSE_RANK
// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1)
var PERCENT_RANK = jet.PERCENT_RANK
// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows
var CUME_DIST = jet.CUME_DIST
// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible
var NTILE = jet.NTILE
// LAG returns value evaluated at the row that is offset rows before the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LAG = jet.LAG
// LEAD returns value evaluated at the row that is offset rows after the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LEAD = jet.LEAD
// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame
var FIRST_VALUE = jet.FIRST_VALUE
// LAST_VALUE returns value evaluated at the row that is the last row of the window frame
var LAST_VALUE = jet.LAST_VALUE
// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row
var NTH_VALUE = jet.NTH_VALUE
//--------------------- String functions ------------------//
// BIT_LENGTH returns number of bits in string expression

View file

@ -1,6 +1,9 @@
package postgres
import "github.com/go-jet/jet/internal/jet"
import (
"github.com/go-jet/jet/internal/jet"
"math"
)
// RowLock is interface for SELECT statement row lock types
type RowLock = jet.RowLock
@ -13,6 +16,27 @@ var (
KEY_SHARE = jet.NewRowLock("KEY SHARE")
)
// Window function clauses
var (
PARTITION_BY = jet.PARTITION_BY
ORDER_BY = jet.ORDER_BY
UNBOUNDED = int64(math.MaxInt64)
CURRENT_ROW = jet.CURRENT_ROW
)
// PRECEDING window frame clause
func PRECEDING(offset int64) jet.FrameExtent {
return jet.PRECEDING(toJetFrameOffset(offset))
}
// FOLLOWING window frame clause
func FOLLOWING(offset int64) jet.FrameExtent {
return jet.FOLLOWING(toJetFrameOffset(offset))
}
// Window definition reference
var Window = jet.WindowName
// SelectStatement is interface for PostgreSQL SELECT statement
type SelectStatement interface {
Statement
@ -24,6 +48,7 @@ type SelectStatement interface {
WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement
WINDOW(name string) windowExpand
ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement
LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement
@ -47,15 +72,9 @@ func SELECT(projection Projection, projections ...Projection) SelectStatement {
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
newSelect := &selectStatementImpl{}
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select,
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy,
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy,
&newSelect.Limit, &newSelect.Offset, &newSelect.For)
// statementImpl = jet.NewStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select,
// &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy,
// &newSelect.Limit, &newSelect.Offset, &newSelect.For)
//
//newSelect.expressionStatementImpl.expressionInterfaceImpl.Parent = newSelect
newSelect.Select.Projections = toJetProjectionList(projections)
newSelect.From.Table = table
newSelect.Limit.Count = -1
@ -75,6 +94,7 @@ type selectStatementImpl struct {
Where jet.ClauseWhere
GroupBy jet.ClauseGroupBy
Having jet.ClauseHaving
Window jet.ClauseWindow
OrderBy jet.ClauseOrderBy
Limit jet.ClauseLimit
Offset jet.ClauseOffset
@ -106,6 +126,11 @@ func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatem
return s
}
func (s *selectStatementImpl) WINDOW(name string) windowExpand {
s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name})
return windowExpand{selectStatement: s}
}
func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement {
s.OrderBy.List = orderByClauses
return s
@ -129,3 +154,25 @@ func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement {
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
}
//-----------------------------------------------------
type windowExpand struct {
selectStatement *selectStatementImpl
}
func (w windowExpand) AS(window ...jet.Window) SelectStatement {
if len(window) == 0 {
return w.selectStatement
}
windowsDefinition := w.selectStatement.Window.Definitions
windowsDefinition[len(windowsDefinition)-1].Window = window[0]
return w.selectStatement
}
func toJetFrameOffset(offset int64) jet.Serializer {
if offset == UNBOUNDED {
return jet.UNBOUNDED
}
return jet.FixedLiteral(offset)
}

View file

@ -109,10 +109,10 @@ type tableImpl struct {
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table {
func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table {
t := &tableImpl{
SerializerTable: jet.NewTable(schemaName, name, columns...),
SerializerTable: jet.NewTable(schemaName, name, column, columns...),
}
t.readableTableInterfaceImpl.parent = t

View file

@ -6,6 +6,7 @@ import (
"fmt"
"github.com/go-jet/jet/generator/mysql"
"github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/tests/dbconfig"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
@ -60,7 +61,7 @@ func initMySQLDB() {
cmd.Stdout = os.Stdout
err := cmd.Run()
panicOnError(err)
utils.PanicOnError(err)
err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{
Host: dbconfig.MySqLHost,
@ -70,7 +71,7 @@ func initMySQLDB() {
DBName: dbName,
})
panicOnError(err)
utils.PanicOnError(err)
}
}
@ -104,22 +105,16 @@ func initPostgresDB() {
SchemaName: schemaName,
SslMode: "disable",
})
panicOnError(err)
utils.PanicOnError(err)
}
}
func execFile(db *sql.DB, sqlFilePath string) {
testSampleSql, err := ioutil.ReadFile(sqlFilePath)
panicOnError(err)
utils.PanicOnError(err)
_, err = db.Exec(string(testSampleSql))
panicOnError(err)
}
func panicOnError(err error) {
if err != nil {
panic(err)
}
utils.PanicOnError(err)
}
func printOnError(err error) {

View file

@ -5,6 +5,7 @@ import (
"github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/view"
"github.com/go-jet/jet/tests/testdata/results/common"
"github.com/google/uuid"
"time"
@ -36,6 +37,23 @@ func TestAllTypes(t *testing.T) {
testutils.AssertJSON(t, dest, allTypesJson)
}
func TestAllTypesViewSelect(t *testing.T) {
type AllTypesView model.AllTypes
dest := []AllTypesView{}
err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert
return
}
testutils.AssertJSON(t, dest, allTypesJson)
}
func TestUUID(t *testing.T) {
query := AllTypes.

View file

@ -1,8 +1,8 @@
package mysql
import (
"bytes"
"github.com/go-jet/jet/generator/mysql"
"github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/dbconfig"
"gotest.tools/assert"
"io/ioutil"
@ -15,22 +15,22 @@ const genTestDirRoot = "./.gentestdata3"
const genTestDir3 = "./.gentestdata3/mysql"
func TestGenerator(t *testing.T) {
err := os.RemoveAll(genTestDir3)
assert.NilError(t, err)
err = mysql.Generate(genTestDir3, mysql.DBConnection{
Host: dbconfig.MySqLHost,
Port: dbconfig.MySQLPort,
User: dbconfig.MySQLUser,
Password: dbconfig.MySQLPassword,
DBName: "dvds",
})
for i := 0; i < 3; i++ {
err := mysql.Generate(genTestDir3, mysql.DBConnection{
Host: dbconfig.MySqLHost,
Port: dbconfig.MySQLPort,
User: dbconfig.MySQLUser,
Password: dbconfig.MySQLPassword,
DBName: "dvds",
})
assert.NilError(t, err)
assert.NilError(t, err)
assertGeneratedFiles(t)
assertGeneratedFiles(t)
}
err = os.RemoveAll(genTestDirRoot)
err := os.RemoveAll(genTestDirRoot)
assert.NilError(t, err)
}
@ -63,53 +63,40 @@ func assertGeneratedFiles(t *testing.T) {
tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table")
assert.NilError(t, err)
assertFileNameEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go")
assertFileContent(t, genTestDir3+"/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile)
testutils.AssertFileContent(t, genTestDir3+"/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile)
// View SQL Builder files
viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view")
assert.NilError(t, err)
testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilerFile)
// Enums SQL Builder files
enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum")
assert.NilError(t, err)
assertFileNameEqual(t, enumFiles, "film_rating.go")
assertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile)
testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile)
// Model files
modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model")
assert.NilError(t, err)
assertFileNameEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go", "film_rating.go")
testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go",
"film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go",
"actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go",
"customer_list.go", "sales_by_store.go", "staff_list.go")
assertFileContent(t, genTestDir3+"/dvds/model/actor.go", "\npackage model", actorModelFile)
}
func assertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err)
beginIndex := bytes.Index(enumFileData, []byte(contentBegin))
//fmt.Println("-"+string(enumFileData[beginIndex:])+"-")
assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent)
}
func assertFileNameEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) {
fileNamesMap := map[string]bool{}
for _, fileInfo := range fileInfos {
fileNamesMap[fileInfo.Name()] = true
}
for _, fileName := range fileNames {
assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.")
}
testutils.AssertFileContent(t, genTestDir3+"/dvds/model/actor.go", "\npackage model", actorModelFile)
}
var mpaaRatingEnumFile = `
@ -200,3 +187,57 @@ type Actor struct {
LastUpdate time.Time
}
`
var actorInfoSQLBuilerFile = `
package view
import (
"github.com/go-jet/jet/mysql"
)
var ActorInfo = newActorInfoTable()
type ActorInfoTable struct {
mysql.Table
//Columns
ActorID mysql.ColumnInteger
FirstName mysql.ColumnString
LastName mysql.ColumnString
FilmInfo mysql.ColumnString
AllColumns mysql.IColumnList
MutableColumns mysql.IColumnList
}
// creates new ActorInfoTable with assigned alias
func (a *ActorInfoTable) AS(alias string) *ActorInfoTable {
aliasTable := newActorInfoTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newActorInfoTable() *ActorInfoTable {
var (
ActorIDColumn = mysql.IntegerColumn("actor_id")
FirstNameColumn = mysql.StringColumn("first_name")
LastNameColumn = mysql.StringColumn("last_name")
FilmInfoColumn = mysql.StringColumn("film_info")
)
return &ActorInfoTable{
Table: mysql.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
//Columns
ActorID: ActorIDColumn,
FirstName: FirstNameColumn,
LastName: LastNameColumn,
FilmInfo: FilmInfoColumn,
AllColumns: mysql.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
MutableColumns: mysql.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
}
}
`

View file

@ -1,11 +1,13 @@
package mysql
import (
"fmt"
"github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/enum"
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/view"
"gotest.tools/assert"
"testing"
@ -15,7 +17,7 @@ func TestSelect_ScanToStruct(t *testing.T) {
query := Actor.
SELECT(Actor.AllColumns).
DISTINCT().
WHERE(Actor.ActorID.EQ(Int(1)))
WHERE(Actor.ActorID.EQ(Int(2)))
testutils.AssertStatementSql(t, query, `
SELECT DISTINCT actor.actor_id AS "actor.actor_id",
@ -24,20 +26,20 @@ SELECT DISTINCT actor.actor_id AS "actor.actor_id",
actor.last_update AS "actor.last_update"
FROM dvds.actor
WHERE actor.actor_id = ?;
`, int64(1))
`, int64(2))
actor := model.Actor{}
err := query.Query(db, &actor)
assert.NilError(t, err)
assert.DeepEqual(t, actor, actor1)
assert.DeepEqual(t, actor, actor2)
}
var actor1 = model.Actor{
ActorID: 1,
FirstName: "PENELOPE",
LastName: "GUINESS",
var actor2 = model.Actor{
ActorID: 2,
FirstName: "NICK",
LastName: "WAHLBERG",
LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2),
}
@ -61,7 +63,7 @@ ORDER BY actor.actor_id;
assert.NilError(t, err)
assert.Equal(t, len(dest), 200)
assert.DeepEqual(t, dest[0], actor1)
assert.DeepEqual(t, dest[1], actor2)
//testutils.PrintJson(dest)
//testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json")
@ -527,3 +529,172 @@ LOCK IN SHARE MODE;
err := query.Query(db, &struct{}{})
assert.NilError(t, err)
}
func TestWindowFunction(t *testing.T) {
if sourceIsMariaDB() {
return
}
var expectedSQL = `
SELECT AVG(payment.amount) OVER (),
AVG(payment.amount) OVER (PARTITION BY payment.customer_id),
MAX(payment.amount) OVER (ORDER BY payment.payment_date DESC),
MIN(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC ROWS BETWEEN 1 PRECEDING AND 6 FOLLOWING),
SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
MAX(payment.customer_id) OVER (ORDER BY payment.payment_date DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING),
MIN(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
SUM(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
ROW_NUMBER() OVER (ORDER BY payment.payment_date),
RANK() OVER (ORDER BY payment.payment_date),
DENSE_RANK() OVER (ORDER BY payment.payment_date),
CUME_DIST() OVER (ORDER BY payment.payment_date),
NTILE(11) OVER (ORDER BY payment.payment_date),
LAG(payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date),
FIRST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date),
LAST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date),
NTH_VALUE(payment.amount, 3) OVER (ORDER BY payment.payment_date)
FROM dvds.payment
WHERE payment.payment_id < ?
GROUP BY payment.amount, payment.customer_id, payment.payment_date;
`
query := Payment.
SELECT(
AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID)),
MAXf(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate.DESC())),
MINf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).
ORDER_BY(Payment.PaymentDate.DESC()).ROWS(PRECEDING(1), FOLLOWING(6))),
SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).
ORDER_BY(Payment.PaymentDate.DESC()).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
MAXi(Payment.CustomerID).OVER(ORDER_BY(Payment.PaymentDate.DESC()).ROWS(CURRENT_ROW, FOLLOWING(UNBOUNDED))),
MINi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
SUMi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
ROW_NUMBER().OVER(ORDER_BY(Payment.PaymentDate)),
RANK().OVER(ORDER_BY(Payment.PaymentDate)),
DENSE_RANK().OVER(ORDER_BY(Payment.PaymentDate)),
CUME_DIST().OVER(ORDER_BY(Payment.PaymentDate)),
NTILE(11).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)),
FIRST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
NTH_VALUE(Payment.Amount, 3).OVER(ORDER_BY(Payment.PaymentDate)),
).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate).
WHERE(Payment.PaymentID.LT(Int(10)))
fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10))
err := query.Query(db, &struct{}{})
assert.NilError(t, err)
}
func TestWindowClause(t *testing.T) {
var expectedSQL = `
SELECT AVG(payment.amount) OVER (),
AVG(payment.amount) OVER (w1),
AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)
FROM dvds.payment
WHERE payment.payment_id < ?
WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id)
ORDER BY payment.customer_id;
`
query := Payment.SELECT(
AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(Window("w1")),
AVG(Payment.Amount).OVER(
Window("w2").
ORDER_BY(Payment.CustomerID).
RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)),
),
AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
).
WHERE(Payment.PaymentID.LT(Int(10))).
WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)).
WINDOW("w2").AS(Window("w1")).
WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)).
ORDER_BY(Payment.CustomerID)
fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, int64(10))
err := query.Query(db, &struct{}{})
assert.NilError(t, err)
}
func TestSimpleView(t *testing.T) {
query := SELECT(
view.ActorInfo.AllColumns,
).
FROM(view.ActorInfo).
ORDER_BY(view.ActorInfo.ActorID).
LIMIT(10)
type ActorInfo struct {
ActorID int
FirstName string
LastName string
FilmInfo string
}
var dest []ActorInfo
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 10)
testutils.AssertJSON(t, dest[1:2], `
[
{
"ActorID": 2,
"FirstName": "NICK",
"LastName": "WAHLBERG",
"FilmInfo": "Action: BULL SHAWSHANK; Animation: FIGHT JAWBREAKER; Children: JERSEY SASSY; Classics: DRACULA CRYSTAL, GILBERT PELICAN; Comedy: MALLRATS UNITED, RUSHMORE MERMAID; Documentary: ADAPTATION HOLES; Drama: WARDROBE PHANTOM; Family: APACHE DIVINE, CHISUM BEHAVIOR, INDIAN LOVE, MAGUIRE APACHE; Foreign: BABY HALL, HAPPINESS UNITED; Games: ROOF CHAMPION; Music: LUCKY FLYING; New: DESTINY SATURDAY, FLASH WARS, JEKYLL FROGMEN, MASK PEACH; Sci-Fi: CHAINSAW UPTOWN, GOODFELLAS SALUTE; Travel: LIAISONS SWEET, SMILE EARRING"
}
]
`)
}
func TestJoinViewWithTable(t *testing.T) {
query := SELECT(
view.CustomerList.AllColumns,
Rental.AllColumns,
).
FROM(view.CustomerList.
INNER_JOIN(Rental, view.CustomerList.ID.EQ(Rental.CustomerID)),
).
ORDER_BY(view.CustomerList.ID).
WHERE(view.CustomerList.ID.LT_EQ(Int(2)))
var dest []struct {
model.CustomerList `sql:"primary_key=ID"`
Rentals []model.Rental
}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
assert.Equal(t, len(dest[0].Rentals), 32)
assert.Equal(t, len(dest[1].Rentals), 27)
}

View file

@ -6,6 +6,7 @@ import (
. "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/view"
"github.com/go-jet/jet/tests/testdata/results/common"
"github.com/google/uuid"
"gotest.tools/assert"
@ -23,6 +24,19 @@ func TestAllTypesSelect(t *testing.T) {
assert.DeepEqual(t, dest[1], allTypesRow1)
}
func TestAllTypesViewSelect(t *testing.T) {
type AllTypesView model.AllTypes
dest := []AllTypesView{}
err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest[0], AllTypesView(allTypesRow0))
assert.DeepEqual(t, dest[1], AllTypesView(allTypesRow1))
}
func TestAllTypesInsertModel(t *testing.T) {
query := AllTypes.INSERT(AllTypes.AllColumns).
MODEL(allTypesRow0).
@ -31,8 +45,8 @@ func TestAllTypesInsertModel(t *testing.T) {
dest := []model.AllTypes{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0], allTypesRow0)
assert.DeepEqual(t, dest[1], allTypesRow1)

View file

@ -1,8 +1,8 @@
package postgres
import (
"bytes"
"github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/dbconfig"
"gotest.tools/assert"
"io/ioutil"
@ -71,27 +71,26 @@ func TestCmdGenerator(t *testing.T) {
func TestGenerator(t *testing.T) {
for i := 0; i < 3; i++ {
err := postgres.Generate(genTestDir2, postgres.DBConnection{
Host: dbconfig.Host,
Port: dbconfig.Port,
User: dbconfig.User,
Password: dbconfig.Password,
SslMode: "disable",
Params: "",
DBName: dbconfig.DBName,
SchemaName: "dvds",
})
assert.NilError(t, err)
assertGeneratedFiles(t)
}
err := os.RemoveAll(genTestDir2)
assert.NilError(t, err)
err = postgres.Generate(genTestDir2, postgres.DBConnection{
Host: dbconfig.Host,
Port: dbconfig.Port,
User: dbconfig.User,
Password: dbconfig.Password,
SslMode: "disable",
Params: "",
DBName: dbconfig.DBName,
SchemaName: "dvds",
})
assert.NilError(t, err)
assertGeneratedFiles(t)
err = os.RemoveAll(genTestDir2)
assert.NilError(t, err)
}
func assertGeneratedFiles(t *testing.T) {
@ -99,53 +98,39 @@ func assertGeneratedFiles(t *testing.T) {
tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table")
assert.NilError(t, err)
assertFileNameEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile)
testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile)
// View SQL Builder files
viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view")
assert.NilError(t, err)
testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go")
testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilderFile)
// Enums SQL Builder files
enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum")
assert.NilError(t, err)
assertFileNameEqual(t, enumFiles, "mpaa_rating.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile)
testutils.AssertFileNamesEqual(t, enumFiles, "mpaa_rating.go")
testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile)
// Model files
modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model")
assert.NilError(t, err)
assertFileNameEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go")
"payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go",
"actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go",
"customer_list.go", "sales_by_store.go", "staff_list.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", "\npackage model", actorModelFile)
}
func assertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err)
beginIndex := bytes.Index(enumFileData, []byte(contentBegin))
//fmt.Println("-"+string(enumFileData[beginIndex:])+"-")
assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent)
}
func assertFileNameEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) {
fileNamesMap := map[string]bool{}
for _, fileInfo := range fileInfos {
fileNamesMap[fileInfo.Name()] = true
}
for _, fileName := range fileNames {
assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.")
}
testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", "\npackage model", actorModelFile)
}
var mpaaRatingEnumFile = `
@ -236,3 +221,57 @@ type Actor struct {
LastUpdate time.Time
}
`
var actorInfoSQLBuilderFile = `
package view
import (
"github.com/go-jet/jet/postgres"
)
var ActorInfo = newActorInfoTable()
type ActorInfoTable struct {
postgres.Table
//Columns
ActorID postgres.ColumnInteger
FirstName postgres.ColumnString
LastName postgres.ColumnString
FilmInfo postgres.ColumnString
AllColumns postgres.IColumnList
MutableColumns postgres.IColumnList
}
// creates new ActorInfoTable with assigned alias
func (a *ActorInfoTable) AS(alias string) *ActorInfoTable {
aliasTable := newActorInfoTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newActorInfoTable() *ActorInfoTable {
var (
ActorIDColumn = postgres.IntegerColumn("actor_id")
FirstNameColumn = postgres.StringColumn("first_name")
LastNameColumn = postgres.StringColumn("last_name")
FilmInfoColumn = postgres.StringColumn("film_info")
)
return &ActorInfoTable{
Table: postgres.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
//Columns
ActorID: ActorIDColumn,
FirstName: FirstNameColumn,
LastName: LastNameColumn,
FilmInfo: FilmInfoColumn,
AllColumns: postgres.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
MutableColumns: postgres.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
}
}
`

View file

@ -7,6 +7,7 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/enum"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/view"
"gotest.tools/assert"
"testing"
"time"
@ -19,15 +20,15 @@ SELECT DISTINCT actor.actor_id AS "actor.actor_id",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update"
FROM dvds.actor
WHERE actor.actor_id = 1;
WHERE actor.actor_id = 2;
`
query := Actor.
SELECT(Actor.AllColumns).
DISTINCT().
WHERE(Actor.ActorID.EQ(Int(1)))
WHERE(Actor.ActorID.EQ(Int(2)))
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1))
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(2))
actor := model.Actor{}
err := query.Query(db, &actor)
@ -35,9 +36,9 @@ WHERE actor.actor_id = 1;
assert.NilError(t, err)
expectedActor := model.Actor{
ActorID: 1,
FirstName: "Penelope",
LastName: "Guiness",
ActorID: 2,
FirstName: "Nick",
LastName: "Wahlberg",
LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2),
}
@ -1615,3 +1616,169 @@ SELECT true,
err := query.Query(db, &struct{}{})
assert.NilError(t, err)
}
func TestWindowFunction(t *testing.T) {
var expectedSQL = `
SELECT AVG(payment.amount) OVER (),
AVG(payment.amount) OVER (PARTITION BY payment.customer_id),
MAX(payment.amount) OVER (ORDER BY payment.payment_date DESC),
MIN(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC ROWS BETWEEN 1 PRECEDING AND 6 FOLLOWING),
SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
MAX(payment.customer_id) OVER (ORDER BY payment.payment_date DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING),
MIN(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
SUM(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
ROW_NUMBER() OVER (ORDER BY payment.payment_date),
RANK() OVER (ORDER BY payment.payment_date),
DENSE_RANK() OVER (ORDER BY payment.payment_date),
CUME_DIST() OVER (ORDER BY payment.payment_date),
NTILE(11) OVER (ORDER BY payment.payment_date),
LAG(payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount, 2, $1) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount, 2, $2) OVER (ORDER BY payment.payment_date),
FIRST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date),
LAST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date),
NTH_VALUE(payment.amount, 3) OVER (ORDER BY payment.payment_date)
FROM dvds.payment
WHERE payment.payment_id < $3
GROUP BY payment.amount, payment.customer_id, payment.payment_date;
`
query := Payment.
SELECT(
AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID)),
MAXf(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate.DESC())),
MINf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).
ORDER_BY(Payment.PaymentDate.DESC()).ROWS(PRECEDING(1), FOLLOWING(6))),
SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).
ORDER_BY(Payment.PaymentDate.DESC()).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
MAXi(Payment.CustomerID).OVER(ORDER_BY(Payment.PaymentDate.DESC()).ROWS(CURRENT_ROW, FOLLOWING(UNBOUNDED))),
MINi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
SUMi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
ROW_NUMBER().OVER(ORDER_BY(Payment.PaymentDate)),
RANK().OVER(ORDER_BY(Payment.PaymentDate)),
DENSE_RANK().OVER(ORDER_BY(Payment.PaymentDate)),
CUME_DIST().OVER(ORDER_BY(Payment.PaymentDate)),
NTILE(11).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)),
FIRST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
NTH_VALUE(Payment.Amount, 3).OVER(ORDER_BY(Payment.PaymentDate)),
).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate).
WHERE(Payment.PaymentID.LT(Int(10)))
fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10))
err := query.Query(db, &struct{}{})
assert.NilError(t, err)
}
func TestWindowClause(t *testing.T) {
var expectedSQL = `
SELECT AVG(payment.amount) OVER (),
AVG(payment.amount) OVER (w1),
AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)
FROM dvds.payment
WHERE payment.payment_id < $1
WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id)
ORDER BY payment.customer_id;
`
query := Payment.SELECT(
AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(Window("w1")),
AVG(Payment.Amount).OVER(
Window("w2").
ORDER_BY(Payment.CustomerID).
RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)),
),
AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
).
WHERE(Payment.PaymentID.LT(Int(10))).
WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)).
WINDOW("w2").AS(Window("w1")).
WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)).
ORDER_BY(Payment.CustomerID)
fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, int64(10))
err := query.Query(db, &struct{}{})
assert.NilError(t, err)
}
func TestSimpleView(t *testing.T) {
query := SELECT(
view.ActorInfo.AllColumns,
).
FROM(view.ActorInfo).
ORDER_BY(view.ActorInfo.ActorID).
LIMIT(10)
type ActorInfo struct {
ActorID int
FirstName string
LastName string
FilmInfo string
}
var dest []ActorInfo
err := query.Query(db, &dest)
assert.NilError(t, err)
testutils.AssertJSON(t, dest[1:2], `
[
{
"ActorID": 2,
"FirstName": "Nick",
"LastName": "Wahlberg",
"FilmInfo": "Action: Bull Shawshank, Animation: Fight Jawbreaker, Children: Jersey Sassy, Classics: Dracula Crystal, Gilbert Pelican, Comedy: Mallrats United, Rushmore Mermaid, Documentary: Adaptation Holes, Drama: Wardrobe Phantom, Family: Apache Divine, Chisum Behavior, Indian Love, Maguire Apache, Foreign: Baby Hall, Happiness United, Games: Roof Champion, Music: Lucky Flying, New: Destiny Saturday, Flash Wars, Jekyll Frogmen, Mask Peach, Sci-Fi: Chainsaw Uptown, Goodfellas Salute, Travel: Liaisons Sweet, Smile Earring"
}
]
`)
}
func TestJoinViewWithTable(t *testing.T) {
query := SELECT(
view.CustomerList.AllColumns,
Rental.AllColumns,
).
FROM(view.CustomerList.
INNER_JOIN(Rental, view.CustomerList.ID.EQ(Rental.CustomerID)),
).
ORDER_BY(view.CustomerList.ID).
WHERE(view.CustomerList.ID.LT_EQ(Int(2)))
var dest []struct {
model.CustomerList `sql:"primary_key=ID"`
Rentals []model.Rental
}
fmt.Println(query.DebugSql())
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
assert.Equal(t, len(dest[0].Rentals), 32)
assert.Equal(t, len(dest[1].Rentals), 27)
}

@ -1 +1 @@
Subproject commit 7f3f3cc26ce34324f3699d6b422376671b827490
Subproject commit 1f6bd8bb86458019fa43b1e2cd7ae9488a7ac9a4