Merge pull request #13 from go-jet/develop

Merge develop to master
This commit is contained in:
go-jet 2019-09-21 15:58:30 +02:00 committed by GitHub
commit fbf3b6d51c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
48 changed files with 1903 additions and 512 deletions

4
NOTICE
View file

@ -8,3 +8,7 @@ https://github.com/dropbox/godropbox/tree/master/database/sqlbuilder (BSD-3)
This product contains a modified portion of 'snaker' which can be obtained at: This product contains a modified portion of 'snaker' which can be obtained at:
https://github.com/serenize/snaker (MIT) https://github.com/serenize/snaker (MIT)
This product contains `FormatTimestamp` function from 'pq' which can be obtained at:
https://github.com/lib/pq (MIT)

View file

@ -12,7 +12,7 @@ convert database query result into desired arbitrary object structure.
Jet currently supports `PostgreSQL`, `MySQL` and `MariaDB`. Future releases will add support for additional databases. Jet currently supports `PostgreSQL`, `MySQL` and `MariaDB`. Future releases will add support for additional databases.
![jet](https://github.com/go-jet/jet/wiki/image/jet.png) ![jet](https://github.com/go-jet/jet/wiki/image/jet.png)
Jet is the easiest and fastest way to write complex SQL queries and map database query result Jet is the easiest and the fastest way to write complex SQL queries and map database query result
into complex object composition. __It is not an ORM.__ into complex object composition. __It is not an ORM.__
## Motivation ## Motivation
@ -46,7 +46,7 @@ https://medium.com/@go.jet/jet-5f3667efa0cc
* UPDATE `(SET, WHERE)`, * UPDATE `(SET, WHERE)`,
* DELETE `(WHERE, ORDER_BY, LIMIT)`, * DELETE `(WHERE, ORDER_BY, LIMIT)`,
* LOCK `(READ, WRITE)` * LOCK `(READ, WRITE)`
2) Auto-generated Data Model types - Go types mapped to database type (table or enum), used to store 2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store
result of database queries. Can be combined to create desired query result destination. result of database queries. Can be combined to create desired query result destination.
3) Query execution with result mapping to arbitrary destination structure. 3) Query execution with result mapping to arbitrary destination structure.
@ -88,12 +88,13 @@ jet -source=PostgreSQL -host=localhost -port=5432 -user=jetuser -password=jetpas
```sh ```sh
Connecting to postgres database: host=localhost port=5432 user=jetuser password=jetpass dbname=jetdb sslmode=disable Connecting to postgres database: host=localhost port=5432 user=jetuser password=jetpass dbname=jetdb sslmode=disable
Retrieving schema information... Retrieving schema information...
FOUND 15 table(s), 1 enum(s) FOUND 15 table(s), 7 view(s), 1 enum(s)
Destination directory: ./gen/jetdb/dvds Cleaning up destination directory...
Cleaning up schema destination directory...
Generating table sql builder files... Generating table sql builder files...
Generating table model files... Generating view sql builder files...
Generating enum sql builder files... Generating enum sql builder files...
Generating table model files...
Generating view model files...
Generating enum model files... Generating enum model files...
Done Done
``` ```
@ -102,9 +103,9 @@ be omitted (both databases doesn't have schema support).
_*User has to have a permission to read information schema tables._ _*User has to have a permission to read information schema tables._
As command output suggest, Jet will: As command output suggest, Jet will:
- connect to postgres database and retrieve information about the _tables_ and _enums_ of `dvds` schema - connect to postgres database and retrieve information about the _tables_, _views_ and _enums_ of `dvds` schema
- delete everything in schema destination folder - `./gen/jetdb/dvds`, - delete everything in schema destination folder - `./gen/jetdb/dvds`,
- and finally generate SQL Builder and Model files for each schema table and enum. - and finally generate SQL Builder and Model files for each schema table, view and enum.
Generated files folder structure will look like this: Generated files folder structure will look like this:
@ -112,20 +113,24 @@ Generated files folder structure will look like this:
|-- gen # -path |-- gen # -path
| `-- jetdb # database name | `-- jetdb # database name
| `-- dvds # schema name | `-- dvds # schema name
| |-- enum # sql builder folder for enums | |-- enum # sql builder package for enums
| | |-- mpaa_rating.go | | |-- mpaa_rating.go
| |-- table # sql builder folder for tables | |-- table # sql builder package for tables
| |-- actor.go | |-- actor.go
| |-- address.go | |-- address.go
| |-- category.go | |-- category.go
| ... | ...
| |-- model # model files for each table and enum | |-- view # sql builder package for views
| |-- actor_info.go
| |-- film_list.go
| ...
| |-- model # data model types for each table, view and enum
| | |-- actor.go | | |-- actor.go
| | |-- address.go | | |-- address.go
| | |-- mpaa_rating.go | | |-- mpaa_rating.go
| | ... | | ...
``` ```
Types from `table` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store Types from `table`, `view` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store
results of the SQL queries. results of the SQL queries.
@ -167,7 +172,8 @@ stmt := SELECT(
Film.FilmID.ASC(), Film.FilmID.ASC(),
) )
``` ```
Package(dot) import is used so that statement would resemble as much as possible as native SQL. Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with _Package(dot) import is used so that statement would resemble as much as possible as native SQL._
Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with
string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns
and can be compared only with integer columns and expressions. and can be compared only with integer columns and expressions.
@ -268,11 +274,12 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
#### Execute query and store result #### Execute query and store result
Well formed SQL is just a first half the job. Lets see how can we make some sense of result set returned executing Well formed SQL is just a first half of the job. Lets see how can we make some sense of result set returned executing
above statement. Usually this is the most complex and tedious work, but with Jet it is the easiest. above statement. Usually this is the most complex and tedious work, but with Jet it is the easiest.
First we have to create desired structure to store query result set. First we have to create desired structure to store query result.
This is done be combining autogenerated model types or it can be done manually(see [wiki](https://github.com/go-jet/jet/wiki/Scan-to-arbitrary-destination) for more information). This is done be combining autogenerated model types or it can be done
manually(see [wiki](https://github.com/go-jet/jet/wiki/Query-Result-Mapping-(QRM)) for more information).
Let's say this is our desired structure: Let's say this is our desired structure:
```go ```go
@ -287,8 +294,8 @@ var dest []struct {
} }
} }
``` ```
Because one actor can act in multiple films, `Films` field is a slice, and because each film belongs to one language `Films` field is a slice because one actor can act in multiple films, and because each film belongs to one language
`Langauge` field is just a single model struct. `Langauge` field is just a single model struct. `Film` can belong to multiple categories.
_*There is no limitation of how big or nested destination can be._ _*There is no limitation of how big or nested destination can be._
Now lets execute a above statement on open database connection (or transaction) db and store result into `dest`. Now lets execute a above statement on open database connection (or transaction) db and store result into `dest`.
@ -504,12 +511,14 @@ The biggest benefit is speed. Speed is improved in 3 major areas:
##### Speed of development ##### Speed of development
Writing SQL queries is much easier, because programmer has the help of SQL code completion and SQL type safety directly in Go. Writing SQL queries is faster and easier, because the developers have help of SQL code completion and SQL type safety directly from Go.
Writing code is much faster and code is more robust. Automatic scan to arbitrary structure removes a lot of headache and Automatic scan to arbitrary structure removes a lot of headache and boilerplate code needed to structure database query result.
boilerplate code needed to structure database query result.
##### Speed of execution ##### Speed of execution
While ORM libraries can introduce significant performance penalties due to number of round-trips to the database,
Jet will always perform much better, because of the single database call.
Common web and database server usually are not on the same physical machine, and there is some latency between them. Common web and database server usually are not on the same physical machine, and there is some latency between them.
Latency can vary from 5ms to 50+ms. In majority of cases query executed on database is simple query lasting no more than 1ms. Latency can vary from 5ms to 50+ms. In majority of cases query executed on database is simple query lasting no more than 1ms.
In those cases web server handler execution time is directly proportional to latency between server and database. In those cases web server handler execution time is directly proportional to latency between server and database.
@ -521,14 +530,14 @@ With Jet, handler time lost on latency between server and database is constant.
return result in one database call. Handler execution will be only proportional to the number of rows returned from database. return result in one database call. Handler execution will be only proportional to the number of rows returned from database.
ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate). ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate).
With Jet you can even join the whole database and store the whole structured result in in one query call. With Jet you can even join the whole database and store the whole structured result in one database call.
This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/postgres/chinook_db_test.go#L40). This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/postgres/chinook_db_test.go#L40).
The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s. The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s.
##### How quickly bugs are found ##### How quickly bugs are found
The most expensive bugs are the one on the production and the least expensive are those found during development. The most expensive bugs are the one on the production and the least expensive are those found during development.
With automatically generated type safe SQL not only queries are written faster but bugs are found sooner. With automatically generated type safe SQL, not only queries are written faster but bugs are found sooner.
Lets return to quick start example, and take closer look at a line: Lets return to quick start example, and take closer look at a line:
```go ```go
AND(Film.Length.GT(Int(180))), AND(Film.Length.GT(Int(180))),

View file

@ -91,9 +91,7 @@ func jsonSave(path string, v interface{}) {
err := ioutil.WriteFile(path, jsonText, 0644) err := ioutil.WriteFile(path, jsonText, 0644)
if err != nil { panicOnError(err)
panic(err)
}
} }
func printStatementInfo(stmt SelectStatement) { func printStatementInfo(stmt SelectStatement) {

View file

@ -771,7 +771,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl
if len(subType.indexes) != 0 || len(subType.subTypes) != 0 { if len(subType.indexes) != 0 || len(subType.subTypes) != 0 {
ret.subTypes = append(ret.subTypes, subType) ret.subTypes = append(ret.subTypes, subType)
} }
} else if isPrimaryKey(field) { } else if isPrimaryKey(field, parentField) {
index := s.typeToColumnIndex(newTypeName, fieldName) index := s.typeToColumnIndex(newTypeName, fieldName)
if index < 0 { if index < 0 {
@ -813,9 +813,7 @@ func (s *scanContext) rowElem(index int) interface{} {
value, err := valuer.Value() value, err := valuer.Value()
if err != nil { utils.PanicOnError(err)
panic(err)
}
return value return value
} }
@ -837,13 +835,45 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
return newElem return newElem
} }
func isPrimaryKey(field reflect.StructField) bool { func isPrimaryKey(field reflect.StructField, parentField *reflect.StructField) bool {
if hasOverwrite, isPrimaryKey := primaryKeyOvewrite(field.Name, parentField); hasOverwrite {
return isPrimaryKey
}
sqlTag := field.Tag.Get("sql") sqlTag := field.Tag.Get("sql")
return sqlTag == "primary_key" return sqlTag == "primary_key"
} }
func primaryKeyOvewrite(columnName string, parentField *reflect.StructField) (hasOverwrite, primaryKey bool) {
if parentField == nil {
return
}
sqlTag := parentField.Tag.Get("sql")
if !strings.HasPrefix(sqlTag, "primary_key") {
return
}
parts := strings.Split(sqlTag, "=")
if len(parts) < 2 {
return
}
primaryKeyColumns := strings.Split(parts[1], ",")
for _, primaryKeyCol := range primaryKeyColumns {
if toCommonIdentifier(columnName) == toCommonIdentifier(primaryKeyCol) {
return true, true
}
}
return true, false
}
func indirectType(reflectType reflect.Type) reflect.Type { func indirectType(reflectType reflect.Type) reflect.Type {
if reflectType.Kind() != reflect.Ptr { if reflectType.Kind() != reflect.Ptr {
return reflectType return reflectType

View file

@ -62,7 +62,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
nt.Time, nt.Valid = parseTime(v) nt.Time, nt.Valid = parseTime(v)
return return
default: default:
return fmt.Errorf("can't scan time from %v", value) return fmt.Errorf("can't scan time.Time from %v", value)
} }
} }

View file

@ -0,0 +1,147 @@
package internal
import (
"fmt"
"gotest.tools/assert"
"testing"
"time"
)
func TestNullByteArray(t *testing.T) {
var array NullByteArray
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan([]byte("bytea")))
assert.Equal(t, array.Valid, true)
assert.Equal(t, string(array.ByteArray), string([]byte("bytea")))
assert.Error(t, array.Scan(12), "can't scan []byte from 12")
}
func TestNullTime(t *testing.T) {
var array NullTime
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
time := time.Now()
assert.NilError(t, array.Scan(time))
assert.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, time)
assert.NilError(t, array.Scan([]byte("13:10:11")))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
assert.NilError(t, array.Scan("13:10:11"))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
assert.Error(t, array.Scan(12), "can't scan time.Time from 12")
}
func TestNullInt8(t *testing.T) {
var array NullInt8
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, int8(11))
assert.Error(t, array.Scan("text"), "can't scan int8 from text")
}
func TestNullInt16(t *testing.T) {
var array NullInt16
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, int16(11))
assert.NilError(t, array.Scan(int16(20)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int16(20))
assert.NilError(t, array.Scan(int8(30)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int16(30))
assert.NilError(t, array.Scan(uint8(30)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int16(30))
assert.Error(t, array.Scan("text"), "can't scan int16 from text")
}
func TestNullInt32(t *testing.T) {
var array NullInt32
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, int32(11))
assert.NilError(t, array.Scan(int32(32)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(32))
assert.NilError(t, array.Scan(int16(20)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(20))
assert.NilError(t, array.Scan(uint16(16)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(16))
assert.NilError(t, array.Scan(int8(30)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(30))
assert.NilError(t, array.Scan(uint8(30)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(30))
assert.Error(t, array.Scan("text"), "can't scan int32 from text")
}
func TestNullFloat32(t *testing.T) {
var array NullFloat32
assert.NilError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(float64(64)))
assert.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, float32(64))
assert.NilError(t, array.Scan(float32(32)))
assert.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, float32(32))
assert.Error(t, array.Scan(12), "can't scan float32 from 12")
}

View file

@ -142,13 +142,10 @@ func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string {
return "" return ""
} }
func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) ([]ColumnMetaData, error) { func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) []ColumnMetaData {
rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName) rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName)
utils.PanicOnError(err)
if err != nil {
return nil, err
}
defer rows.Close() defer rows.Close()
ret := []ColumnMetaData{} ret := []ColumnMetaData{}
@ -157,19 +154,13 @@ func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableN
var name, isNullable, dataType, enumName string var name, isNullable, dataType, enumName string
var isUnsigned bool var isUnsigned bool
err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned) err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned)
utils.PanicOnError(err)
if err != nil {
return nil, err
}
ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned)) ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned))
} }
err = rows.Err() err = rows.Err()
utils.PanicOnError(err)
if err != nil { return ret
return nil, err
}
return ret, nil
} }

View file

@ -11,5 +11,5 @@ type DialectQuerySet interface {
ListOfColumnsQuery() string ListOfColumnsQuery() string
ListOfEnumsQuery() string ListOfEnumsQuery() string
GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaData, error) GetEnumsMetaData(db *sql.DB, schemaName string) []MetaData
} }

View file

@ -3,41 +3,43 @@ package metadata
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/go-jet/jet/internal/utils"
) )
// SchemaMetaData struct // SchemaMetaData struct
type SchemaMetaData struct { type SchemaMetaData struct {
TableInfos []MetaData TablesMetaData []MetaData
EnumInfos []MetaData ViewsMetaData []MetaData
EnumsMetaData []MetaData
} }
// GetSchemaInfo returns schema information from db connection. // IsEmpty returns true if schema info does not contain any table, views or enums metadata
func GetSchemaInfo(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData, err error) { func (s SchemaMetaData) IsEmpty() bool {
return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0
}
schemaInfo.TableInfos, err = getTableInfos(db, querySet, schemaName) const (
baseTable = "BASE TABLE"
view = "VIEW"
)
if err != nil { // GetSchemaMetaData returns schema information from db connection.
return func GetSchemaMetaData(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData) {
}
schemaInfo.EnumInfos, err = querySet.GetEnumsMetaData(db, schemaName) schemaInfo.TablesMetaData = getTablesMetaData(db, querySet, schemaName, baseTable)
schemaInfo.ViewsMetaData = getTablesMetaData(db, querySet, schemaName, view)
schemaInfo.EnumsMetaData = querySet.GetEnumsMetaData(db, schemaName)
if err != nil { fmt.Println(" FOUND", len(schemaInfo.TablesMetaData), "table(s),", len(schemaInfo.ViewsMetaData), "view(s),",
return len(schemaInfo.EnumsMetaData), "enum(s)")
}
fmt.Println(" FOUND", len(schemaInfo.TableInfos), "table(s), ", len(schemaInfo.EnumInfos), "enum(s)")
return return
} }
func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]MetaData, error) { func getTablesMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableType string) []MetaData {
rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName) rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName, tableType)
utils.PanicOnError(err)
if err != nil {
return nil, err
}
defer rows.Close() defer rows.Close()
ret := []MetaData{} ret := []MetaData{}
@ -45,24 +47,15 @@ func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]M
var tableName string var tableName string
err = rows.Scan(&tableName) err = rows.Scan(&tableName)
if err != nil { utils.PanicOnError(err)
return nil, err
}
tableInfo, err := GetTableInfo(db, querySet, schemaName, tableName) tableInfo := GetTableMetaData(db, querySet, schemaName, tableName)
if err != nil {
return nil, err
}
ret = append(ret, tableInfo) ret = append(ret, tableInfo)
} }
err = rows.Err() err = rows.Err()
utils.PanicOnError(err)
if err != nil { return ret
return nil, err
}
return ret, nil
} }

View file

@ -67,46 +67,32 @@ func (t TableMetaData) GoStructName() string {
return utils.ToGoIdentifier(t.name) + "Table" return utils.ToGoIdentifier(t.name) + "Table"
} }
// GetTableInfo returns table info metadata // GetTableMetaData returns table info metadata
func GetTableInfo(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData, err error) { func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) {
tableInfo.SchemaName = schemaName tableInfo.SchemaName = schemaName
tableInfo.name = tableName tableInfo.name = tableName
tableInfo.PrimaryKeys, err = getPrimaryKeys(db, querySet, schemaName, tableName) tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName)
if err != nil { tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName)
return
}
tableInfo.Columns, err = getColumnsMetaData(db, querySet, schemaName, tableName)
if err != nil {
return
}
return return
} }
func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (map[string]bool, error) { func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) map[string]bool {
rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName) rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName)
utils.PanicOnError(err)
if err != nil {
return nil, err
}
primaryKeyMap := map[string]bool{} primaryKeyMap := map[string]bool{}
for rows.Next() { for rows.Next() {
primaryKey := "" primaryKey := ""
err := rows.Scan(&primaryKey) err := rows.Scan(&primaryKey)
utils.PanicOnError(err)
if err != nil {
return nil, err
}
primaryKeyMap[primaryKey] = true primaryKeyMap[primaryKey] = true
} }
return primaryKeyMap, nil return primaryKeyMap
} }

View file

@ -12,89 +12,65 @@ import (
) )
// GenerateFiles generates Go files from tables and enums metadata // GenerateFiles generates Go files from tables and enums metadata
func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect jet.Dialect) error { func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect jet.Dialect) {
if len(tables) == 0 && len(enums) == 0 { if schemaInfo.IsEmpty() {
return nil return
} }
fmt.Println("Destination directory:", destDir) fmt.Println("Destination directory:", destDir)
fmt.Println("Cleaning up destination directory...") fmt.Println("Cleaning up destination directory...")
err := utils.CleanUpGeneratedFiles(destDir) err := utils.CleanUpGeneratedFiles(destDir)
utils.PanicOnError(err)
if err != nil { generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect)
return err generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect)
} generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect)
fmt.Println("Generating table sql builder files...") generateModelFiles(destDir, "table", tableModelTemplate, schemaInfo.TablesMetaData, dialect)
err = generate(destDir, "table", tableSQLBuilderTemplate, tables, dialect) generateModelFiles(destDir, "view", tableModelTemplate, schemaInfo.ViewsMetaData, dialect)
generateModelFiles(destDir, "enum", enumModelTemplate, schemaInfo.EnumsMetaData, dialect)
if err != nil {
return err
}
fmt.Println("Generating table model files...")
err = generate(destDir, "model", tableModelTemplate, tables, dialect)
if err != nil {
return err
}
if len(enums) > 0 {
fmt.Println("Generating enum sql builder files...")
err = generate(destDir, "enum", enumSQLBuilderTemplate, enums, dialect)
if err != nil {
return err
}
fmt.Println("Generating enum model files...")
err = generate(destDir, "model", enumModelTemplate, enums, dialect)
if err != nil {
return err
}
}
fmt.Println("Done") fmt.Println("Done")
return nil
} }
func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error { func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) {
if len(metaData) == 0 {
return
}
fmt.Printf("Generating %s sql builder files...\n", fileTypes)
generateGoFiles(destDir, fileTypes, sqlBuilderTemplate, metaData, dialect)
}
func generateModelFiles(destDir, fileTypes, modelTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) {
if len(metaData) == 0 {
return
}
fmt.Printf("Generating %s model files...\n", fileTypes)
generateGoFiles(destDir, "model", modelTemplate, metaData, dialect)
}
func generateGoFiles(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) {
modelDirPath := filepath.Join(dirPath, packageName) modelDirPath := filepath.Join(dirPath, packageName)
err := utils.EnsureDirPath(modelDirPath) err := utils.EnsureDirPath(modelDirPath)
utils.PanicOnError(err)
if err != nil {
return err
}
autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect) autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect)
utils.PanicOnError(err)
if err != nil {
return err
}
for _, metaData := range metaDataList { for _, metaData := range metaDataList {
text, err := GenerateTemplate(template, metaData, dialect) text, err := GenerateTemplate(template, metaData, dialect, map[string]interface{}{"package": packageName})
utils.PanicOnError(err)
if err != nil {
return err
}
err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...)) err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...))
utils.PanicOnError(err)
if err != nil {
return err
}
} }
return nil return
} }
// GenerateTemplate generates template with template text and template data. // GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}, dialect1 jet.Dialect) ([]byte, error) { func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect, params ...map[string]interface{}) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": utils.ToGoIdentifier, "ToGoIdentifier": utils.ToGoIdentifier,
@ -102,7 +78,13 @@ func GenerateTemplate(templateText string, templateData interface{}, dialect1 je
return time.Now().Format(time.RFC850) return time.Now().Format(time.RFC850)
}, },
"dialect": func() jet.Dialect { "dialect": func() jet.Dialect {
return dialect1 return dialect
},
"param": func(name string) interface{} {
if len(params) > 0 {
return params[0][name]
}
return ""
}, },
}).Parse(templateText) }).Parse(templateText)

View file

@ -18,7 +18,7 @@ var tableSQLBuilderTemplate = `
{{- end}} {{- end}}
{{- end}} {{- end}}
package table package {{param "package"}}
import ( import (
"github.com/go-jet/jet/{{dialect.PackageName}}" "github.com/go-jet/jet/{{dialect.PackageName}}"

View file

@ -22,50 +22,34 @@ type DBConnection struct {
} }
// Generate generates jet files at destination dir from database connection details // Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error { func Generate(destDir string, dbConn DBConnection) (err error) {
db, err := openConnection(dbConn) defer utils.ErrorCatch(&err)
if err != nil {
return err db := openConnection(dbConn)
}
defer utils.DBClose(db) defer utils.DBClose(db)
fmt.Println("Retrieving database information...") fmt.Println("Retrieving database information...")
// No schemas in MySQL // No schemas in MySQL
dbInfo, err := metadata.GetSchemaInfo(db, dbConn.DBName, &mySqlQuerySet{}) dbInfo := metadata.GetSchemaMetaData(db, dbConn.DBName, &mySqlQuerySet{})
if err != nil {
return err
}
genPath := path.Join(destDir, dbConn.DBName) genPath := path.Join(destDir, dbConn.DBName)
err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, mysql.Dialect) template.GenerateFiles(genPath, dbInfo, mysql.Dialect)
if err != nil {
return err
}
return nil return nil
} }
func openConnection(dbConn DBConnection) (*sql.DB, error) { func openConnection(dbConn DBConnection) *sql.DB {
var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName) var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName)
if dbConn.Params != "" { if dbConn.Params != "" {
connectionString += "?" + dbConn.Params connectionString += "?" + dbConn.Params
} }
db, err := sql.Open("mysql", connectionString)
fmt.Println("Connecting to MySQL database: " + connectionString) fmt.Println("Connecting to MySQL database: " + connectionString)
db, err := sql.Open("mysql", connectionString)
if err != nil { utils.PanicOnError(err)
return nil, err
}
err = db.Ping() err = db.Ping()
utils.PanicOnError(err)
if err != nil { return db
return nil, err
}
return db, nil
} }

View file

@ -3,6 +3,7 @@ package mysql
import ( import (
"database/sql" "database/sql"
"github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/internal/utils"
"strings" "strings"
) )
@ -13,7 +14,7 @@ func (m *mySqlQuerySet) ListOfTablesQuery() string {
return ` return `
SELECT table_name SELECT table_name
FROM INFORMATION_SCHEMA.tables FROM INFORMATION_SCHEMA.tables
WHERE table_schema = ? and table_type = 'BASE TABLE'; WHERE table_schema = ? and table_type = ?;
` `
} }
@ -46,17 +47,14 @@ func (m *mySqlQuerySet) ListOfEnumsQuery() string {
SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5) SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5)
FROM information_schema.columns as c FROM information_schema.columns as c
INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name) INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name)
WHERE c.table_schema = ? AND DATA_TYPE = 'enum' AND t.TABLE_TYPE = 'BASE TABLE'; WHERE c.table_schema = ? AND DATA_TYPE = 'enum';
` `
} }
func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.MetaData, error) { func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData {
rows, err := db.Query(m.ListOfEnumsQuery(), schemaName) rows, err := db.Query(m.ListOfEnumsQuery(), schemaName)
utils.PanicOnError(err)
if err != nil {
return nil, err
}
defer rows.Close() defer rows.Close()
ret := []metadata.MetaData{} ret := []metadata.MetaData{}
@ -65,9 +63,7 @@ func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metad
var enumName string var enumName string
var enumValues string var enumValues string
err = rows.Scan(&enumName, &enumValues) err = rows.Scan(&enumName, &enumValues)
if err != nil { utils.PanicOnError(err)
return nil, err
}
enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1) enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1)
@ -78,11 +74,8 @@ func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metad
} }
err = rows.Err() err = rows.Err()
utils.PanicOnError(err)
if err != nil { return ret
return nil, err
}
return ret, nil
} }

View file

@ -25,31 +25,20 @@ type DBConnection struct {
} }
// Generate generates jet files at destination dir from database connection details // Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error { func Generate(destDir string, dbConn DBConnection) (err error) {
defer utils.ErrorCatch(&err)
db, err := openConnection(dbConn) db, err := openConnection(dbConn)
utils.PanicOnError(err)
defer utils.DBClose(db) defer utils.DBClose(db)
if err != nil {
return err
}
fmt.Println("Retrieving schema information...") fmt.Println("Retrieving schema information...")
schemaInfo, err := metadata.GetSchemaInfo(db, dbConn.SchemaName, &postgresQuerySet{}) schemaInfo := metadata.GetSchemaMetaData(db, dbConn.SchemaName, &postgresQuerySet{})
if err != nil {
return err
}
genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName) genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
template.GenerateFiles(genPath, schemaInfo, postgres.Dialect)
err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, postgres.Dialect) return
if err != nil {
return err
}
return nil
} }
func openConnection(dbConn DBConnection) (*sql.DB, error) { func openConnection(dbConn DBConnection) (*sql.DB, error) {

View file

@ -3,6 +3,7 @@ package postgres
import ( import (
"database/sql" "database/sql"
"github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/internal/utils"
) )
// postgresQuerySet is dialect query set for PostgreSQL // postgresQuerySet is dialect query set for PostgreSQL
@ -12,7 +13,7 @@ func (p *postgresQuerySet) ListOfTablesQuery() string {
return ` return `
SELECT table_name SELECT table_name
FROM information_schema.tables FROM information_schema.tables
where table_schema = $1 and table_type = 'BASE TABLE'; where table_schema = $1 and table_type = $2;
` `
} }
@ -45,12 +46,9 @@ WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;` ORDER BY n.nspname, t.typname, e.enumsortorder;`
} }
func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.MetaData, error) { func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData {
rows, err := db.Query(p.ListOfEnumsQuery(), schemaName) rows, err := db.Query(p.ListOfEnumsQuery(), schemaName)
utils.PanicOnError(err)
if err != nil {
return nil, err
}
defer rows.Close() defer rows.Close()
enumsInfosMap := map[string][]string{} enumsInfosMap := map[string][]string{}
@ -58,9 +56,7 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me
var enumName string var enumName string
var enumValue string var enumValue string
err = rows.Scan(&enumName, &enumValue) err = rows.Scan(&enumName, &enumValue)
if err != nil { utils.PanicOnError(err)
return nil, err
}
enumValues := enumsInfosMap[enumName] enumValues := enumsInfosMap[enumName]
@ -70,10 +66,7 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me
} }
err = rows.Err() err = rows.Err()
utils.PanicOnError(err)
if err != nil {
return nil, err
}
ret := []metadata.MetaData{} ret := []metadata.MetaData{}
@ -84,5 +77,5 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me
}) })
} }
return ret, nil return ret
} }

View file

@ -0,0 +1,42 @@
package pq
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany
import (
"strconv"
"time"
)
// FormatTimestamp formats t into Postgres' text format for timestamps. From: github.com/lib/pq
func FormatTimestamp(t time.Time) []byte {
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
// minus sign preferred by Go.
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
bc := false
if t.Year() <= 0 {
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
t = t.AddDate((-t.Year())*2+1, 0, 0)
bc = true
}
b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
_, offset := t.Zone()
offset = offset % 60
if offset != 0 {
// RFC3339Nano already printed the minus sign
if offset < 0 {
offset = -offset
}
b = append(b, ':')
if offset < 10 {
b = append(b, '0')
}
b = strconv.AppendInt(b, int64(offset), 10)
}
if bc {
b = append(b, " BC"...)
}
return b
}

View file

@ -0,0 +1,39 @@
package pq
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany
import (
"testing"
"time"
)
var formatTimeTests = []struct {
time time.Time
expected string
}{
{time.Time{}, "0001-01-01 00:00:00Z"},
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03 04:05:06.123456789Z"},
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03 04:05:06.123456789+02:00"},
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03 04:05:06.123456789-06:00"},
{time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03 04:05:06-07:30:09"},
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z"},
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00"},
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00"},
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z BC"},
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00 BC"},
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00 BC"},
{time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09"},
{time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09 BC"},
}
func TestFormatTs(t *testing.T) {
for i, tt := range formatTimeTests {
val := string(FormatTimestamp(tt.time))
if val != tt.expected {
t.Errorf("%d: incorrect time format %q, want %q", i, val, tt.expected)
}
}
}

View file

@ -135,6 +135,7 @@ func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) {
// ClauseOrderBy struct // ClauseOrderBy struct
type ClauseOrderBy struct { type ClauseOrderBy struct {
List []OrderByClause List []OrderByClause
SkipNewLine bool
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
@ -143,7 +144,9 @@ func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder)
return return
} }
if !o.SkipNewLine {
out.NewLine() out.NewLine()
}
out.WriteString("ORDER BY") out.WriteString("ORDER BY")
out.IncreaseIdent() out.IncreaseIdent()
@ -469,3 +472,37 @@ func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString(string(i.LockMode)) out.WriteString(string(i.LockMode))
out.WriteString("MODE") out.WriteString("MODE")
} }
// WindowDefinition struct
type WindowDefinition struct {
Name string
Window Window
}
// ClauseWindow struct
type ClauseWindow struct {
Definitions []WindowDefinition
}
// Serialize serializes clause into SQLBuilder
func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) {
if len(i.Definitions) == 0 {
return
}
out.NewLine()
out.WriteString("WINDOW")
for i, def := range i.Definitions {
if i > 0 {
out.WriteString(", ")
}
out.WriteString(def.Name)
out.WriteString("AS")
if def.Window == nil {
out.WriteString("()")
continue
}
def.Window.serialize(statementType, out)
}
}

View file

@ -81,68 +81,154 @@ func LOG(floatExpression FloatExpression) FloatExpression {
// ----------------- Aggregate functions -------------------// // ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression // AVG is aggregate function used to calculate avg value from numeric expression
func AVG(numericExpression NumericExpression) FloatExpression { func AVG(numericExpression NumericExpression) floatWindowExpression {
return NewFloatFunc("AVG", numericExpression) return NewFloatWindowFunc("AVG", numericExpression)
} }
// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none. // BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none.
func BIT_AND(integerExpression IntegerExpression) IntegerExpression { func BIT_AND(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerFunc("BIT_AND", integerExpression) return newIntegerWindowFunc("BIT_AND", integerExpression)
} }
// BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none. // BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none.
func BIT_OR(integerExpression IntegerExpression) IntegerExpression { func BIT_OR(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerFunc("BIT_OR", integerExpression) return newIntegerWindowFunc("BIT_OR", integerExpression)
} }
// BOOL_AND is aggregate function. Returns true if all input values are true, otherwise false // BOOL_AND is aggregate function. Returns true if all input values are true, otherwise false
func BOOL_AND(boolExpression BoolExpression) BoolExpression { func BOOL_AND(boolExpression BoolExpression) boolWindowExpression {
return newBoolFunc("BOOL_AND", boolExpression) return newBoolWindowFunc("BOOL_AND", boolExpression)
} }
// BOOL_OR is aggregate function. Returns true if at least one input value is true, otherwise false // BOOL_OR is aggregate function. Returns true if at least one input value is true, otherwise false
func BOOL_OR(boolExpression BoolExpression) BoolExpression { func BOOL_OR(boolExpression BoolExpression) boolWindowExpression {
return newBoolFunc("BOOL_OR", boolExpression) return newBoolWindowFunc("BOOL_OR", boolExpression)
} }
// COUNT is aggregate function. Returns number of input rows for which the value of expression is not null. // COUNT is aggregate function. Returns number of input rows for which the value of expression is not null.
func COUNT(expression Expression) IntegerExpression { func COUNT(expression Expression) integerWindowExpression {
return newIntegerFunc("COUNT", expression) return newIntegerWindowFunc("COUNT", expression)
} }
// EVERY is aggregate function. Returns true if all input values are true, otherwise false // EVERY is aggregate function. Returns true if all input values are true, otherwise false
func EVERY(boolExpression BoolExpression) BoolExpression { func EVERY(boolExpression BoolExpression) boolWindowExpression {
return newBoolFunc("EVERY", boolExpression) return newBoolWindowFunc("EVERY", boolExpression)
} }
// MAXf is aggregate function. Returns maximum value of float expression across all input values // MAXf is aggregate function. Returns maximum value of float expression across all input values
func MAXf(floatExpression FloatExpression) FloatExpression { func MAXf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatFunc("MAX", floatExpression) return NewFloatWindowFunc("MAX", floatExpression)
} }
// MAXi is aggregate function. Returns maximum value of int expression across all input values // MAXi is aggregate function. Returns maximum value of int expression across all input values
func MAXi(integerExpression IntegerExpression) IntegerExpression { func MAXi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerFunc("MAX", integerExpression) return newIntegerWindowFunc("MAX", integerExpression)
} }
// MINf is aggregate function. Returns minimum value of float expression across all input values // MINf is aggregate function. Returns minimum value of float expression across all input values
func MINf(floatExpression FloatExpression) FloatExpression { func MINf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatFunc("MIN", floatExpression) return NewFloatWindowFunc("MIN", floatExpression)
} }
// MINi is aggregate function. Returns minimum value of int expression across all input values // MINi is aggregate function. Returns minimum value of int expression across all input values
func MINi(integerExpression IntegerExpression) IntegerExpression { func MINi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerFunc("MIN", integerExpression) return newIntegerWindowFunc("MIN", integerExpression)
} }
// SUMf is aggregate function. Returns sum of expression across all float expressions // SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) FloatExpression { func SUMf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatFunc("SUM", floatExpression) return NewFloatWindowFunc("SUM", floatExpression)
} }
// SUMi is aggregate function. Returns sum of expression across all integer expression. // SUMi is aggregate function. Returns sum of expression across all integer expression.
func SUMi(integerExpression IntegerExpression) IntegerExpression { func SUMi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerFunc("SUM", integerExpression) return newIntegerWindowFunc("SUM", integerExpression)
}
// ----------------- Window functions -------------------//
// ROW_NUMBER returns number of the current row within its partition, counting from 1
func ROW_NUMBER() integerWindowExpression {
return newIntegerWindowFunc("ROW_NUMBER")
}
// RANK of the current row with gaps; same as row_number of its first peer
func RANK() integerWindowExpression {
return newIntegerWindowFunc("RANK")
}
// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups
func DENSE_RANK() integerWindowExpression {
return newIntegerWindowFunc("DENSE_RANK")
}
// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1)
func PERCENT_RANK() floatWindowExpression {
return NewFloatWindowFunc("PERCENT_RANK")
}
// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows
func CUME_DIST() floatWindowExpression {
return NewFloatWindowFunc("CUME_DIST")
}
// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible
func NTILE(numOfBuckets int64) integerWindowExpression {
return newIntegerWindowFunc("NTILE", FixedLiteral(numOfBuckets))
}
// LAG returns value evaluated at the row that is offset rows before the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
func LAG(expr Expression, offsetAndDefault ...interface{}) windowExpression {
return leadLagImpl("LAG", expr, offsetAndDefault...)
}
// LEAD returns value evaluated at the row that is offset rows after the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
func LEAD(expr Expression, offsetAndDefault ...interface{}) windowExpression {
return leadLagImpl("LEAD", expr, offsetAndDefault...)
}
// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame
func FIRST_VALUE(value Expression) windowExpression {
return newWindowFunc("FIRST_VALUE", value)
}
// LAST_VALUE returns value evaluated at the row that is the last row of the window frame
func LAST_VALUE(value Expression) windowExpression {
return newWindowFunc("LAST_VALUE", value)
}
// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row
func NTH_VALUE(value Expression, nth int64) windowExpression {
return newWindowFunc("NTH_VALUE", value, FixedLiteral(nth))
}
func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{}) windowExpression {
params := []Expression{expr}
if len(offsetAndDefault) >= 2 {
offset, ok := offsetAndDefault[0].(int)
if !ok {
panic("jet: LAG offset should be an integer")
}
var defaultValue Expression
defaultValue, ok = offsetAndDefault[1].(Expression)
if !ok {
defaultValue = literal(offsetAndDefault[1])
}
params = append(params, FixedLiteral(offset), defaultValue)
}
return newWindowFunc(name, params...)
} }
//------------ String functions ------------------// //------------ String functions ------------------//
@ -349,7 +435,7 @@ func TO_HEX(number IntegerExpression) StringExpression {
// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. // REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise.
func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType ...string) BoolExpression { func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType ...string) BoolExpression {
if len(matchType) > 0 { if len(matchType) > 0 {
return newBoolFunc("REGEXP_LIKE", stringExp, pattern, ConstLiteral(matchType[0])) return newBoolFunc("REGEXP_LIKE", stringExp, pattern, FixedLiteral(matchType[0]))
} }
return newBoolFunc("REGEXP_LIKE", stringExp, pattern) return newBoolFunc("REGEXP_LIKE", stringExp, pattern)
@ -391,7 +477,7 @@ func CURRENT_TIME(precision ...int) TimezExpression {
var timezFunc *timezFunc var timezFunc *timezFunc
if len(precision) > 0 { if len(precision) > 0 {
timezFunc = newTimezFunc("CURRENT_TIME", ConstLiteral(precision[0])) timezFunc = newTimezFunc("CURRENT_TIME", FixedLiteral(precision[0]))
} else { } else {
timezFunc = newTimezFunc("CURRENT_TIME") timezFunc = newTimezFunc("CURRENT_TIME")
} }
@ -406,7 +492,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression {
var timestampzFunc *timestampzFunc var timestampzFunc *timestampzFunc
if len(precision) > 0 { if len(precision) > 0 {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", ConstLiteral(precision[0])) timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", FixedLiteral(precision[0]))
} else { } else {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP") timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP")
} }
@ -421,7 +507,7 @@ func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc var timeFunc *timeFunc
if len(precision) > 0 { if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", ConstLiteral(precision[0])) timeFunc = newTimeFunc("LOCALTIME", FixedLiteral(precision[0]))
} else { } else {
timeFunc = newTimeFunc("LOCALTIME") timeFunc = newTimeFunc("LOCALTIME")
} }
@ -436,7 +522,7 @@ func LOCALTIMESTAMP(precision ...int) TimestampExpression {
var timestampFunc *timestampFunc var timestampFunc *timestampFunc
if len(precision) > 0 { if len(precision) > 0 {
timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", ConstLiteral(precision[0])) timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", FixedLiteral(precision[0]))
} else { } else {
timestampFunc = NewTimestampFunc("LOCALTIMESTAMP") timestampFunc = NewTimestampFunc("LOCALTIMESTAMP")
} }
@ -504,6 +590,16 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
return funcExp return funcExp
} }
// NewFloatWindowFunc creates new float function with name and expressions
func newWindowFunc(name string, expressions ...Expression) windowExpression {
newFun := newFunc(name, expressions, nil)
windowExpr := newWindowExpression(newFun)
newFun.expressionInterfaceImpl.Parent = windowExpr
return windowExpr
}
func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(f.expressions...) serializeOverrideFunc := serializeOverride(f.expressions...)
@ -536,10 +632,23 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression {
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc)
boolFunc.boolInterfaceImpl.parent = boolFunc boolFunc.boolInterfaceImpl.parent = boolFunc
boolFunc.expressionInterfaceImpl.Parent = boolFunc
return boolFunc return boolFunc
} }
// NewFloatWindowFunc creates new float function with name and expressions
func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression {
boolFunc := &boolFunc{}
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc)
intWindowFunc := newBoolWindowExpression(boolFunc)
boolFunc.boolInterfaceImpl.parent = intWindowFunc
boolFunc.expressionInterfaceImpl.Parent = intWindowFunc
return intWindowFunc
}
type floatFunc struct { type floatFunc struct {
funcExpressionImpl funcExpressionImpl
floatInterfaceImpl floatInterfaceImpl
@ -555,6 +664,18 @@ func NewFloatFunc(name string, expressions ...Expression) FloatExpression {
return floatFunc return floatFunc
} }
// NewFloatWindowFunc creates new float function with name and expressions
func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression {
floatFunc := &floatFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc)
floatWindowFunc := newFloatWindowExpression(floatFunc)
floatFunc.floatInterfaceImpl.parent = floatWindowFunc
floatFunc.expressionInterfaceImpl.Parent = floatWindowFunc
return floatWindowFunc
}
type integerFunc struct { type integerFunc struct {
funcExpressionImpl funcExpressionImpl
integerInterfaceImpl integerInterfaceImpl
@ -569,6 +690,18 @@ func newIntegerFunc(name string, expressions ...Expression) IntegerExpression {
return floatFunc return floatFunc
} }
// NewFloatWindowFunc creates new float function with name and expressions
func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression {
integerFunc := &integerFunc{}
integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc)
intWindowFunc := newIntegerWindowExpression(integerFunc)
integerFunc.integerInterfaceImpl.parent = intWindowFunc
integerFunc.expressionInterfaceImpl.Parent = intWindowFunc
return intWindowFunc
}
type stringFunc struct { type stringFunc struct {
funcExpressionImpl funcExpressionImpl
stringInterfaceImpl stringInterfaceImpl

View file

@ -32,8 +32,8 @@ func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl
return &exp return &exp
} }
// ConstLiteral is injected directly to SQL query, and does not appear in argument list. // FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list.
func ConstLiteral(value interface{}) *literalExpressionImpl { func FixedLiteral(value interface{}) *literalExpressionImpl {
exp := literal(value) exp := literal(value)
exp.constant = true exp.constant = true

View file

@ -11,16 +11,9 @@ func TestArgToString(t *testing.T) {
assert.Equal(t, argToString(true), "TRUE") assert.Equal(t, argToString(true), "TRUE")
assert.Equal(t, argToString(false), "FALSE") assert.Equal(t, argToString(false), "FALSE")
assert.Equal(t, argToString(int8(-8)), "-8")
assert.Equal(t, argToString(int16(-16)), "-16")
assert.Equal(t, argToString(int(-32)), "-32") assert.Equal(t, argToString(int(-32)), "-32")
assert.Equal(t, argToString(int32(-32)), "-32") assert.Equal(t, argToString(int32(-32)), "-32")
assert.Equal(t, argToString(int64(-64)), "-64") assert.Equal(t, argToString(int64(-64)), "-64")
assert.Equal(t, argToString(uint8(8)), "8")
assert.Equal(t, argToString(uint16(16)), "16")
assert.Equal(t, argToString(uint(32)), "32")
assert.Equal(t, argToString(uint32(32)), "32")
assert.Equal(t, argToString(uint64(64)), "64")
assert.Equal(t, argToString(float64(1.11)), "1.11") assert.Equal(t, argToString(float64(1.11)), "1.11")
assert.Equal(t, argToString("john"), "'john'") assert.Equal(t, argToString("john"), "'john'")
@ -31,5 +24,12 @@ func TestArgToString(t *testing.T) {
time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006") time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006")
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'") assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'")
assert.Equal(t, argToString(map[string]bool{}), "[Unsupported type]")
func() {
defer func() {
assert.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter")
}()
argToString(map[string]bool{})
}()
} }

View file

@ -2,8 +2,11 @@ package jet
import ( import (
"bytes" "bytes"
"fmt"
"github.com/go-jet/jet/internal/3rdparty/pq"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
"github.com/google/uuid" "github.com/google/uuid"
"reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -139,28 +142,13 @@ func argToString(value interface{}) string {
return "TRUE" return "TRUE"
} }
return "FALSE" return "FALSE"
case int8:
return strconv.FormatInt(int64(bindVal), 10)
case int: case int:
return strconv.FormatInt(int64(bindVal), 10) return strconv.FormatInt(int64(bindVal), 10)
case int16:
return strconv.FormatInt(int64(bindVal), 10)
case int32: case int32:
return strconv.FormatInt(int64(bindVal), 10) return strconv.FormatInt(int64(bindVal), 10)
case int64: case int64:
return strconv.FormatInt(bindVal, 10) return strconv.FormatInt(bindVal, 10)
case uint8:
return strconv.FormatUint(uint64(bindVal), 10)
case uint:
return strconv.FormatUint(uint64(bindVal), 10)
case uint16:
return strconv.FormatUint(uint64(bindVal), 10)
case uint32:
return strconv.FormatUint(uint64(bindVal), 10)
case uint64:
return strconv.FormatUint(uint64(bindVal), 10)
case float32: case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case float64: case float64:
@ -173,9 +161,9 @@ func argToString(value interface{}) string {
case uuid.UUID: case uuid.UUID:
return stringQuote(bindVal.String()) return stringQuote(bindVal.String())
case time.Time: case time.Time:
return stringQuote(string(utils.FormatTimestamp(bindVal))) return stringQuote(string(pq.FormatTimestamp(bindVal)))
default: default:
return "[Unsupported type]" panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String()))
} }
} }

View file

@ -19,15 +19,17 @@ type Table interface {
} }
// NewTable creates new table with schema Name, table Name and list of columns // NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, columns ...ColumnExpression) SerializerTable { func NewTable(schemaName, name string, column ColumnExpression, columns ...ColumnExpression) SerializerTable {
columnList := append([]ColumnExpression{column}, columns...)
t := tableImpl{ t := tableImpl{
schemaName: schemaName, schemaName: schemaName,
name: name, name: name,
columnList: columns, columnList: columnList,
} }
for _, c := range columns { for _, c := range columnList {
c.setTableName(name) c.setTableName(name)
} }

View file

@ -0,0 +1,33 @@
package jet
import (
"gotest.tools/assert"
"testing"
)
func TestNewTable(t *testing.T) {
newTable := NewTable("schema", "table", IntegerColumn("intCol"))
assert.Equal(t, newTable.SchemaName(), "schema")
assert.Equal(t, newTable.TableName(), "table")
assert.Equal(t, len(newTable.columns()), 1)
assert.Equal(t, newTable.columns()[0].Name(), "intCol")
}
func TestNewJoinTable(t *testing.T) {
newTable1 := NewTable("schema", "table", IntegerColumn("intCol1"))
newTable2 := NewTable("schema", "table2", IntegerColumn("intCol2"))
joinTable := NewJoinTable(newTable1, newTable2, InnerJoin, IntegerColumn("intCol1").EQ(IntegerColumn("intCol2")))
assertClauseSerialize(t, joinTable, `schema.table
INNER JOIN schema.table2 ON ("intCol1" = "intCol2")`)
assert.Equal(t, joinTable.SchemaName(), "schema")
assert.Equal(t, joinTable.TableName(), "")
assert.Equal(t, len(joinTable.columns()), 2)
assert.Equal(t, joinTable.columns()[0].Name(), "intCol1")
assert.Equal(t, joinTable.columns()[1].Name(), "intCol2")
}

View file

@ -0,0 +1,146 @@
package jet
type commonWindowImpl struct {
expression Expression
window Window
}
func (w *commonWindowImpl) over(window ...Window) {
if len(window) > 0 {
w.window = window[0]
} else {
w.window = newWindowImpl(nil)
}
}
func (w *commonWindowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
w.expression.serialize(statement, out)
if w.window != nil {
out.WriteString("OVER")
w.window.serialize(statement, out)
}
}
// --------------------------------------
type windowExpression interface {
Expression
OVER(window ...Window) Expression
}
func newWindowExpression(Exp Expression) windowExpression {
newExp := &windowExpressionImpl{
Expression: Exp,
}
newExp.commonWindowImpl.expression = Exp
return newExp
}
type windowExpressionImpl struct {
Expression
commonWindowImpl
}
func (f *windowExpressionImpl) OVER(window ...Window) Expression {
f.commonWindowImpl.over(window...)
return f
}
func (f *windowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}
// -----------------------------------------------------
type floatWindowExpression interface {
FloatExpression
OVER(window ...Window) FloatExpression
}
func newFloatWindowExpression(floatExp FloatExpression) floatWindowExpression {
newExp := &floatWindowExpressionImpl{
FloatExpression: floatExp,
}
newExp.commonWindowImpl.expression = floatExp
return newExp
}
type floatWindowExpressionImpl struct {
FloatExpression
commonWindowImpl
}
func (f *floatWindowExpressionImpl) OVER(window ...Window) FloatExpression {
f.commonWindowImpl.over(window...)
return f
}
func (f *floatWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}
// ------------------------------------------------
type integerWindowExpression interface {
IntegerExpression
OVER(window ...Window) IntegerExpression
}
func newIntegerWindowExpression(intExp IntegerExpression) integerWindowExpression {
newExp := &integerWindowExpressionImpl{
IntegerExpression: intExp,
}
newExp.commonWindowImpl.expression = intExp
return newExp
}
type integerWindowExpressionImpl struct {
IntegerExpression
commonWindowImpl
}
func (f *integerWindowExpressionImpl) OVER(window ...Window) IntegerExpression {
f.commonWindowImpl.over(window...)
return f
}
func (f *integerWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}
// ------------------------------------------------
type boolWindowExpression interface {
BoolExpression
OVER(window ...Window) BoolExpression
}
func newBoolWindowExpression(boolExp BoolExpression) boolWindowExpression {
newExp := &boolWindowExpressionImpl{
BoolExpression: boolExp,
}
newExp.commonWindowImpl.expression = boolExp
return newExp
}
type boolWindowExpressionImpl struct {
BoolExpression
commonWindowImpl
}
func (f *boolWindowExpressionImpl) OVER(window ...Window) BoolExpression {
f.commonWindowImpl.over(window...)
return f
}
func (f *boolWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
}

186
internal/jet/window_func.go Normal file
View file

@ -0,0 +1,186 @@
package jet
// Window interface
type Window interface {
Serializer
ORDER_BY(expr ...OrderByClause) Window
ROWS(start FrameExtent, end ...FrameExtent) Window
RANGE(start FrameExtent, end ...FrameExtent) Window
GROUPS(start FrameExtent, end ...FrameExtent) Window
}
type windowImpl struct {
partitionBy []Expression
orderBy ClauseOrderBy
frameUnits string
start, end FrameExtent
parent Window
}
func newWindowImpl(parent Window) *windowImpl {
newWindow := &windowImpl{}
if parent == nil {
newWindow.parent = newWindow
} else {
newWindow.parent = parent
}
return newWindow
}
func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, noWrap) {
out.WriteByte('(')
}
if w.partitionBy != nil {
out.WriteString("PARTITION BY")
serializeExpressionList(statement, w.partitionBy, ", ", out)
}
w.orderBy.SkipNewLine = true
w.orderBy.Serialize(statement, out)
if w.frameUnits != "" {
out.WriteString(w.frameUnits)
if w.end == nil {
w.start.serialize(statement, out)
} else {
out.WriteString("BETWEEN")
w.start.serialize(statement, out)
out.WriteString("AND")
w.end.serialize(statement, out)
}
}
if !contains(options, noWrap) {
out.WriteByte(')')
}
}
func (w *windowImpl) ORDER_BY(exprs ...OrderByClause) Window {
w.orderBy.List = exprs
return w.parent
}
func (w *windowImpl) ROWS(start FrameExtent, end ...FrameExtent) Window {
w.frameUnits = "ROWS"
w.setFrameRange(start, end...)
return w.parent
}
func (w *windowImpl) RANGE(start FrameExtent, end ...FrameExtent) Window {
w.frameUnits = "RANGE"
w.setFrameRange(start, end...)
return w.parent
}
func (w *windowImpl) GROUPS(start FrameExtent, end ...FrameExtent) Window {
w.frameUnits = "GROUPS"
w.setFrameRange(start, end...)
return w.parent
}
func (w *windowImpl) setFrameRange(start FrameExtent, end ...FrameExtent) {
w.start = start
if len(end) > 0 {
w.end = end[0]
}
}
// PARTITION_BY window function constructor
func PARTITION_BY(exp Expression, exprs ...Expression) Window {
funImpl := newWindowImpl(nil)
funImpl.partitionBy = append([]Expression{exp}, exprs...)
return funImpl
}
// ORDER_BY window function constructor
func ORDER_BY(expr ...OrderByClause) Window {
funImpl := newWindowImpl(nil)
funImpl.orderBy.List = expr
return funImpl
}
// -----------------------------------------------
// FrameExtent interface
type FrameExtent interface {
Serializer
isFrameExtent()
}
// PRECEDING window frame clause
func PRECEDING(offset Serializer) FrameExtent {
return &frameExtentImpl{
preceding: true,
offset: offset,
}
}
// FOLLOWING window frame clause
func FOLLOWING(offset Serializer) FrameExtent {
return &frameExtentImpl{
preceding: false,
offset: offset,
}
}
type frameExtentImpl struct {
preceding bool
offset Serializer
}
func (f *frameExtentImpl) isFrameExtent() {}
func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if f == nil {
return
}
f.offset.serialize(statement, out)
if f.preceding {
out.WriteString("PRECEDING")
} else {
out.WriteString("FOLLOWING")
}
}
// -----------------------------------------------
// Window function keywords
var (
UNBOUNDED = keywordClause("UNBOUNDED")
CURRENT_ROW = frameExtentKeyword{"CURRENT ROW"}
)
type frameExtentKeyword struct {
keywordClause
}
func (f frameExtentKeyword) isFrameExtent() {}
// -----------------------------------------------
// WindowName is used to specify window reference from WINDOW clause
func WindowName(name string) Window {
newWindow := &windowName{name: name}
newWindow.parent = newWindow
return newWindow
}
type windowName struct {
windowImpl
name string
}
func (w windowName) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteByte('(')
out.WriteString(w.name)
w.windowImpl.serialize(statement, out, noWrap)
out.WriteByte(')')
}

View file

@ -0,0 +1,21 @@
package jet
import "testing"
func TestFrameExtent(t *testing.T) {
assertClauseSerialize(t, PRECEDING(Int(2)), "$1 PRECEDING", int64(2))
assertClauseSerialize(t, FOLLOWING(Int(4)), "$1 FOLLOWING", int64(4))
}
func TestWindowFunctions(t *testing.T) {
assertClauseSerialize(t, PARTITION_BY(table1Col1), "(PARTITION BY table1.col1)")
assertClauseSerialize(t, PARTITION_BY(table1Col3).ORDER_BY(table1Col1), "(PARTITION BY table1.col3 ORDER BY table1.col1)")
assertClauseSerialize(t, ORDER_BY(table1Col1), "(ORDER BY table1.col1)")
assertClauseSerialize(t, ORDER_BY(table1Col1).ROWS(PRECEDING(Int(1))), "(ORDER BY table1.col1 ROWS $1 PRECEDING)", int64(1))
assertClauseSerialize(t, ORDER_BY(table1Col1).ROWS(PRECEDING(Int(1)), FOLLOWING(Int(33))),
"(ORDER BY table1.col1 ROWS BETWEEN $1 PRECEDING AND $2 FOLLOWING)", int64(1), int64(33))
assertClauseSerialize(t, ORDER_BY(table1Col1).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)),
"(ORDER BY table1.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)")
assertClauseSerialize(t, ORDER_BY(table1Col1).RANGE(PRECEDING(UNBOUNDED), CURRENT_ROW),
"(ORDER BY table1.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)")
}

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/go-jet/jet/execution" "github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/utils"
"gotest.tools/assert" "gotest.tools/assert"
"io/ioutil" "io/ioutil"
"os" "os"
@ -60,9 +61,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath) filePath := getFullPath(testRelativePath)
err := ioutil.WriteFile(filePath, jsonText, 0644) err := ioutil.WriteFile(filePath, jsonText, 0644)
if err != nil { utils.PanicOnError(err)
panic(err)
}
} }
// AssertJSONFile check if data json representation is the same as json at testRelativePath // AssertJSONFile check if data json representation is the same as json at testRelativePath
@ -159,3 +158,31 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db execution.DB, dest
stmt.Query(db, dest) stmt.Query(db, dest)
} }
// AssertFileContent check if file content at filePath contains expectedContent text.
func AssertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err)
beginIndex := bytes.Index(enumFileData, []byte(contentBegin))
//fmt.Println("-"+string(enumFileData[beginIndex:])+"-")
assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent)
}
// AssertFileNamesEqual check if all filesInfos are contained in fileNames
func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) {
assert.Equal(t, len(fileInfos), len(fileNames))
fileNamesMap := map[string]bool{}
for _, fileInfo := range fileInfos {
fileNamesMap[fileInfo.Name()] = true
}
for _, fileName := range fileNames {
assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.")
}
}

View file

@ -1,6 +1,7 @@
package testutils package testutils
import ( import (
"github.com/go-jet/jet/internal/utils"
"strings" "strings"
"time" "time"
) )
@ -9,9 +10,7 @@ import (
func Date(t string) *time.Time { func Date(t string) *time.Time {
newTime, err := time.Parse("2006-01-02", t) newTime, err := time.Parse("2006-01-02", t)
if err != nil { utils.PanicOnError(err)
panic(err)
}
return &newTime return &newTime
} }
@ -27,9 +26,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time {
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000") newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000")
if err != nil { utils.PanicOnError(err)
panic(err)
}
return &newTime return &newTime
} }
@ -38,9 +35,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time {
func TimeWithoutTimeZone(t string) *time.Time { func TimeWithoutTimeZone(t string) *time.Time {
newTime, err := time.Parse("15:04:05", t) newTime, err := time.Parse("15:04:05", t)
if err != nil { utils.PanicOnError(err)
panic(err)
}
return &newTime return &newTime
} }
@ -49,9 +44,7 @@ func TimeWithoutTimeZone(t string) *time.Time {
func TimeWithTimeZone(t string) *time.Time { func TimeWithTimeZone(t string) *time.Time {
newTimez, err := time.Parse("15:04:05 -0700", t) newTimez, err := time.Parse("15:04:05 -0700", t)
if err != nil { utils.PanicOnError(err)
panic(err)
}
return &newTimez return &newTimez
} }
@ -67,9 +60,7 @@ func TimestampWithTimeZone(t string, precision int) *time.Time {
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t) newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t)
if err != nil { utils.PanicOnError(err)
panic(err)
}
return &newTime return &newTime
} }

View file

@ -2,14 +2,13 @@ package utils
import ( import (
"database/sql" "database/sql"
"fmt"
"github.com/go-jet/jet/internal/3rdparty/snaker" "github.com/go-jet/jet/internal/3rdparty/snaker"
"go/format" "go/format"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"strconv"
"strings" "strings"
"time"
) )
// ToGoIdentifier converts database to Go identifier. // ToGoIdentifier converts database to Go identifier.
@ -104,44 +103,11 @@ func DirExists(path string) (bool, error) {
func replaceInvalidChars(str string) string { func replaceInvalidChars(str string) string {
str = strings.Replace(str, " ", "_", -1) str = strings.Replace(str, " ", "_", -1)
str = strings.Replace(str, "-", "_", -1) str = strings.Replace(str, "-", "_", -1)
str = strings.Replace(str, ".", "_", -1)
return str return str
} }
// FormatTimestamp formats t into Postgres' text format for timestamps. From: github.com/lib/pq
func FormatTimestamp(t time.Time) []byte {
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
// minus sign preferred by Go.
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
bc := false
if t.Year() <= 0 {
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
t = t.AddDate((-t.Year())*2+1, 0, 0)
bc = true
}
b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
_, offset := t.Zone()
offset = offset % 60
if offset != 0 {
// RFC3339Nano already printed the minus sign
if offset < 0 {
offset = -offset
}
b = append(b, ':')
if offset < 10 {
b = append(b, '0')
}
b = strconv.AppendInt(b, int64(offset), 10)
}
if bc {
b = append(b, " BC"...)
}
return b
}
// IsNil check if v is nil // IsNil check if v is nil
func IsNil(v interface{}) bool { func IsNil(v interface{}) bool {
return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil())
@ -174,3 +140,27 @@ func MustBeInitializedPtr(val interface{}, errorStr string) {
panic(errorStr) panic(errorStr)
} }
} }
// PanicOnError panics if err is not nil
func PanicOnError(err error) {
if err != nil {
panic(err)
}
}
// ErrorCatch is used in defer to recover from panics and to set err
func ErrorCatch(err *error) {
recovered := recover()
if recovered == nil {
return
}
recoveredErr, isError := recovered.(error)
if isError {
*err = recoveredErr
} else {
*err = fmt.Errorf("%v", recovered)
}
}

View file

@ -1,6 +1,7 @@
package utils package utils
import ( import (
"fmt"
"gotest.tools/assert" "gotest.tools/assert"
"testing" "testing"
) )
@ -23,3 +24,27 @@ func TestToGoIdentifier(t *testing.T) {
assert.Equal(t, ToGoIdentifier("My Table"), "MyTable") assert.Equal(t, ToGoIdentifier("My Table"), "MyTable")
assert.Equal(t, ToGoIdentifier("My-Table"), "MyTable") assert.Equal(t, ToGoIdentifier("My-Table"), "MyTable")
} }
func TestErrorCatchErr(t *testing.T) {
var err error
func() {
defer ErrorCatch(&err)
panic(fmt.Errorf("newError"))
}()
assert.Error(t, err, "newError")
}
func TestErrorCatchNonErr(t *testing.T) {
var err error
func() {
defer ErrorCatch(&err)
panic(11)
}()
assert.Error(t, err, "11")
}

View file

@ -10,13 +10,13 @@ var Dialect = newDialect()
func newDialect() jet.Dialect { func newDialect() jet.Dialect {
operatorSerializeOverrides := map[string]jet.SerializeOverride{} operatorSerializeOverrides := map[string]jet.SerializeOverride{}
operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysql_REGEXP_LIKE_operator operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysqlREGEXPLIKEoperator
operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysql_NOT_REGEXP_LIKE_operator operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysqlNOTREGEXPLIKEoperator
operatorSerializeOverrides["IS DISTINCT FROM"] = mysql_IS_DISTINCT_FROM operatorSerializeOverrides["IS DISTINCT FROM"] = mysqlISDISTINCTFROM
operatorSerializeOverrides["IS NOT DISTINCT FROM"] = mysql_IS_NOT_DISTINCT_FROM operatorSerializeOverrides["IS NOT DISTINCT FROM"] = mysqlISNOTDISTINCTFROM
operatorSerializeOverrides["/"] = mysql_DIVISION operatorSerializeOverrides["/"] = mysqlDivision
operatorSerializeOverrides["#"] = mysql_BIT_XOR operatorSerializeOverrides["#"] = mysqlBitXor
operatorSerializeOverrides[jet.StringConcatOperator] = mysql_CONCAT_operator operatorSerializeOverrides[jet.StringConcatOperator] = mysqlCONCAToperator
mySQLDialectParams := jet.DialectParams{ mySQLDialectParams := jet.DialectParams{
Name: "MySQL", Name: "MySQL",
@ -32,7 +32,7 @@ func newDialect() jet.Dialect {
return jet.NewDialect(mySQLDialectParams) return jet.NewDialect(mySQLDialectParams)
} }
func mysql_BIT_XOR(expressions ...jet.Expression) jet.SerializeFunc { func mysqlBitXor(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator XOR") panic("jet: invalid number of expressions for operator XOR")
@ -49,7 +49,7 @@ func mysql_BIT_XOR(expressions ...jet.Expression) jet.SerializeFunc {
} }
} }
func mysql_CONCAT_operator(expressions ...jet.Expression) jet.SerializeFunc { func mysqlCONCAToperator(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator CONCAT") panic("jet: invalid number of expressions for operator CONCAT")
@ -66,7 +66,7 @@ func mysql_CONCAT_operator(expressions ...jet.Expression) jet.SerializeFunc {
} }
} }
func mysql_DIVISION(expressions ...jet.Expression) jet.SerializeFunc { func mysqlDivision(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator DIV") panic("jet: invalid number of expressions for operator DIV")
@ -90,7 +90,7 @@ func mysql_DIVISION(expressions ...jet.Expression) jet.SerializeFunc {
} }
} }
func mysql_IS_NOT_DISTINCT_FROM(expressions ...jet.Expression) jet.SerializeFunc { func mysqlISNOTDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator") panic("jet: invalid number of expressions for operator")
@ -102,15 +102,15 @@ func mysql_IS_NOT_DISTINCT_FROM(expressions ...jet.Expression) jet.SerializeFunc
} }
} }
func mysql_IS_DISTINCT_FROM(expressions ...jet.Expression) jet.SerializeFunc { func mysqlISDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
out.WriteString("NOT(") out.WriteString("NOT(")
mysql_IS_NOT_DISTINCT_FROM(expressions...)(statement, out, options...) mysqlISNOTDISTINCTFROM(expressions...)(statement, out, options...)
out.WriteString(")") out.WriteString(")")
} }
} }
func mysql_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc { func mysqlREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator") panic("jet: invalid number of expressions for operator")
@ -136,7 +136,7 @@ func mysql_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc
} }
} }
func mysql_NOT_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc { func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator") panic("jet: invalid number of expressions for operator")

View file

@ -85,6 +85,47 @@ var SUMi = jet.SUMi
// SUMf is aggregate function. Returns sum of float expression. // SUMf is aggregate function. Returns sum of float expression.
var SUMf = jet.SUMf var SUMf = jet.SUMf
// -------------------- Window functions -----------------------//
// ROW_NUMBER returns number of the current row within its partition, counting from 1
var ROW_NUMBER = jet.ROW_NUMBER
// RANK of the current row with gaps; same as row_number of its first peer
var RANK = jet.RANK
// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups
var DENSE_RANK = jet.DENSE_RANK
// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1)
var PERCENT_RANK = jet.PERCENT_RANK
// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows
var CUME_DIST = jet.CUME_DIST
// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible
var NTILE = jet.NTILE
// LAG returns value evaluated at the row that is offset rows before the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LAG = jet.LAG
// LEAD returns value evaluated at the row that is offset rows after the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LEAD = jet.LEAD
// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame
var FIRST_VALUE = jet.FIRST_VALUE
// LAST_VALUE returns value evaluated at the row that is the last row of the window frame
var LAST_VALUE = jet.LAST_VALUE
// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row
var NTH_VALUE = jet.NTH_VALUE
//--------------------- String functions ------------------// //--------------------- String functions ------------------//
// BIT_LENGTH returns number of bits in string expression // BIT_LENGTH returns number of bits in string expression
@ -181,7 +222,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampExpression {
// NOW returns current datetime // NOW returns current datetime
func NOW(fsp ...int) DateTimeExpression { func NOW(fsp ...int) DateTimeExpression {
if len(fsp) > 0 { if len(fsp) > 0 {
return jet.NewTimestampFunc("NOW", jet.ConstLiteral(int64(fsp[0]))) return jet.NewTimestampFunc("NOW", jet.FixedLiteral(int64(fsp[0])))
} }
return jet.NewTimestampFunc("NOW") return jet.NewTimestampFunc("NOW")
} }

View file

@ -1,6 +1,8 @@
package mysql package mysql
import "github.com/go-jet/jet/internal/jet" import (
"github.com/go-jet/jet/internal/jet"
)
// RowLock is interface for SELECT statement row lock types // RowLock is interface for SELECT statement row lock types
type RowLock = jet.RowLock type RowLock = jet.RowLock
@ -11,6 +13,27 @@ var (
SHARE = jet.NewRowLock("SHARE") SHARE = jet.NewRowLock("SHARE")
) )
// Window function clauses
var (
PARTITION_BY = jet.PARTITION_BY
ORDER_BY = jet.ORDER_BY
UNBOUNDED = jet.UNBOUNDED
CURRENT_ROW = jet.CURRENT_ROW
)
// PRECEDING window frame clause
func PRECEDING(offset interface{}) jet.FrameExtent {
return jet.PRECEDING(toJetFrameOffset(offset))
}
// FOLLOWING window frame clause
func FOLLOWING(offset interface{}) jet.FrameExtent {
return jet.FOLLOWING(toJetFrameOffset(offset))
}
// Window is used to specify window reference from WINDOW clause
var Window = jet.WindowName
// SelectStatement is interface for MySQL SELECT statement // SelectStatement is interface for MySQL SELECT statement
type SelectStatement interface { type SelectStatement interface {
Statement Statement
@ -22,6 +45,7 @@ type SelectStatement interface {
WHERE(expression BoolExpression) SelectStatement WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement
WINDOW(name string) windowExpand
ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement
LIMIT(limit int64) SelectStatement LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement OFFSET(offset int64) SelectStatement
@ -42,7 +66,7 @@ func SELECT(projection Projection, projections ...Projection) SelectStatement {
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
newSelect := &selectStatementImpl{} newSelect := &selectStatementImpl{}
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select,
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy,
&newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock) &newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock)
newSelect.Select.Projections = toJetProjectionList(projections) newSelect.Select.Projections = toJetProjectionList(projections)
@ -66,6 +90,7 @@ type selectStatementImpl struct {
Where jet.ClauseWhere Where jet.ClauseWhere
GroupBy jet.ClauseGroupBy GroupBy jet.ClauseGroupBy
Having jet.ClauseHaving Having jet.ClauseHaving
Window jet.ClauseWindow
OrderBy jet.ClauseOrderBy OrderBy jet.ClauseOrderBy
Limit jet.ClauseLimit Limit jet.ClauseLimit
Offset jet.ClauseOffset Offset jet.ClauseOffset
@ -98,6 +123,11 @@ func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatem
return s return s
} }
func (s *selectStatementImpl) WINDOW(name string) windowExpand {
s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name})
return windowExpand{selectStatement: s}
}
func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement { func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement {
s.OrderBy.List = orderByClauses s.OrderBy.List = orderByClauses
return s return s
@ -126,3 +156,31 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement {
func (s *selectStatementImpl) AsTable(alias string) SelectTable { func (s *selectStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias) return newSelectTable(s, alias)
} }
//-----------------------------------------------------
type windowExpand struct {
selectStatement *selectStatementImpl
}
func (w windowExpand) AS(window ...jet.Window) SelectStatement {
if len(window) == 0 {
return w.selectStatement
}
windowsDefinition := w.selectStatement.Window.Definitions
windowsDefinition[len(windowsDefinition)-1].Window = window[0]
return w.selectStatement
}
func toJetFrameOffset(offset interface{}) jet.Serializer {
if offset == UNBOUNDED {
return jet.UNBOUNDED
}
// check for interval expression
//if exp, ok := offset.(Expression); ok {
// return exp
//}
return jet.FixedLiteral(offset)
}

View file

@ -77,9 +77,9 @@ func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectU
} }
// NewTable creates new table with schema Name, table Name and list of columns // NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table { func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table {
t := &tableImpl{ t := &tableImpl{
SerializerTable: jet.NewTable(schemaName, name, columns...), SerializerTable: jet.NewTable(schemaName, name, column, columns...),
} }
t.readableTableInterfaceImpl.parent = t t.readableTableInterfaceImpl.parent = t

View file

@ -11,8 +11,8 @@ var Dialect = newDialect()
func newDialect() jet.Dialect { func newDialect() jet.Dialect {
operatorSerializeOverrides := map[string]jet.SerializeOverride{} operatorSerializeOverrides := map[string]jet.SerializeOverride{}
operatorSerializeOverrides[jet.StringRegexpLikeOperator] = postgres_REGEXP_LIKE_operator operatorSerializeOverrides[jet.StringRegexpLikeOperator] = postgresREGEXPLIKEoperator
operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = postgres_NOT_REGEXP_LIKE_operator operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = postgresNOTREGEXPLIKEoperator
operatorSerializeOverrides["CAST"] = postgresCAST operatorSerializeOverrides["CAST"] = postgresCAST
dialectParams := jet.DialectParams{ dialectParams := jet.DialectParams{
@ -54,7 +54,7 @@ func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc {
} }
} }
func postgres_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc { func postgresREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator") panic("jet: invalid number of expressions for operator")
@ -80,7 +80,7 @@ func postgres_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeF
} }
} }
func postgres_NOT_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc { func postgresNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator") panic("jet: invalid number of expressions for operator")

View file

@ -87,6 +87,47 @@ var SUMf = jet.SUMf
// SUMi is aggregate function. Returns sum of expression across all integer expression. // SUMi is aggregate function. Returns sum of expression across all integer expression.
var SUMi = jet.SUMi var SUMi = jet.SUMi
// -------------------- Window functions -----------------------//
// ROW_NUMBER returns number of the current row within its partition, counting from 1
var ROW_NUMBER = jet.ROW_NUMBER
// RANK of the current row with gaps; same as row_number of its first peer
var RANK = jet.RANK
// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups
var DENSE_RANK = jet.DENSE_RANK
// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1)
var PERCENT_RANK = jet.PERCENT_RANK
// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows
var CUME_DIST = jet.CUME_DIST
// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible
var NTILE = jet.NTILE
// LAG returns value evaluated at the row that is offset rows before the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LAG = jet.LAG
// LEAD returns value evaluated at the row that is offset rows after the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LEAD = jet.LEAD
// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame
var FIRST_VALUE = jet.FIRST_VALUE
// LAST_VALUE returns value evaluated at the row that is the last row of the window frame
var LAST_VALUE = jet.LAST_VALUE
// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row
var NTH_VALUE = jet.NTH_VALUE
//--------------------- String functions ------------------// //--------------------- String functions ------------------//
// BIT_LENGTH returns number of bits in string expression // BIT_LENGTH returns number of bits in string expression

View file

@ -1,6 +1,9 @@
package postgres package postgres
import "github.com/go-jet/jet/internal/jet" import (
"github.com/go-jet/jet/internal/jet"
"math"
)
// RowLock is interface for SELECT statement row lock types // RowLock is interface for SELECT statement row lock types
type RowLock = jet.RowLock type RowLock = jet.RowLock
@ -13,6 +16,27 @@ var (
KEY_SHARE = jet.NewRowLock("KEY SHARE") KEY_SHARE = jet.NewRowLock("KEY SHARE")
) )
// Window function clauses
var (
PARTITION_BY = jet.PARTITION_BY
ORDER_BY = jet.ORDER_BY
UNBOUNDED = int64(math.MaxInt64)
CURRENT_ROW = jet.CURRENT_ROW
)
// PRECEDING window frame clause
func PRECEDING(offset int64) jet.FrameExtent {
return jet.PRECEDING(toJetFrameOffset(offset))
}
// FOLLOWING window frame clause
func FOLLOWING(offset int64) jet.FrameExtent {
return jet.FOLLOWING(toJetFrameOffset(offset))
}
// Window definition reference
var Window = jet.WindowName
// SelectStatement is interface for PostgreSQL SELECT statement // SelectStatement is interface for PostgreSQL SELECT statement
type SelectStatement interface { type SelectStatement interface {
Statement Statement
@ -24,6 +48,7 @@ type SelectStatement interface {
WHERE(expression BoolExpression) SelectStatement WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement
WINDOW(name string) windowExpand
ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement
LIMIT(limit int64) SelectStatement LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement OFFSET(offset int64) SelectStatement
@ -47,15 +72,9 @@ func SELECT(projection Projection, projections ...Projection) SelectStatement {
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
newSelect := &selectStatementImpl{} newSelect := &selectStatementImpl{}
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select,
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy,
&newSelect.Limit, &newSelect.Offset, &newSelect.For) &newSelect.Limit, &newSelect.Offset, &newSelect.For)
// statementImpl = jet.NewStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select,
// &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy,
// &newSelect.Limit, &newSelect.Offset, &newSelect.For)
//
//newSelect.expressionStatementImpl.expressionInterfaceImpl.Parent = newSelect
newSelect.Select.Projections = toJetProjectionList(projections) newSelect.Select.Projections = toJetProjectionList(projections)
newSelect.From.Table = table newSelect.From.Table = table
newSelect.Limit.Count = -1 newSelect.Limit.Count = -1
@ -75,6 +94,7 @@ type selectStatementImpl struct {
Where jet.ClauseWhere Where jet.ClauseWhere
GroupBy jet.ClauseGroupBy GroupBy jet.ClauseGroupBy
Having jet.ClauseHaving Having jet.ClauseHaving
Window jet.ClauseWindow
OrderBy jet.ClauseOrderBy OrderBy jet.ClauseOrderBy
Limit jet.ClauseLimit Limit jet.ClauseLimit
Offset jet.ClauseOffset Offset jet.ClauseOffset
@ -106,6 +126,11 @@ func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatem
return s return s
} }
func (s *selectStatementImpl) WINDOW(name string) windowExpand {
s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name})
return windowExpand{selectStatement: s}
}
func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement { func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement {
s.OrderBy.List = orderByClauses s.OrderBy.List = orderByClauses
return s return s
@ -129,3 +154,25 @@ func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement {
func (s *selectStatementImpl) AsTable(alias string) SelectTable { func (s *selectStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias) return newSelectTable(s, alias)
} }
//-----------------------------------------------------
type windowExpand struct {
selectStatement *selectStatementImpl
}
func (w windowExpand) AS(window ...jet.Window) SelectStatement {
if len(window) == 0 {
return w.selectStatement
}
windowsDefinition := w.selectStatement.Window.Definitions
windowsDefinition[len(windowsDefinition)-1].Window = window[0]
return w.selectStatement
}
func toJetFrameOffset(offset int64) jet.Serializer {
if offset == UNBOUNDED {
return jet.UNBOUNDED
}
return jet.FixedLiteral(offset)
}

View file

@ -109,10 +109,10 @@ type tableImpl struct {
} }
// NewTable creates new table with schema Name, table Name and list of columns // NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table { func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table {
t := &tableImpl{ t := &tableImpl{
SerializerTable: jet.NewTable(schemaName, name, columns...), SerializerTable: jet.NewTable(schemaName, name, column, columns...),
} }
t.readableTableInterfaceImpl.parent = t t.readableTableInterfaceImpl.parent = t

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/go-jet/jet/generator/mysql" "github.com/go-jet/jet/generator/mysql"
"github.com/go-jet/jet/generator/postgres" "github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/tests/dbconfig" "github.com/go-jet/jet/tests/dbconfig"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
@ -60,7 +61,7 @@ func initMySQLDB() {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
err := cmd.Run() err := cmd.Run()
panicOnError(err) utils.PanicOnError(err)
err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{ err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{
Host: dbconfig.MySqLHost, Host: dbconfig.MySqLHost,
@ -70,7 +71,7 @@ func initMySQLDB() {
DBName: dbName, DBName: dbName,
}) })
panicOnError(err) utils.PanicOnError(err)
} }
} }
@ -104,22 +105,16 @@ func initPostgresDB() {
SchemaName: schemaName, SchemaName: schemaName,
SslMode: "disable", SslMode: "disable",
}) })
panicOnError(err) utils.PanicOnError(err)
} }
} }
func execFile(db *sql.DB, sqlFilePath string) { func execFile(db *sql.DB, sqlFilePath string) {
testSampleSql, err := ioutil.ReadFile(sqlFilePath) testSampleSql, err := ioutil.ReadFile(sqlFilePath)
panicOnError(err) utils.PanicOnError(err)
_, err = db.Exec(string(testSampleSql)) _, err = db.Exec(string(testSampleSql))
panicOnError(err) utils.PanicOnError(err)
}
func panicOnError(err error) {
if err != nil {
panic(err)
}
} }
func printOnError(err error) { func printOnError(err error) {

View file

@ -5,6 +5,7 @@ import (
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/view"
"github.com/go-jet/jet/tests/testdata/results/common" "github.com/go-jet/jet/tests/testdata/results/common"
"github.com/google/uuid" "github.com/google/uuid"
"time" "time"
@ -36,6 +37,23 @@ func TestAllTypes(t *testing.T) {
testutils.AssertJSON(t, dest, allTypesJson) testutils.AssertJSON(t, dest, allTypesJson)
} }
func TestAllTypesViewSelect(t *testing.T) {
type AllTypesView model.AllTypes
dest := []AllTypesView{}
err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert
return
}
testutils.AssertJSON(t, dest, allTypesJson)
}
func TestUUID(t *testing.T) { func TestUUID(t *testing.T) {
query := AllTypes. query := AllTypes.

View file

@ -1,8 +1,8 @@
package mysql package mysql
import ( import (
"bytes"
"github.com/go-jet/jet/generator/mysql" "github.com/go-jet/jet/generator/mysql"
"github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/dbconfig" "github.com/go-jet/jet/tests/dbconfig"
"gotest.tools/assert" "gotest.tools/assert"
"io/ioutil" "io/ioutil"
@ -15,10 +15,9 @@ const genTestDirRoot = "./.gentestdata3"
const genTestDir3 = "./.gentestdata3/mysql" const genTestDir3 = "./.gentestdata3/mysql"
func TestGenerator(t *testing.T) { func TestGenerator(t *testing.T) {
err := os.RemoveAll(genTestDir3)
assert.NilError(t, err)
err = mysql.Generate(genTestDir3, mysql.DBConnection{ for i := 0; i < 3; i++ {
err := mysql.Generate(genTestDir3, mysql.DBConnection{
Host: dbconfig.MySqLHost, Host: dbconfig.MySqLHost,
Port: dbconfig.MySQLPort, Port: dbconfig.MySQLPort,
User: dbconfig.MySQLUser, User: dbconfig.MySQLUser,
@ -29,8 +28,9 @@ func TestGenerator(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
assertGeneratedFiles(t) assertGeneratedFiles(t)
}
err = os.RemoveAll(genTestDirRoot) err := os.RemoveAll(genTestDirRoot)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -63,53 +63,40 @@ func assertGeneratedFiles(t *testing.T) {
tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table") tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table")
assert.NilError(t, err) assert.NilError(t, err)
assertFileNameEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go") "payment.go", "rental.go", "staff.go", "store.go")
assertFileContent(t, genTestDir3+"/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) testutils.AssertFileContent(t, genTestDir3+"/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile)
// View SQL Builder files
viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view")
assert.NilError(t, err)
testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilerFile)
// Enums SQL Builder files // Enums SQL Builder files
enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum")
assert.NilError(t, err) assert.NilError(t, err)
assertFileNameEqual(t, enumFiles, "film_rating.go") testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go")
assertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile) testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile)
// Model files // Model files
modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model") modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model")
assert.NilError(t, err) assert.NilError(t, err)
assertFileNameEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go", "film_rating.go") "payment.go", "rental.go", "staff.go", "store.go",
"film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go",
"actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go",
"customer_list.go", "sales_by_store.go", "staff_list.go")
assertFileContent(t, genTestDir3+"/dvds/model/actor.go", "\npackage model", actorModelFile) testutils.AssertFileContent(t, genTestDir3+"/dvds/model/actor.go", "\npackage model", actorModelFile)
}
func assertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err)
beginIndex := bytes.Index(enumFileData, []byte(contentBegin))
//fmt.Println("-"+string(enumFileData[beginIndex:])+"-")
assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent)
}
func assertFileNameEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) {
fileNamesMap := map[string]bool{}
for _, fileInfo := range fileInfos {
fileNamesMap[fileInfo.Name()] = true
}
for _, fileName := range fileNames {
assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.")
}
} }
var mpaaRatingEnumFile = ` var mpaaRatingEnumFile = `
@ -200,3 +187,57 @@ type Actor struct {
LastUpdate time.Time LastUpdate time.Time
} }
` `
var actorInfoSQLBuilerFile = `
package view
import (
"github.com/go-jet/jet/mysql"
)
var ActorInfo = newActorInfoTable()
type ActorInfoTable struct {
mysql.Table
//Columns
ActorID mysql.ColumnInteger
FirstName mysql.ColumnString
LastName mysql.ColumnString
FilmInfo mysql.ColumnString
AllColumns mysql.IColumnList
MutableColumns mysql.IColumnList
}
// creates new ActorInfoTable with assigned alias
func (a *ActorInfoTable) AS(alias string) *ActorInfoTable {
aliasTable := newActorInfoTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newActorInfoTable() *ActorInfoTable {
var (
ActorIDColumn = mysql.IntegerColumn("actor_id")
FirstNameColumn = mysql.StringColumn("first_name")
LastNameColumn = mysql.StringColumn("last_name")
FilmInfoColumn = mysql.StringColumn("film_info")
)
return &ActorInfoTable{
Table: mysql.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
//Columns
ActorID: ActorIDColumn,
FirstName: FirstNameColumn,
LastName: LastNameColumn,
FilmInfo: FilmInfoColumn,
AllColumns: mysql.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
MutableColumns: mysql.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
}
}
`

View file

@ -1,11 +1,13 @@
package mysql package mysql
import ( import (
"fmt"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/enum" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/enum"
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/view"
"gotest.tools/assert" "gotest.tools/assert"
"testing" "testing"
@ -15,7 +17,7 @@ func TestSelect_ScanToStruct(t *testing.T) {
query := Actor. query := Actor.
SELECT(Actor.AllColumns). SELECT(Actor.AllColumns).
DISTINCT(). DISTINCT().
WHERE(Actor.ActorID.EQ(Int(1))) WHERE(Actor.ActorID.EQ(Int(2)))
testutils.AssertStatementSql(t, query, ` testutils.AssertStatementSql(t, query, `
SELECT DISTINCT actor.actor_id AS "actor.actor_id", SELECT DISTINCT actor.actor_id AS "actor.actor_id",
@ -24,20 +26,20 @@ SELECT DISTINCT actor.actor_id AS "actor.actor_id",
actor.last_update AS "actor.last_update" actor.last_update AS "actor.last_update"
FROM dvds.actor FROM dvds.actor
WHERE actor.actor_id = ?; WHERE actor.actor_id = ?;
`, int64(1)) `, int64(2))
actor := model.Actor{} actor := model.Actor{}
err := query.Query(db, &actor) err := query.Query(db, &actor)
assert.NilError(t, err) assert.NilError(t, err)
assert.DeepEqual(t, actor, actor1) assert.DeepEqual(t, actor, actor2)
} }
var actor1 = model.Actor{ var actor2 = model.Actor{
ActorID: 1, ActorID: 2,
FirstName: "PENELOPE", FirstName: "NICK",
LastName: "GUINESS", LastName: "WAHLBERG",
LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2), LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2),
} }
@ -61,7 +63,7 @@ ORDER BY actor.actor_id;
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(dest), 200) assert.Equal(t, len(dest), 200)
assert.DeepEqual(t, dest[0], actor1) assert.DeepEqual(t, dest[1], actor2)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
//testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json") //testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json")
@ -527,3 +529,172 @@ LOCK IN SHARE MODE;
err := query.Query(db, &struct{}{}) err := query.Query(db, &struct{}{})
assert.NilError(t, err) assert.NilError(t, err)
} }
func TestWindowFunction(t *testing.T) {
if sourceIsMariaDB() {
return
}
var expectedSQL = `
SELECT AVG(payment.amount) OVER (),
AVG(payment.amount) OVER (PARTITION BY payment.customer_id),
MAX(payment.amount) OVER (ORDER BY payment.payment_date DESC),
MIN(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC ROWS BETWEEN 1 PRECEDING AND 6 FOLLOWING),
SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
MAX(payment.customer_id) OVER (ORDER BY payment.payment_date DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING),
MIN(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
SUM(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
ROW_NUMBER() OVER (ORDER BY payment.payment_date),
RANK() OVER (ORDER BY payment.payment_date),
DENSE_RANK() OVER (ORDER BY payment.payment_date),
CUME_DIST() OVER (ORDER BY payment.payment_date),
NTILE(11) OVER (ORDER BY payment.payment_date),
LAG(payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date),
FIRST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date),
LAST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date),
NTH_VALUE(payment.amount, 3) OVER (ORDER BY payment.payment_date)
FROM dvds.payment
WHERE payment.payment_id < ?
GROUP BY payment.amount, payment.customer_id, payment.payment_date;
`
query := Payment.
SELECT(
AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID)),
MAXf(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate.DESC())),
MINf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).
ORDER_BY(Payment.PaymentDate.DESC()).ROWS(PRECEDING(1), FOLLOWING(6))),
SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).
ORDER_BY(Payment.PaymentDate.DESC()).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
MAXi(Payment.CustomerID).OVER(ORDER_BY(Payment.PaymentDate.DESC()).ROWS(CURRENT_ROW, FOLLOWING(UNBOUNDED))),
MINi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
SUMi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
ROW_NUMBER().OVER(ORDER_BY(Payment.PaymentDate)),
RANK().OVER(ORDER_BY(Payment.PaymentDate)),
DENSE_RANK().OVER(ORDER_BY(Payment.PaymentDate)),
CUME_DIST().OVER(ORDER_BY(Payment.PaymentDate)),
NTILE(11).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)),
FIRST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
NTH_VALUE(Payment.Amount, 3).OVER(ORDER_BY(Payment.PaymentDate)),
).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate).
WHERE(Payment.PaymentID.LT(Int(10)))
fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10))
err := query.Query(db, &struct{}{})
assert.NilError(t, err)
}
func TestWindowClause(t *testing.T) {
var expectedSQL = `
SELECT AVG(payment.amount) OVER (),
AVG(payment.amount) OVER (w1),
AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)
FROM dvds.payment
WHERE payment.payment_id < ?
WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id)
ORDER BY payment.customer_id;
`
query := Payment.SELECT(
AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(Window("w1")),
AVG(Payment.Amount).OVER(
Window("w2").
ORDER_BY(Payment.CustomerID).
RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)),
),
AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
).
WHERE(Payment.PaymentID.LT(Int(10))).
WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)).
WINDOW("w2").AS(Window("w1")).
WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)).
ORDER_BY(Payment.CustomerID)
fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, int64(10))
err := query.Query(db, &struct{}{})
assert.NilError(t, err)
}
func TestSimpleView(t *testing.T) {
query := SELECT(
view.ActorInfo.AllColumns,
).
FROM(view.ActorInfo).
ORDER_BY(view.ActorInfo.ActorID).
LIMIT(10)
type ActorInfo struct {
ActorID int
FirstName string
LastName string
FilmInfo string
}
var dest []ActorInfo
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 10)
testutils.AssertJSON(t, dest[1:2], `
[
{
"ActorID": 2,
"FirstName": "NICK",
"LastName": "WAHLBERG",
"FilmInfo": "Action: BULL SHAWSHANK; Animation: FIGHT JAWBREAKER; Children: JERSEY SASSY; Classics: DRACULA CRYSTAL, GILBERT PELICAN; Comedy: MALLRATS UNITED, RUSHMORE MERMAID; Documentary: ADAPTATION HOLES; Drama: WARDROBE PHANTOM; Family: APACHE DIVINE, CHISUM BEHAVIOR, INDIAN LOVE, MAGUIRE APACHE; Foreign: BABY HALL, HAPPINESS UNITED; Games: ROOF CHAMPION; Music: LUCKY FLYING; New: DESTINY SATURDAY, FLASH WARS, JEKYLL FROGMEN, MASK PEACH; Sci-Fi: CHAINSAW UPTOWN, GOODFELLAS SALUTE; Travel: LIAISONS SWEET, SMILE EARRING"
}
]
`)
}
func TestJoinViewWithTable(t *testing.T) {
query := SELECT(
view.CustomerList.AllColumns,
Rental.AllColumns,
).
FROM(view.CustomerList.
INNER_JOIN(Rental, view.CustomerList.ID.EQ(Rental.CustomerID)),
).
ORDER_BY(view.CustomerList.ID).
WHERE(view.CustomerList.ID.LT_EQ(Int(2)))
var dest []struct {
model.CustomerList `sql:"primary_key=ID"`
Rentals []model.Rental
}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
assert.Equal(t, len(dest[0].Rentals), 32)
assert.Equal(t, len(dest[1].Rentals), 27)
}

View file

@ -6,6 +6,7 @@ import (
. "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/view"
"github.com/go-jet/jet/tests/testdata/results/common" "github.com/go-jet/jet/tests/testdata/results/common"
"github.com/google/uuid" "github.com/google/uuid"
"gotest.tools/assert" "gotest.tools/assert"
@ -23,6 +24,19 @@ func TestAllTypesSelect(t *testing.T) {
assert.DeepEqual(t, dest[1], allTypesRow1) assert.DeepEqual(t, dest[1], allTypesRow1)
} }
func TestAllTypesViewSelect(t *testing.T) {
type AllTypesView model.AllTypes
dest := []AllTypesView{}
err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest[0], AllTypesView(allTypesRow0))
assert.DeepEqual(t, dest[1], AllTypesView(allTypesRow1))
}
func TestAllTypesInsertModel(t *testing.T) { func TestAllTypesInsertModel(t *testing.T) {
query := AllTypes.INSERT(AllTypes.AllColumns). query := AllTypes.INSERT(AllTypes.AllColumns).
MODEL(allTypesRow0). MODEL(allTypesRow0).
@ -31,8 +45,8 @@ func TestAllTypesInsertModel(t *testing.T) {
dest := []model.AllTypes{} dest := []model.AllTypes{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0], allTypesRow0) assert.DeepEqual(t, dest[0], allTypesRow0)
assert.DeepEqual(t, dest[1], allTypesRow1) assert.DeepEqual(t, dest[1], allTypesRow1)

View file

@ -1,8 +1,8 @@
package postgres package postgres
import ( import (
"bytes"
"github.com/go-jet/jet/generator/postgres" "github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/dbconfig" "github.com/go-jet/jet/tests/dbconfig"
"gotest.tools/assert" "gotest.tools/assert"
"io/ioutil" "io/ioutil"
@ -71,10 +71,8 @@ func TestCmdGenerator(t *testing.T) {
func TestGenerator(t *testing.T) { func TestGenerator(t *testing.T) {
err := os.RemoveAll(genTestDir2) for i := 0; i < 3; i++ {
assert.NilError(t, err) err := postgres.Generate(genTestDir2, postgres.DBConnection{
err = postgres.Generate(genTestDir2, postgres.DBConnection{
Host: dbconfig.Host, Host: dbconfig.Host,
Port: dbconfig.Port, Port: dbconfig.Port,
User: dbconfig.User, User: dbconfig.User,
@ -89,8 +87,9 @@ func TestGenerator(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
assertGeneratedFiles(t) assertGeneratedFiles(t)
}
err = os.RemoveAll(genTestDir2) err := os.RemoveAll(genTestDir2)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -99,53 +98,39 @@ func assertGeneratedFiles(t *testing.T) {
tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table") tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table")
assert.NilError(t, err) assert.NilError(t, err)
assertFileNameEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go") "payment.go", "rental.go", "staff.go", "store.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile)
// View SQL Builder files
viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view")
assert.NilError(t, err)
testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go")
testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilderFile)
// Enums SQL Builder files // Enums SQL Builder files
enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum") enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum")
assert.NilError(t, err) assert.NilError(t, err)
assertFileNameEqual(t, enumFiles, "mpaa_rating.go") testutils.AssertFileNamesEqual(t, enumFiles, "mpaa_rating.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile) testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile)
// Model files // Model files
modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model") modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model")
assert.NilError(t, err) assert.NilError(t, err)
assertFileNameEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go") "payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go",
"actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go",
"customer_list.go", "sales_by_store.go", "staff_list.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", "\npackage model", actorModelFile) testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", "\npackage model", actorModelFile)
}
func assertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err)
beginIndex := bytes.Index(enumFileData, []byte(contentBegin))
//fmt.Println("-"+string(enumFileData[beginIndex:])+"-")
assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent)
}
func assertFileNameEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) {
fileNamesMap := map[string]bool{}
for _, fileInfo := range fileInfos {
fileNamesMap[fileInfo.Name()] = true
}
for _, fileName := range fileNames {
assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.")
}
} }
var mpaaRatingEnumFile = ` var mpaaRatingEnumFile = `
@ -236,3 +221,57 @@ type Actor struct {
LastUpdate time.Time LastUpdate time.Time
} }
` `
var actorInfoSQLBuilderFile = `
package view
import (
"github.com/go-jet/jet/postgres"
)
var ActorInfo = newActorInfoTable()
type ActorInfoTable struct {
postgres.Table
//Columns
ActorID postgres.ColumnInteger
FirstName postgres.ColumnString
LastName postgres.ColumnString
FilmInfo postgres.ColumnString
AllColumns postgres.IColumnList
MutableColumns postgres.IColumnList
}
// creates new ActorInfoTable with assigned alias
func (a *ActorInfoTable) AS(alias string) *ActorInfoTable {
aliasTable := newActorInfoTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newActorInfoTable() *ActorInfoTable {
var (
ActorIDColumn = postgres.IntegerColumn("actor_id")
FirstNameColumn = postgres.StringColumn("first_name")
LastNameColumn = postgres.StringColumn("last_name")
FilmInfoColumn = postgres.StringColumn("film_info")
)
return &ActorInfoTable{
Table: postgres.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
//Columns
ActorID: ActorIDColumn,
FirstName: FirstNameColumn,
LastName: LastNameColumn,
FilmInfo: FilmInfoColumn,
AllColumns: postgres.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
MutableColumns: postgres.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
}
}
`

View file

@ -7,6 +7,7 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/enum" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/enum"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/view"
"gotest.tools/assert" "gotest.tools/assert"
"testing" "testing"
"time" "time"
@ -19,15 +20,15 @@ SELECT DISTINCT actor.actor_id AS "actor.actor_id",
actor.last_name AS "actor.last_name", actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update" actor.last_update AS "actor.last_update"
FROM dvds.actor FROM dvds.actor
WHERE actor.actor_id = 1; WHERE actor.actor_id = 2;
` `
query := Actor. query := Actor.
SELECT(Actor.AllColumns). SELECT(Actor.AllColumns).
DISTINCT(). DISTINCT().
WHERE(Actor.ActorID.EQ(Int(1))) WHERE(Actor.ActorID.EQ(Int(2)))
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1)) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(2))
actor := model.Actor{} actor := model.Actor{}
err := query.Query(db, &actor) err := query.Query(db, &actor)
@ -35,9 +36,9 @@ WHERE actor.actor_id = 1;
assert.NilError(t, err) assert.NilError(t, err)
expectedActor := model.Actor{ expectedActor := model.Actor{
ActorID: 1, ActorID: 2,
FirstName: "Penelope", FirstName: "Nick",
LastName: "Guiness", LastName: "Wahlberg",
LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2), LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2),
} }
@ -1615,3 +1616,169 @@ SELECT true,
err := query.Query(db, &struct{}{}) err := query.Query(db, &struct{}{})
assert.NilError(t, err) assert.NilError(t, err)
} }
func TestWindowFunction(t *testing.T) {
var expectedSQL = `
SELECT AVG(payment.amount) OVER (),
AVG(payment.amount) OVER (PARTITION BY payment.customer_id),
MAX(payment.amount) OVER (ORDER BY payment.payment_date DESC),
MIN(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC ROWS BETWEEN 1 PRECEDING AND 6 FOLLOWING),
SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
MAX(payment.customer_id) OVER (ORDER BY payment.payment_date DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING),
MIN(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
SUM(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
ROW_NUMBER() OVER (ORDER BY payment.payment_date),
RANK() OVER (ORDER BY payment.payment_date),
DENSE_RANK() OVER (ORDER BY payment.payment_date),
CUME_DIST() OVER (ORDER BY payment.payment_date),
NTILE(11) OVER (ORDER BY payment.payment_date),
LAG(payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount, 2, $1) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount, 2, $2) OVER (ORDER BY payment.payment_date),
FIRST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date),
LAST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date),
NTH_VALUE(payment.amount, 3) OVER (ORDER BY payment.payment_date)
FROM dvds.payment
WHERE payment.payment_id < $3
GROUP BY payment.amount, payment.customer_id, payment.payment_date;
`
query := Payment.
SELECT(
AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID)),
MAXf(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate.DESC())),
MINf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).
ORDER_BY(Payment.PaymentDate.DESC()).ROWS(PRECEDING(1), FOLLOWING(6))),
SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).
ORDER_BY(Payment.PaymentDate.DESC()).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
MAXi(Payment.CustomerID).OVER(ORDER_BY(Payment.PaymentDate.DESC()).ROWS(CURRENT_ROW, FOLLOWING(UNBOUNDED))),
MINi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
SUMi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
ROW_NUMBER().OVER(ORDER_BY(Payment.PaymentDate)),
RANK().OVER(ORDER_BY(Payment.PaymentDate)),
DENSE_RANK().OVER(ORDER_BY(Payment.PaymentDate)),
CUME_DIST().OVER(ORDER_BY(Payment.PaymentDate)),
NTILE(11).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)),
FIRST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
NTH_VALUE(Payment.Amount, 3).OVER(ORDER_BY(Payment.PaymentDate)),
).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate).
WHERE(Payment.PaymentID.LT(Int(10)))
fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10))
err := query.Query(db, &struct{}{})
assert.NilError(t, err)
}
func TestWindowClause(t *testing.T) {
var expectedSQL = `
SELECT AVG(payment.amount) OVER (),
AVG(payment.amount) OVER (w1),
AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)
FROM dvds.payment
WHERE payment.payment_id < $1
WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id)
ORDER BY payment.customer_id;
`
query := Payment.SELECT(
AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(Window("w1")),
AVG(Payment.Amount).OVER(
Window("w2").
ORDER_BY(Payment.CustomerID).
RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)),
),
AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
).
WHERE(Payment.PaymentID.LT(Int(10))).
WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)).
WINDOW("w2").AS(Window("w1")).
WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)).
ORDER_BY(Payment.CustomerID)
fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, int64(10))
err := query.Query(db, &struct{}{})
assert.NilError(t, err)
}
func TestSimpleView(t *testing.T) {
query := SELECT(
view.ActorInfo.AllColumns,
).
FROM(view.ActorInfo).
ORDER_BY(view.ActorInfo.ActorID).
LIMIT(10)
type ActorInfo struct {
ActorID int
FirstName string
LastName string
FilmInfo string
}
var dest []ActorInfo
err := query.Query(db, &dest)
assert.NilError(t, err)
testutils.AssertJSON(t, dest[1:2], `
[
{
"ActorID": 2,
"FirstName": "Nick",
"LastName": "Wahlberg",
"FilmInfo": "Action: Bull Shawshank, Animation: Fight Jawbreaker, Children: Jersey Sassy, Classics: Dracula Crystal, Gilbert Pelican, Comedy: Mallrats United, Rushmore Mermaid, Documentary: Adaptation Holes, Drama: Wardrobe Phantom, Family: Apache Divine, Chisum Behavior, Indian Love, Maguire Apache, Foreign: Baby Hall, Happiness United, Games: Roof Champion, Music: Lucky Flying, New: Destiny Saturday, Flash Wars, Jekyll Frogmen, Mask Peach, Sci-Fi: Chainsaw Uptown, Goodfellas Salute, Travel: Liaisons Sweet, Smile Earring"
}
]
`)
}
func TestJoinViewWithTable(t *testing.T) {
query := SELECT(
view.CustomerList.AllColumns,
Rental.AllColumns,
).
FROM(view.CustomerList.
INNER_JOIN(Rental, view.CustomerList.ID.EQ(Rental.CustomerID)),
).
ORDER_BY(view.CustomerList.ID).
WHERE(view.CustomerList.ID.LT_EQ(Int(2)))
var dest []struct {
model.CustomerList `sql:"primary_key=ID"`
Rentals []model.Rental
}
fmt.Println(query.DebugSql())
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
assert.Equal(t, len(dest[0].Rentals), 32)
assert.Equal(t, len(dest[1].Rentals), 27)
}

@ -1 +1 @@
Subproject commit 7f3f3cc26ce34324f3699d6b422376671b827490 Subproject commit 1f6bd8bb86458019fa43b1e2cd7ae9488a7ac9a4