Merge pull request #99 from go-jet/develop

Release 2.6.0
This commit is contained in:
go-jet 2021-10-25 16:33:21 +02:00 committed by GitHub
commit d335b6cdad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
124 changed files with 9352 additions and 1787 deletions

View file

@ -69,7 +69,7 @@ jobs:
- run: - run:
name: Install MySQL CLI; name: Install MySQL CLI;
command: | command: |
sudo apt-get update && sudo apt-get install default-mysql-client sudo apt-get --allow-releaseinfo-change update && sudo apt-get install default-mysql-client
- run: - run:
name: Create MySQL user and databases name: Create MySQL user and databases
@ -88,7 +88,8 @@ jobs:
- run: mkdir -p $TEST_RESULTS - run: mkdir -p $TEST_RESULTS
- run: MY_SQL_SOURCE=MySQL go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml # this will run all tests and exclude test files from code coverage report
- run: MY_SQL_SOURCE=MySQL go test -v ./... -covermode=atomic -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/sqlite/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
- run: - run:
name: Upload code coverage name: Upload code coverage
@ -138,7 +139,7 @@ jobs:
- run: - run:
name: Install MySQL CLI; name: Install MySQL CLI;
command: | command: |
sudo apt-get update && sudo apt-get install default-mysql-client sudo apt-get --allow-releaseinfo-change update && sudo apt-get install default-mysql-client
- run: - run:
name: Init MariaDB database name: Init MariaDB database

1
.gitignore vendored
View file

@ -19,3 +19,4 @@
gen gen
.gentestdata .gentestdata
.tests/testdata/ .tests/testdata/
.gen

108
README.md
View file

@ -9,10 +9,10 @@
Jet is a complete solution for efficient and high performance database access, consisting of type-safe SQL builder Jet is a complete solution for efficient and high performance database access, consisting of type-safe SQL builder
with code generation and automatic query result data mapping. with code generation and automatic query result data mapping.
Jet currently supports `PostgreSQL`, `MySQL` and `MariaDB`. Future releases will add support for additional databases. Jet currently supports `PostgreSQL`, `MySQL`, `MariaDB` and `SQLite`. Future releases will add support for additional databases.
![jet](https://github.com/go-jet/jet/wiki/image/jet.png) ![jet](https://github.com/go-jet/jet/wiki/image/jet.png)
Jet is the easiest and the fastest way to write complex type-safe SQL queries as a Go code and map database query result Jet is the easiest, and the fastest way to write complex type-safe SQL queries as a Go code and map database query result
into complex object composition. __It is not an ORM.__ into complex object composition. __It is not an ORM.__
## Motivation ## Motivation
@ -24,7 +24,7 @@ https://medium.com/@go.jet/jet-5f3667efa0cc
- [Prerequisites](#prerequisites) - [Prerequisites](#prerequisites)
- [Installation](#installation) - [Installation](#installation)
- [Quick Start](#quick-start) - [Quick Start](#quick-start)
- [Generate sql builder and model files](#generate-sql-builder-and-model-files) - [Generate sql builder and model types](#generate-sql-builder-and-model-types)
- [Lets write some SQL queries in Go](#lets-write-some-sql-queries-in-go) - [Lets write some SQL queries in Go](#lets-write-some-sql-queries-in-go)
- [Execute query and store result](#execute-query-and-store-result) - [Execute query and store result](#execute-query-and-store-result)
- [Benefits](#benefits) - [Benefits](#benefits)
@ -33,24 +33,17 @@ https://medium.com/@go.jet/jet-5f3667efa0cc
- [License](#license) - [License](#license)
## Features ## Features
1) Auto-generated type-safe SQL Builder 1) Auto-generated type-safe SQL Builder. Statements supported:
- PostgreSQL: * [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, LOCK_IN_SHARE_MODE, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)`
* [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)` * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, QUERY, ON_CONFLICT/ON_DUPLICATE_KEY_UPDATE, RETURNING)`,
* [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, QUERY, ON_CONFLICT, RETURNING)`,
* [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, MODEL, WHERE, RETURNING)`, * [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, MODEL, WHERE, RETURNING)`,
* [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, RETURNING)`, * [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, ORDER_BY, LIMIT, RETURNING)`,
* [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(IN, NOWAIT)` * [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(IN, NOWAIT)`, `(READ, WRITE)`
* [WITH](https://github.com/go-jet/jet/wiki/WITH)
- MySQL and MariaDB:
* [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, WINDOW, sub-queries)`
* [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, ON_DUPLICATE_KEY_UPDATE, query)`,
* [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, MODEL, WHERE)`,
* [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, ORDER_BY, LIMIT)`,
* [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(READ, WRITE)`
* [WITH](https://github.com/go-jet/jet/wiki/WITH) * [WITH](https://github.com/go-jet/jet/wiki/WITH)
2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store 2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store
result of database queries. Can be combined to create desired query result destination. result of database queries. Can be combined to create desired query result destination.
3) Query execution with result mapping to arbitrary destination structure. 3) Query execution with result mapping to arbitrary destination.
## Getting Started ## Getting Started
@ -67,7 +60,7 @@ Use the command bellow to add jet as a dependency into `go.mod` project:
$ go get -u github.com/go-jet/jet/v2 $ go get -u github.com/go-jet/jet/v2
``` ```
Jet generator can be install in the following ways: Jet generator can be installed in the following ways:
1) Install jet generator to GOPATH/bin folder: 1) Install jet generator to GOPATH/bin folder:
```sh ```sh
@ -75,7 +68,7 @@ cd $GOPATH/src/ && GO111MODULE=off go get -u github.com/go-jet/jet/cmd/jet
``` ```
*Make sure GOPATH/bin folder is added to the PATH environment variable.* *Make sure GOPATH/bin folder is added to the PATH environment variable.*
2) Install jet generator to specific folder: 2) Install jet generator into specific folder:
```sh ```sh
git clone https://github.com/go-jet/jet.git git clone https://github.com/go-jet/jet.git
@ -91,19 +84,20 @@ go install github.com/go-jet/jet/v2/cmd/jet@latest
which defaults to $GOPATH/bin or $HOME/go/bin if the GOPATH environment variable is not set.* which defaults to $GOPATH/bin or $HOME/go/bin if the GOPATH environment variable is not set.*
### Quick Start ### Quick Start
For this quick start example we will use PostgreSQL sample _'dvd rental'_ database. Full database dump can be found in [./tests/testdata/init/postgres/dvds.sql](./tests/testdata/init/postgres/dvds.sql). For this quick start example we will use PostgreSQL sample _'dvd rental'_ database. Full database dump can be found in
[./tests/testdata/init/postgres/dvds.sql](https://github.com/go-jet/jet-test-data/blob/master/init/postgres/dvds.sql).
Schema diagram of interest for example can be found [here](./examples/quick-start/diagram.png). Schema diagram of interest for example can be found [here](./examples/quick-start/diagram.png).
#### Generate SQL Builder and Model files #### Generate SQL Builder and Model types
To generate jet SQL Builder and Data Model files from postgres database, we need to call `jet` generator with postgres To generate jet SQL Builder and Data Model types from postgres database, we need to call `jet` generator with postgres
connection parameters and root destination folder path for generated files.\ connection parameters and root destination folder path for generated files.
Assuming we are running local postgres database, with user `jetuser`, user password `jetpass`, database `jetdb` and Assuming we are running local postgres database, with user `user`, user password `pass`, database `jetdb` and
schema `dvds` we will use this command: schema `dvds` we will use this command:
```sh ```sh
jet -source=PostgreSQL -host=localhost -port=5432 -user=jetuser -password=jetpass -dbname=jetdb -schema=dvds -path=./.gen jet -dsn=postgresql://user:pass@localhost:5432/jetdb -schema=dvds -path=./.gen
``` ```
```sh ```sh
Connecting to postgres database: host=localhost port=5432 user=jetuser password=jetpass dbname=jetdb sslmode=disable Connecting to postgres database: postgresql://user:pass@localhost:5432/jetdb
Retrieving schema information... Retrieving schema information...
FOUND 15 table(s), 7 view(s), 1 enum(s) FOUND 15 table(s), 7 view(s), 1 enum(s)
Cleaning up destination directory... Cleaning up destination directory...
@ -115,14 +109,19 @@ Generating view model files...
Generating enum model files... Generating enum model files...
Done Done
``` ```
Procedure is similar for MySQL or MariaDB, except source should be replaced with `MySql` or `MariaDB` and schema name should Procedure is similar for MySQL, MariaDB and SQLite. For instance:
be omitted (both databases doesn't have schema support). ```sh
jet -source=mysql -dsn="user:pass@tcp(localhost:3306)/dbname" -path=./gen
jet -dsn="mariadb://user:pass@tcp(localhost:3306)/dvds" -path=./gen # source flag can be omitted if data source appears in dsn
jet -source=sqlite -dsn="/path/to/sqlite/database/file" -schema=dvds -path=./gen
jet -dsn="file:///path/to/sqlite/database/file" -schema=dvds -path=./gen # sqlite database assumed for 'file' data sources
```
_*User has to have a permission to read information schema tables._ _*User has to have a permission to read information schema tables._
As command output suggest, Jet will: As command output suggest, Jet will:
- connect to postgres database and retrieve information about the _tables_, _views_ and _enums_ of `dvds` schema - connect to postgres database and retrieve information about the _tables_, _views_ and _enums_ of `dvds` schema
- delete everything in schema destination folder - `./gen/jetdb/dvds`, - delete everything in schema destination folder - `./gen/jetdb/dvds`,
- and finally generate SQL Builder and Model files for each schema table, view and enum. - and finally generate SQL Builder and Model types for each schema table, view and enum.
Generated files folder structure will look like this: Generated files folder structure will look like this:
@ -147,14 +146,14 @@ Generated files folder structure will look like this:
| | |-- mpaa_rating.go | | |-- mpaa_rating.go
| | ... | | ...
``` ```
Types from `table`, `view` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store Types from `table`, `view` and `enum` are used to write type safe SQL in Go, and `model` types are combined to store
results of the SQL queries. results of the SQL queries.
#### Lets write some SQL queries in Go #### Let's write some SQL queries in Go
First we need to import jet and generated files from previous step: First we need to import postgres SQLBuilder and generated packages from the previous step:
```go ```go
import ( import (
// dot import so go code would resemble as much as native SQL // dot import so go code would resemble as much as native SQL
@ -165,7 +164,7 @@ import (
"github.com/go-jet/jet/v2/examples/quick-start/gen/jetdb/dvds/model" "github.com/go-jet/jet/v2/examples/quick-start/gen/jetdb/dvds/model"
) )
``` ```
Lets say we want to retrieve the list of all _actors_ that acted in _films_ longer than 180 minutes, _film language_ is 'English' Let's say we want to retrieve the list of all _actors_ that acted in _films_ longer than 180 minutes, _film language_ is 'English'
and _film category_ is not 'Action'. and _film category_ is not 'Action'.
```java ```java
stmt := SELECT( stmt := SELECT(
@ -189,17 +188,17 @@ stmt := SELECT(
Film.FilmID.ASC(), Film.FilmID.ASC(),
) )
``` ```
_Package(dot) import is used so that statement would resemble as much as possible as native SQL._ _Package(dot) import is used, so the statements would resemble as much as possible as native SQL._
Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with
string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns
and can be compared only with integer columns and expressions. and can be compared only with integer columns and expressions.
__How to get parametrized SQL query from statement?__ __How to get a parametrized SQL query from the statement?__
```go ```go
query, args := stmt.Sql() query, args := stmt.Sql()
``` ```
query - parametrized query\ query - parametrized query
args - parameters for the query args - query parameters
<details> <details>
<summary>Click to see `query` and `args`</summary> <summary>Click to see `query` and `args`</summary>
@ -248,7 +247,7 @@ __How to get debug SQL from statement?__
```go ```go
debugSql := stmt.DebugSql() debugSql := stmt.DebugSql()
``` ```
debugSql - query string that can be copy pasted to sql editor and executed. __It's not intended to be used in production!!!__ debugSql - this query string can be copy-pasted to sql editor and executed. __It is not intended to be used in production, only for the purpose of debugging!!!__
<details> <details>
<summary>Click to see debug sql</summary> <summary>Click to see debug sql</summary>
@ -291,14 +290,17 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
#### Execute query and store result #### Execute query and store result
Well formed SQL is just a first half of the job. Lets see how can we make some sense of result set returned executing Well-formed SQL is just a first half of the job. Let's see how can we make some sense of result set returned executing
above statement. Usually this is the most complex and tedious work, but with Jet it is the easiest. above statement. Usually this is the most complex and tedious work, but with Jet it is the easiest.
First we have to create desired structure to store query result. First we have to create desired structure to store query result.
This is done be combining autogenerated model types or it can be done This is done be combining autogenerated model types, or it can be done
manually(see [wiki](https://github.com/go-jet/jet/wiki/Query-Result-Mapping-(QRM)) for more information). by combining custom model types(see [wiki](https://github.com/go-jet/jet/wiki/Query-Result-Mapping-(QRM)#custom-model-types) for more information).
Let's say this is our desired structure: It's possible to overwrite default jet generator behavior, and all the aspects of generated model and SQLBuilder types can be
tailor-made([wiki](https://github.com/go-jet/jet/wiki/Generator#generator-customization)).
Let's say this is our desired structure made of autogenerated types:
```go ```go
var dest []struct { var dest []struct {
model.Actor model.Actor
@ -315,7 +317,7 @@ var dest []struct {
`Langauge` field is just a single model struct. `Film` can belong to multiple categories. `Langauge` field is just a single model struct. `Film` can belong to multiple categories.
_*There is no limitation of how big or nested destination can be._ _*There is no limitation of how big or nested destination can be._
Now lets execute a above statement on open database connection (or transaction) db and store result into `dest`. Now lets execute above statement on open database connection (or transaction) db and store result into `dest`.
```go ```go
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
@ -524,7 +526,7 @@ found at project [Wiki](https://github.com/go-jet/jet/wiki) page.
## Benefits ## Benefits
What are the benefits of writing SQL in Go using Jet? What are the benefits of writing SQL in Go using Jet?
The biggest benefit is speed. Speed is improved in 3 major areas: The biggest benefit is speed. Speed is being improved in 3 major areas:
##### Speed of development ##### Speed of development
@ -538,32 +540,34 @@ Jet will always perform better as developers can write complex query and retriev
Thus handler time lost on latency between server and database can be constant. Handler execution will be proportional Thus handler time lost on latency between server and database can be constant. Handler execution will be proportional
only to the query complexity and the number of rows returned from database. only to the query complexity and the number of rows returned from database.
With Jet it is even possible to join the whole database and store the whole structured result in one database call. With Jet, it is even possible to join the whole database and store the whole structured result in one database call.
This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/postgres/chinook_db_test.go#L40). This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/postgres/chinook_db_test.go#L40).
The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s. The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s.
##### How quickly bugs are found ##### How quickly bugs are found
The most expensive bugs are the one on the production and the least expensive are those found during development. The most expensive bugs are the one discovered on the production, and the least expensive are those found during development.
With automatically generated type safe SQL, not only queries are written faster but bugs are found sooner. With automatically generated type safe SQL, not only queries are written faster but bugs are found sooner.
Lets return to quick start example, and take closer look at a line: Lets return to quick start example, and take closer look at a line:
```go ```go
AND(Film.Length.GT(Int(180))), AND(Film.Length.GT(Int(180))),
``` ```
Lets say someone changes column `length` to `duration` from `film` table. The next go build will fail at that line and Let's say someone changes column `length` to `duration` from `film` table. The next go build will fail at that line, and
the bug will be caught at compile time. the bug will be caught at compile time.
Lets say someone changes the type of `length` column to some non integer type. Build will also fail at the same line Let's say someone changes the type of `length` column to some non integer type. Build will also fail at the same line
because integer columns and expressions can be only compered to other integer columns and expressions. because integer columns and expressions can be only compared to other integer columns and expressions.
Build will also fail if someone removes `length` column from `film` table, because `Film` field will be omitted from SQL Builder and Model types, next time `jet` generator is run. Build will also fail if someone removes `length` column from `film` table. `Film` field will be omitted from SQL Builder and Model types,
next time `jet` generator is run.
Without Jet these bugs will have to be either caught by some test or by manual testing. Without Jet these bugs will have to be either caught by some test or by manual testing.
## Dependencies ## Dependencies
At the moment Jet dependence only of: At the moment Jet dependence only of:
- `github.com/lib/pq` _(Used by jet generator to read information about database schema from `PostgreSQL`)_ - `github.com/lib/pq` _(Used by jet generator to read `PostgreSQL` database information)_
- `github.com/go-sql-driver/mysql` _(Used by jet generator to read information about database from `MySQL` and `MariaDB`)_ - `github.com/go-sql-driver/mysql` _(Used by jet generator to read `MySQL` and `MariaDB` database information)_
- `github.com/mattn/go-sqlite3` _(Used by jet generator to read `SQLite` database information)_
- `github.com/google/uuid` _(Used in data model files and for debug purposes)_ - `github.com/google/uuid` _(Used in data model files and for debug purposes)_
To run the tests, additional dependencies are required: To run the tests, additional dependencies are required:

View file

@ -3,19 +3,21 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
mysqlgen "github.com/go-jet/jet/v2/generator/mysql" sqlitegen "github.com/go-jet/jet/v2/generator/sqlite"
postgresgen "github.com/go-jet/jet/v2/generator/postgres"
"github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/postgres"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
"os" "os"
"strings" "strings"
mysqlgen "github.com/go-jet/jet/v2/generator/mysql"
postgresgen "github.com/go-jet/jet/v2/generator/postgres"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
) )
var ( var (
source string source string
dsn string
host string host string
port int port int
user string user string
@ -29,8 +31,9 @@ var (
) )
func init() { func init() {
flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL or MariaDB)") flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL, MariaDB or SQLite)")
flag.StringVar(&dsn, "dsn", "", "Data source name connection string (Example: postgresql://user@localhost:5432/otherdb?sslmode=trust)")
flag.StringVar(&host, "host", "", "Database host path (Example: localhost)") flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
flag.IntVar(&port, "port", 0, "Database port") flag.IntVar(&port, "port", 0, "Database port")
flag.StringVar(&user, "user", "", "Database user") flag.StringVar(&user, "user", "", "Database user")
@ -47,11 +50,22 @@ func main() {
flag.Usage = func() { flag.Usage = func() {
_, _ = fmt.Fprint(os.Stdout, ` _, _ = fmt.Fprint(os.Stdout, `
Jet generator 2.5.0 Jet generator 2.6.0
Usage: Usage:
-dsn string
Data source name. Unified format for connecting to database.
PostgreSQL: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
Example:
postgresql://user:pass@localhost:5432/dbname
MySQL: https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html
Example:
mysql://jet:jet@tcp(localhost:3306)/dvds
SQLite: https://www.sqlite.org/c3ref/open.html#urifilenameexamples
Example:
file://path/to/database/file
-source string -source string
Database system name (PostgreSQL, MySQL or MariaDB) Database system name (PostgreSQL, MySQL, MariaDB or SQLite)
-host string -host string
Database host path (Example: localhost) Database host path (Example: localhost)
-port int -port int
@ -65,25 +79,48 @@ Usage:
-params string -params string
Additional connection string parameters(optional) Additional connection string parameters(optional)
-schema string -schema string
Database schema name. (default "public") (ignored for MySQL and MariaDB) Database schema name. (default "public") (ignored for MySQL, MariaDB and SQLite)
-sslmode string -sslmode string
Whether or not to use SSL(optional) (default "disable") (ignored for MySQL and MariaDB) Whether or not to use SSL(optional) (default "disable") (ignored for MySQL, MariaDB and SQLite)
-path string -path string
Destination dir for files generated. Destination dir for files generated.
Example commands:
$ jet -source=PostgreSQL -dbname=jetdb -host=localhost -port=5432 -user=jet -password=jet -schema=dvds -path=./gen
$ jet -dsn=postgresql://jet:jet@localhost:5432/jetdb -schema=dvds -path=./gen
$ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds -path=./gen
$ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -schema=dvds -path=./gen
`) `)
} }
flag.Parse() flag.Parse()
if dsn == "" {
// validations for separated connection flags.
if source == "" || host == "" || port == 0 || user == "" || dbName == "" { if source == "" || host == "" || port == 0 || user == "" || dbName == "" {
printErrorAndExit("\nERROR: required flag(s) missing") printErrorAndExit("ERROR: required flag(s) missing")
}
} else {
if source == "" {
// try to get source from schema
source = detectSchema(dsn)
}
// validations when dsn != ""
if source == "" {
printErrorAndExit("ERROR: required -source flag missing.")
}
} }
var err error var err error
switch strings.ToLower(strings.TrimSpace(source)) { switch strings.ToLower(strings.TrimSpace(source)) {
case strings.ToLower(postgres.Dialect.Name()), case "postgresql", "postgres":
strings.ToLower(postgres.Dialect.PackageName()): if dsn != "" {
err = postgresgen.GenerateDSN(dsn, schemaName, destDir)
break
}
genData := postgresgen.DBConnection{ genData := postgresgen.DBConnection{
Host: host, Host: host,
Port: port, Port: port,
@ -98,8 +135,11 @@ Usage:
err = postgresgen.Generate(destDir, genData) err = postgresgen.Generate(destDir, genData)
case strings.ToLower(mysql.Dialect.Name()), "mariadb": case "mysql", "mysqlx", "mariadb":
if dsn != "" {
err = mysqlgen.GenerateDSN(dsn, destDir)
break
}
dbConn := mysqlgen.DBConnection{ dbConn := mysqlgen.DBConnection{
Host: host, Host: host,
Port: port, Port: port,
@ -110,9 +150,13 @@ Usage:
} }
err = mysqlgen.Generate(destDir, dbConn) err = mysqlgen.Generate(destDir, dbConn)
case "sqlite":
if dsn == "" {
printErrorAndExit("ERROR: required -dsn flag missing.")
}
err = sqlitegen.GenerateDSN(dsn, destDir)
default: default:
fmt.Println("ERROR: unsupported source " + source + ". " + postgres.Dialect.Name() + " and " + mysql.Dialect.Name() + " are currently supported.") printErrorAndExit("ERROR: unknown data source " + source + ". Only postgres, mysql, mariadb and sqlite are supported.")
os.Exit(-4)
} }
if err != nil { if err != nil {
@ -122,7 +166,22 @@ Usage:
} }
func printErrorAndExit(error string) { func printErrorAndExit(error string) {
fmt.Println(error) fmt.Println("\n", error)
flag.Usage() flag.Usage()
os.Exit(-2) os.Exit(-2)
} }
func detectSchema(dsn string) string {
match := strings.SplitN(dsn, "://", 2)
if len(match) < 2 { // not found
return ""
}
protocol := match[0]
if protocol == "file" {
return "sqlite"
}
return match[0]
}

View file

@ -20,10 +20,17 @@ const (
) )
func (e *MpaaRating) Scan(value interface{}) error { func (e *MpaaRating) Scan(value interface{}) error {
if v, ok := value.(string); !ok { var enumValue string
return errors.New("jet: Invalid data for MpaaRating enum") switch val := value.(type) {
} else { case string:
switch string(v) { enumValue = val
case []byte:
enumValue = string(val)
default:
return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte")
}
switch enumValue {
case "G": case "G":
*e = MpaaRating_G *e = MpaaRating_G
case "PG": case "PG":
@ -35,12 +42,11 @@ func (e *MpaaRating) Scan(value interface{}) error {
case "NC-17": case "NC-17":
*e = MpaaRating_Nc17 *e = MpaaRating_Nc17
default: default:
return errors.New("jet: Inavlid data " + string(v) + "for MpaaRating enum") return errors.New("jet: Invalid scan value '" + enumValue + "' for MpaaRating enum")
} }
return nil return nil
} }
}
func (e MpaaRating) String() string { func (e MpaaRating) String() string {
return string(e) return string(e)

View file

@ -1,168 +0,0 @@
package metadata
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/internal/utils"
"strings"
)
// ColumnMetaData struct
type ColumnMetaData struct {
Name string
IsNullable bool
DataType string
EnumName string
IsUnsigned bool
SqlBuilderColumnType string
GoBaseType string
GoModelType string
}
// NewColumnMetaData create new column meta data that describes one column in SQL database
func NewColumnMetaData(name string, isNullable bool, dataType string, enumName string, isUnsigned bool) ColumnMetaData {
columnMetaData := ColumnMetaData{
Name: name,
IsNullable: isNullable,
DataType: dataType,
EnumName: enumName,
IsUnsigned: isUnsigned,
}
columnMetaData.SqlBuilderColumnType = columnMetaData.getSqlBuilderColumnType()
columnMetaData.GoBaseType = columnMetaData.getGoBaseType()
columnMetaData.GoModelType = columnMetaData.getGoModelType()
return columnMetaData
}
// getSqlBuilderColumnType returns type of jet sql builder column
func (c ColumnMetaData) getSqlBuilderColumnType() string {
switch c.DataType {
case "boolean":
return "Bool"
case "smallint", "integer", "bigint",
"tinyint", "mediumint", "int", "year": //MySQL
return "Integer"
case "date":
return "Date"
case "timestamp without time zone",
"timestamp", "datetime": //MySQL:
return "Timestamp"
case "timestamp with time zone":
return "Timestampz"
case "time without time zone",
"time": //MySQL
return "Time"
case "time with time zone":
return "Timez"
case "interval":
return "Interval"
case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY",
"char", "varchar", "binary", "varbinary",
"tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL
return "String"
case "real", "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "Float"
default:
fmt.Println("- [SQL Builder] Unsupported sql column '" + c.Name + " " + c.DataType + "', using StringColumn instead.")
return "String"
}
}
// getGoBaseType returns model type for column info.
func (c ColumnMetaData) getGoBaseType() string {
switch c.DataType {
case "USER-DEFINED", "enum":
return utils.ToGoIdentifier(c.EnumName)
case "boolean":
return "bool"
case "tinyint":
return "int8"
case "smallint",
"year":
return "int16"
case "integer",
"mediumint", "int": //MySQL
return "int32"
case "bigint":
return "int64"
case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone",
"timestamp", "datetime", "time": // MySQL
return "time.Time"
case "bytea",
"binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": //MySQL
return "[]byte"
case "text", "character", "character varying", "tsvector", "bit", "bit varying", "money", "json", "jsonb",
"xml", "point", "interval", "line", "ARRAY",
"char", "varchar", "tinytext", "mediumtext", "longtext": // MySQL
return "string"
case "real":
return "float32"
case "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "float64"
case "uuid":
return "uuid.UUID"
default:
fmt.Println("- [Model ] Unsupported sql column '" + c.Name + " " + c.DataType + "', using string instead.")
return "string"
}
}
// GoModelType returns model type for column info with optional pointer if
// column can be NULL.
func (c ColumnMetaData) getGoModelType() string {
typeStr := c.GoBaseType
if strings.Contains(typeStr, "int") && c.IsUnsigned {
typeStr = "u" + typeStr
}
if c.IsNullable {
return "*" + typeStr
}
return typeStr
}
// GoModelTag returns model field tag for column
func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string {
tags := []string{}
if isPrimaryKey {
tags = append(tags, "primary_key")
}
if len(tags) > 0 {
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
}
return ""
}
func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) []ColumnMetaData {
rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName)
utils.PanicOnError(err)
defer rows.Close()
ret := []ColumnMetaData{}
for rows.Next() {
var name, isNullable, dataType, enumName string
var isUnsigned bool
err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned)
utils.PanicOnError(err)
ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned))
}
err = rows.Err()
utils.PanicOnError(err)
return ret
}

View file

@ -1,15 +0,0 @@
package metadata
import (
"database/sql"
)
// DialectQuerySet is set of methods necessary to retrieve dialect meta data information
type DialectQuerySet interface {
ListOfTablesQuery() string
PrimaryKeysQuery() string
ListOfColumnsQuery() string
ListOfEnumsQuery() string
GetEnumsMetaData(db *sql.DB, schemaName string) []MetaData
}

View file

@ -1,12 +0,0 @@
package metadata
// EnumMetaData struct
type EnumMetaData struct {
EnumName string
Values []string
}
// Name returns enum name
func (e EnumMetaData) Name() string {
return e.EnumName
}

View file

@ -1,6 +0,0 @@
package metadata
// MetaData interface
type MetaData interface {
Name() string
}

View file

@ -1,61 +0,0 @@
package metadata
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/internal/utils"
)
// SchemaMetaData struct
type SchemaMetaData struct {
TablesMetaData []MetaData
ViewsMetaData []MetaData
EnumsMetaData []MetaData
}
// IsEmpty returns true if schema info does not contain any table, views or enums metadata
func (s SchemaMetaData) IsEmpty() bool {
return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0
}
const (
baseTable = "BASE TABLE"
view = "VIEW"
)
// GetSchemaMetaData returns schema information from db connection.
func GetSchemaMetaData(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData) {
schemaInfo.TablesMetaData = getTablesMetaData(db, querySet, schemaName, baseTable)
schemaInfo.ViewsMetaData = getTablesMetaData(db, querySet, schemaName, view)
schemaInfo.EnumsMetaData = querySet.GetEnumsMetaData(db, schemaName)
fmt.Println(" FOUND", len(schemaInfo.TablesMetaData), "table(s),", len(schemaInfo.ViewsMetaData), "view(s),",
len(schemaInfo.EnumsMetaData), "enum(s)")
return
}
func getTablesMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableType string) []MetaData {
rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName, tableType)
utils.PanicOnError(err)
defer rows.Close()
ret := []MetaData{}
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
utils.PanicOnError(err)
tableInfo := GetTableMetaData(db, querySet, schemaName, tableName)
ret = append(ret, tableInfo)
}
err = rows.Err()
utils.PanicOnError(err)
return ret
}

View file

@ -1,103 +0,0 @@
package metadata
import (
"database/sql"
"github.com/go-jet/jet/v2/internal/utils"
"strings"
)
// TableMetaData metadata struct
type TableMetaData struct {
SchemaName string
name string
PrimaryKeys map[string]bool
Columns []ColumnMetaData
}
// Name returns table info name
func (t TableMetaData) Name() string {
return t.name
}
// IsPrimaryKey returns if column is a part of primary key
func (t TableMetaData) IsPrimaryKey(column string) bool {
return t.PrimaryKeys[column]
}
// MutableColumns returns list of mutable columns for table
func (t TableMetaData) MutableColumns() []ColumnMetaData {
ret := []ColumnMetaData{}
for _, column := range t.Columns {
if t.IsPrimaryKey(column.Name) {
continue
}
ret = append(ret, column)
}
return ret
}
// GetImports returns model imports for table.
func (t TableMetaData) GetImports() []string {
imports := map[string]string{}
for _, column := range t.Columns {
columnType := column.GoBaseType
switch columnType {
case "time.Time":
imports["time.Time"] = "time"
case "uuid.UUID":
imports["uuid.UUID"] = "github.com/google/uuid"
}
}
ret := []string{}
for _, packageImport := range imports {
ret = append(ret, packageImport)
}
return ret
}
// GoStructName returns go struct name for sql builder
func (t TableMetaData) GoStructName() string {
return utils.ToGoIdentifier(t.name) + "Table"
}
// GoStructImplName returns go struct impl name for sql builder
func (t TableMetaData) GoStructImplName() string {
name := utils.ToGoIdentifier(t.name) + "Table"
return string(strings.ToLower(name)[0]) + name[1:]
}
// GetTableMetaData returns table info metadata
func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) {
tableInfo.SchemaName = schemaName
tableInfo.name = tableName
tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName)
tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName)
return
}
func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) map[string]bool {
rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName)
utils.PanicOnError(err)
primaryKeyMap := map[string]bool{}
for rows.Next() {
primaryKey := ""
err := rows.Scan(&primaryKey)
utils.PanicOnError(err)
primaryKeyMap[primaryKey] = true
}
return primaryKeyMap
}

View file

@ -1,107 +0,0 @@
package template
import (
"bytes"
"fmt"
"github.com/go-jet/jet/v2/generator/internal/metadata"
"github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/utils"
"path/filepath"
"text/template"
)
// GenerateFiles generates Go files from tables and enums metadata
func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect jet.Dialect) {
if schemaInfo.IsEmpty() {
return
}
fmt.Println("Destination directory:", destDir)
fmt.Println("Cleaning up destination directory...")
err := utils.CleanUpGeneratedFiles(destDir)
utils.PanicOnError(err)
tableSQLBuilderTemplate := getTableSQLBuilderTemplate(dialect)
generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect)
generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect)
generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect)
generateModelFiles(destDir, "table", tableModelTemplate, schemaInfo.TablesMetaData, dialect)
generateModelFiles(destDir, "view", tableModelTemplate, schemaInfo.ViewsMetaData, dialect)
generateModelFiles(destDir, "enum", enumModelTemplate, schemaInfo.EnumsMetaData, dialect)
fmt.Println("Done")
}
func getTableSQLBuilderTemplate(dialect jet.Dialect) string {
if dialect.Name() == "PostgreSQL" {
return tablePostgreSQLBuilderTemplate
}
return tableSQLBuilderTemplate
}
func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) {
if len(metaData) == 0 {
return
}
fmt.Printf("Generating %s sql builder files...\n", fileTypes)
generateGoFiles(destDir, fileTypes, sqlBuilderTemplate, metaData, dialect)
}
func generateModelFiles(destDir, fileTypes, modelTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) {
if len(metaData) == 0 {
return
}
fmt.Printf("Generating %s model files...\n", fileTypes)
generateGoFiles(destDir, "model", modelTemplate, metaData, dialect)
}
func generateGoFiles(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) {
modelDirPath := filepath.Join(dirPath, packageName)
err := utils.EnsureDirPath(modelDirPath)
utils.PanicOnError(err)
autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect)
utils.PanicOnError(err)
for _, metaData := range metaDataList {
text, err := GenerateTemplate(template, metaData, dialect, map[string]interface{}{"package": packageName})
utils.PanicOnError(err)
err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...))
utils.PanicOnError(err)
}
return
}
// GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect, params ...map[string]interface{}) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": utils.ToGoIdentifier,
"ToGoEnumValueIdentifier": utils.ToGoEnumValueIdentifier,
"dialect": func() jet.Dialect {
return dialect
},
"param": func(name string) interface{} {
if len(params) > 0 {
return params[0][name]
}
return ""
},
}).Parse(templateText)
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := t.Execute(&buf, templateData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View file

@ -1,213 +0,0 @@
package template
var autoGenWarningTemplate = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
`
var tableSQLBuilderTemplate = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column
{{- end}}
{{- end}}
package {{param "package"}}
import (
"github.com/go-jet/jet/v2/{{dialect.PackageName}}"
)
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}("{{.SchemaName}}", "{{.Name}}", "")
type {{.GoStructName}} struct {
{{dialect.PackageName}}.Table
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
MutableColumns {{dialect.PackageName}}.ColumnList
}
// AS creates new {{.GoStructName}} with assigned alias
func (a {{.GoStructName}}) AS(alias string) {{.GoStructName}} {
return new{{.GoStructName}}(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new {{.GoStructName}} with assigned schema name
func (a {{.GoStructName}}) FromSchema(schemaName string) {{.GoStructName}} {
return new{{.GoStructName}}(schemaName, a.TableName(), a.Alias())
}
func new{{.GoStructName}}(schemaName, tableName, alias string) {{.GoStructName}} {
var (
{{- range .Columns}}
{{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
)
return {{.GoStructName}}{
Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column,
{{- end}}
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
var tablePostgreSQLBuilderTemplate = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column
{{- end}}
{{- end}}
package {{param "package"}}
import (
"github.com/go-jet/jet/v2/{{dialect.PackageName}}"
)
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}("{{.SchemaName}}", "{{.Name}}", "")
type {{.GoStructImplName}} struct {
{{dialect.PackageName}}.Table
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
MutableColumns {{dialect.PackageName}}.ColumnList
}
type {{.GoStructName}} struct {
{{.GoStructImplName}}
EXCLUDED {{.GoStructImplName}}
}
// AS creates new {{.GoStructName}} with assigned alias
func (a {{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
return new{{.GoStructName}}(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new {{.GoStructName}} with assigned schema name
func (a {{.GoStructName}}) FromSchema(schemaName string) *{{.GoStructName}} {
return new{{.GoStructName}}(schemaName, a.TableName(), a.Alias())
}
func new{{.GoStructName}}(schemaName, tableName, alias string) *{{.GoStructName}} {
return &{{.GoStructName}}{
{{.GoStructImplName}}: new{{.GoStructName}}Impl(schemaName, tableName, alias),
EXCLUDED: new{{.GoStructName}}Impl("", "excluded", ""),
}
}
func new{{.GoStructName}}Impl(schemaName, tableName, alias string) {{.GoStructImplName}} {
var (
{{- range .Columns}}
{{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
)
return {{.GoStructImplName}}{
Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column,
{{- end}}
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
var tableModelTemplate = `package model
{{ if .GetImports }}
import (
{{- range .GetImports}}
"{{.}}"
{{- end}}
)
{{end}}
type {{ToGoIdentifier .Name}} struct {
{{- range .Columns}}
{{ToGoIdentifier .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + `
{{- end}}
}
`
var enumSQLBuilderTemplate = `package enum
import "github.com/go-jet/jet/v2/{{dialect.PackageName}}"
var {{ToGoIdentifier $.Name}} = &struct {
{{- range $index, $element := .Values}}
{{ToGoEnumValueIdentifier $.Name $element}} {{dialect.PackageName}}.StringExpression
{{- end}}
} {
{{- range $index, $element := .Values}}
{{ToGoEnumValueIdentifier $.Name $element}}: {{dialect.PackageName}}.NewEnumValue("{{$element}}"),
{{- end}}
}
`
var enumModelTemplate = `package model
import "errors"
type {{ToGoIdentifier $.Name}} string
const (
{{- range $index, $element := .Values}}
{{ToGoIdentifier $.Name}}_{{ToGoIdentifier $element}} {{ToGoIdentifier $.Name}} = "{{$element}}"
{{- end}}
)
func (e *{{ToGoIdentifier $.Name}}) Scan(value interface{}) error {
if v, ok := value.(string); !ok {
return errors.New("jet: Invalid data for {{ToGoIdentifier $.Name}} enum")
} else {
switch string(v) {
{{- range $index, $element := .Values}}
case "{{$element}}":
*e = {{ToGoIdentifier $.Name}}_{{ToGoIdentifier $element}}
{{- end}}
default:
return errors.New("jet: Inavlid data " + string(v) + "for {{ToGoIdentifier $.Name}} enum")
}
return nil
}
}
func (e {{ToGoIdentifier $.Name}}) String() string {
return string(e)
}
`

View file

@ -0,0 +1,27 @@
package metadata
// Column struct
type Column struct {
Name string
IsPrimaryKey bool
IsNullable bool
DataType DataType
}
// DataTypeKind is database type kind(base, enum, user-defined, array)
type DataTypeKind string
// DataTypeKind possible values
const (
BaseType DataTypeKind = "base"
EnumType DataTypeKind = "enum"
UserDefinedType DataTypeKind = "user-defined"
ArrayType DataTypeKind = "array"
)
// DataType contains information about column data type
type DataType struct {
Name string
Kind DataTypeKind
IsUnsigned bool
}

View file

@ -0,0 +1,36 @@
package metadata
import (
"database/sql"
"fmt"
)
// TableType is type of database table(view or base)
type TableType string
// SQL table types
const (
BaseTable TableType = "BASE TABLE"
ViewTable TableType = "VIEW"
)
// DialectQuerySet is set of methods necessary to retrieve dialect meta data information
type DialectQuerySet interface {
GetTablesMetaData(db *sql.DB, schemaName string, tableType TableType) []Table
GetEnumsMetaData(db *sql.DB, schemaName string) []Enum
}
// GetSchema retrieves Schema information from database
func GetSchema(db *sql.DB, querySet DialectQuerySet, schemaName string) Schema {
ret := Schema{
Name: schemaName,
TablesMetaData: querySet.GetTablesMetaData(db, schemaName, BaseTable),
ViewsMetaData: querySet.GetTablesMetaData(db, schemaName, ViewTable),
EnumsMetaData: querySet.GetEnumsMetaData(db, schemaName),
}
fmt.Println(" FOUND", len(ret.TablesMetaData), "table(s),", len(ret.ViewsMetaData), "view(s),",
len(ret.EnumsMetaData), "enum(s)")
return ret
}

View file

@ -0,0 +1,7 @@
package metadata
// Enum metadata struct
type Enum struct {
Name string `sql:"primary_key"`
Values []string
}

View file

@ -0,0 +1,14 @@
package metadata
// Schema struct
type Schema struct {
Name string
TablesMetaData []Table
ViewsMetaData []Table
EnumsMetaData []Enum
}
// IsEmpty returns true if schema info does not contain any table, views or enums metadata
func (s Schema) IsEmpty() bool {
return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0
}

View file

@ -0,0 +1,22 @@
package metadata
// Table metadata struct
type Table struct {
Name string
Columns []Column
}
// MutableColumns returns list of mutable columns for table
func (t Table) MutableColumns() []Column {
var ret []Column
for _, column := range t.Columns {
if column.IsPrimaryKey {
continue
}
ret = append(ret, column)
}
return ret
}

View file

@ -3,11 +3,14 @@ package mysql
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/go-jet/jet/v2/generator/internal/metadata" "strings"
"github.com/go-jet/jet/v2/generator/internal/template"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/mysql"
"path" mysqldr "github.com/go-sql-driver/mysql"
) )
// DBConnection contains MySQL connection details // DBConnection contains MySQL connection details
@ -22,34 +25,68 @@ type DBConnection struct {
} }
// Generate generates jet files at destination dir from database connection details // Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) (err error) { func Generate(destDir string, dbConn DBConnection, generatorTemplate ...template.Template) (err error) {
defer utils.ErrorCatch(&err) defer utils.ErrorCatch(&err)
db := openConnection(dbConn) connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName)
if dbConn.Params != "" {
connectionString += "?" + dbConn.Params
}
db := openConnection(connectionString)
defer utils.DBClose(db) defer utils.DBClose(db)
fmt.Println("Retrieving database information...") generate(db, dbConn.DBName, destDir, generatorTemplate...)
// No schemas in MySQL
dbInfo := metadata.GetSchemaMetaData(db, dbConn.DBName, &mySqlQuerySet{})
genPath := path.Join(destDir, dbConn.DBName)
template.GenerateFiles(genPath, dbInfo, mysql.Dialect)
return nil return nil
} }
func openConnection(dbConn DBConnection) *sql.DB { // GenerateDSN opens connection via DSN string and does everything what Generate does.
var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName) func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error) {
if dbConn.Params != "" { defer utils.ErrorCatch(&err)
connectionString += "?" + dbConn.Params
// Special case for go mysql driver. It does not understand schema,
// so we need to trim it before passing to generator
// https://github.com/go-sql-driver/mysql#dsn-data-source-name
idx := strings.Index(dsn, "://")
if idx != -1 {
dsn = dsn[idx+len("://"):]
} }
cfg, err := mysqldr.ParseDSN(dsn)
throw.OnError(err)
if cfg.DBName == "" {
panic("database name is required")
}
db := openConnection(dsn)
defer utils.DBClose(db)
generate(db, cfg.DBName, destDir, templates...)
return nil
}
func openConnection(connectionString string) *sql.DB {
fmt.Println("Connecting to MySQL database: " + connectionString) fmt.Println("Connecting to MySQL database: " + connectionString)
db, err := sql.Open("mysql", connectionString) db, err := sql.Open("mysql", connectionString)
utils.PanicOnError(err) throw.OnError(err)
err = db.Ping() err = db.Ping()
utils.PanicOnError(err) throw.OnError(err)
return db return db
} }
func generate(db *sql.DB, dbName, destDir string, templates ...template.Template) {
fmt.Println("Retrieving database information...")
// No schemas in MySQL
schemaMetaData := metadata.GetSchema(db, &mySqlQuerySet{}, dbName)
genTemplate := template.Default(mysql.Dialect)
if len(templates) > 0 {
genTemplate = templates[0]
}
template.ProcessSchema(destDir, schemaMetaData, genTemplate)
}

View file

@ -1,81 +1,91 @@
package mysql package mysql
import ( import (
"context"
"database/sql" "database/sql"
"github.com/go-jet/jet/v2/generator/internal/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm"
"strings" "strings"
) )
// mySqlQuerySet is dialect query set for MySQL // mySqlQuerySet is dialect query set for MySQL
type mySqlQuerySet struct{} type mySqlQuerySet struct{}
func (m *mySqlQuerySet) ListOfTablesQuery() string { func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table {
return ` query := `
SELECT table_name SELECT table_name as "table.name"
FROM INFORMATION_SCHEMA.tables FROM INFORMATION_SCHEMA.tables
WHERE table_schema = ? and table_type = ?; WHERE table_schema = ? and table_type = ?;
` `
var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
throw.OnError(err)
for i := range tables {
tables[i].Columns = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
} }
func (m *mySqlQuerySet) PrimaryKeysQuery() string { return tables
return ` }
func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column {
query := `
WITH primaryKeys AS (
SELECT k.column_name SELECT k.column_name
FROM information_schema.table_constraints t FROM information_schema.table_constraints t
JOIN information_schema.key_column_usage k JOIN information_schema.key_column_usage k USING(constraint_name,table_schema,table_name)
USING(constraint_name,table_schema,table_name) WHERE table_schema = ? AND table_name = ? AND t.constraint_type='PRIMARY KEY'
WHERE t.constraint_type='PRIMARY KEY' )
AND t.table_schema= ? SELECT COLUMN_NAME AS "column.Name",
AND t.table_name= ?; IS_NULLABLE = "YES" AS "column.IsNullable",
` (EXISTS(SELECT 1 FROM primaryKeys AS pk WHERE pk.column_name = columns.column_name)) AS "column.IsPrimaryKey",
} IF (COLUMN_TYPE = 'tinyint(1)',
'boolean',
func (m *mySqlQuerySet) ListOfColumnsQuery() string { IF (DATA_TYPE='enum',
return ` CONCAT(TABLE_NAME, '_', COLUMN_NAME),
SELECT COLUMN_NAME, DATA_TYPE)
IS_NULLABLE, IF(COLUMN_TYPE = 'tinyint(1)', 'boolean', DATA_TYPE), ) AS "dataType.Name",
IF(DATA_TYPE = 'enum', CONCAT(TABLE_NAME, '_', COLUMN_NAME), ''), IF (DATA_TYPE = 'enum', 'enum', 'base') AS "dataType.Kind",
COLUMN_TYPE LIKE '%unsigned%' COLUMN_TYPE LIKE '%unsigned%' AS "dataType.IsUnsigned"
FROM information_schema.columns FROM information_schema.columns
WHERE table_schema = ? and table_name = ? WHERE table_schema = ? AND table_name = ?
ORDER BY ordinal_position; ORDER BY ordinal_position;
` `
var columns []metadata.Column
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns)
throw.OnError(err)
return columns
} }
func (m *mySqlQuerySet) ListOfEnumsQuery() string { func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum {
return ` query := `
SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5) SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ) as "name",
SUBSTRING(c.COLUMN_TYPE,5) as "values"
FROM information_schema.columns as c FROM information_schema.columns as c
INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name) INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name)
WHERE c.table_schema = ? AND DATA_TYPE = 'enum'; WHERE c.table_schema = ? AND DATA_TYPE = 'enum';
` `
var queryResult []struct {
Name string
Values string
} }
func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData { err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult)
throw.OnError(err)
rows, err := db.Query(m.ListOfEnumsQuery(), schemaName) var ret []metadata.Enum
utils.PanicOnError(err)
defer rows.Close()
ret := []metadata.MetaData{} for _, result := range queryResult {
enumValues := strings.Replace(result.Values[1:len(result.Values)-1], "'", "", -1)
for rows.Next() { ret = append(ret, metadata.Enum{
var enumName string Name: result.Name,
var enumValues string
err = rows.Scan(&enumName, &enumValues)
utils.PanicOnError(err)
enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1)
ret = append(ret, metadata.EnumMetaData{
EnumName: enumName,
Values: strings.Split(enumValues, ","), Values: strings.Split(enumValues, ","),
}) })
} }
err = rows.Err()
utils.PanicOnError(err)
return ret return ret
} }

View file

@ -3,12 +3,16 @@ package postgres
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/go-jet/jet/v2/generator/internal/metadata" "net/url"
"github.com/go-jet/jet/v2/generator/internal/template"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/postgres"
"path" "path"
"strconv" "strconv"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/postgres"
"github.com/jackc/pgconn"
) )
// DBConnection contains postgres connection details // DBConnection contains postgres connection details
@ -25,38 +29,53 @@ type DBConnection struct {
} }
// Generate generates jet files at destination dir from database connection details // Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) (err error) { func Generate(destDir string, dbConn DBConnection, genTemplate ...template.Template) (err error) {
dsn := fmt.Sprintf("postgresql://%s:%s@%s:%s/%s?sslmode=%s",
url.PathEscape(dbConn.User),
url.PathEscape(dbConn.Password),
dbConn.Host,
strconv.Itoa(dbConn.Port),
url.PathEscape(dbConn.DBName),
dbConn.SslMode,
)
return GenerateDSN(dsn, dbConn.SchemaName, destDir, genTemplate...)
}
// GenerateDSN generates jet files using dsn connection string
func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) (err error) {
defer utils.ErrorCatch(&err) defer utils.ErrorCatch(&err)
db, err := openConnection(dbConn) cfg, err := pgconn.ParseConfig(dsn)
utils.PanicOnError(err) throw.OnError(err)
if cfg.Database == "" {
panic("database name is required")
}
db := openConnection(dsn)
defer utils.DBClose(db) defer utils.DBClose(db)
fmt.Println("Retrieving schema information...") fmt.Println("Retrieving schema information...")
schemaInfo := metadata.GetSchemaMetaData(db, dbConn.SchemaName, &postgresQuerySet{}) generatorTemplate := template.Default(postgres.Dialect)
if len(templates) > 0 {
generatorTemplate = templates[0]
}
genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName) schemaMetadata := metadata.GetSchema(db, &postgresQuerySet{}, schema)
template.GenerateFiles(genPath, schemaInfo, postgres.Dialect)
dirPath := path.Join(destDir, cfg.Database)
template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate)
return return
} }
func openConnection(dbConn DBConnection) (*sql.DB, error) { func openConnection(dsn string) *sql.DB {
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s", fmt.Println("Connecting to postgres database: " + dsn)
dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params)
fmt.Println("Connecting to postgres database: " + connectionString) db, err := sql.Open("postgres", dsn)
throw.OnError(err)
db, err := sql.Open("postgres", connectionString)
if err != nil {
return nil, err
}
err = db.Ping() err = db.Ping()
throw.OnError(err)
if err != nil { return db
return nil, err
}
return db, nil
} }

View file

@ -1,81 +1,83 @@
package postgres package postgres
import ( import (
"context"
"database/sql" "database/sql"
"github.com/go-jet/jet/v2/generator/internal/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm"
) )
// postgresQuerySet is dialect query set for PostgreSQL // postgresQuerySet is dialect query set for PostgreSQL
type postgresQuerySet struct{} type postgresQuerySet struct{}
func (p *postgresQuerySet) ListOfTablesQuery() string { func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table {
return ` query := `
SELECT table_name SELECT table_name as "table.name"
FROM information_schema.tables FROM information_schema.tables
where table_schema = $1 and table_type = $2; WHERE table_schema = $1 and table_type = $2;
` `
var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
throw.OnError(err)
for i := range tables {
tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
} }
func (p *postgresQuerySet) PrimaryKeysQuery() string { return tables
return ` }
SELECT c.column_name
func (p postgresQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column {
query := `
WITH primaryKeys AS (
SELECT column_name
FROM information_schema.key_column_usage AS c FROM information_schema.key_column_usage AS c
LEFT JOIN information_schema.table_constraints AS t LEFT JOIN information_schema.table_constraints AS t
ON t.constraint_name = c.constraint_name ON t.constraint_name = c.constraint_name
WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY'; WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY'
` )
} SELECT column_name as "column.Name",
is_nullable = 'YES' as "column.isNullable",
func (p *postgresQuerySet) ListOfColumnsQuery() string { (EXISTS(SELECT 1 from primaryKeys as pk where pk.column_name = columns.column_name)) as "column.IsPrimaryKey",
return ` dataType.kind as "dataType.Kind",
SELECT column_name, is_nullable, data_type, udt_name, FALSE (case dataType.Kind when 'base' then data_type else LTRIM(udt_name, '_') end) as "dataType.Name",
FROM information_schema.columns FALSE as "dataType.isUnsigned"
FROM information_schema.columns,
LATERAL (select (case data_type
when 'ARRAY' then 'array'
when 'USER-DEFINED' then
case (select typtype from pg_type where typname = columns.udt_name)
when 'e' then 'enum'
else 'user-defined'
end
else 'base'
end) as Kind) as dataType
where table_schema = $1 and table_name = $2 where table_schema = $1 and table_name = $2
order by ordinal_position;` order by ordinal_position;
`
var columns []metadata.Column
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns)
throw.OnError(err)
return columns
} }
func (p *postgresQuerySet) ListOfEnumsQuery() string { func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum {
return ` query := `
SELECT t.typname, SELECT t.typname as "enum.name",
e.enumlabel e.enumlabel as "values"
FROM pg_catalog.pg_type t FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = $1 WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;` ORDER BY n.nspname, t.typname, e.enumsortorder;`
}
var result []metadata.Enum
func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData {
rows, err := db.Query(p.ListOfEnumsQuery(), schemaName) err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result)
utils.PanicOnError(err) throw.OnError(err)
defer rows.Close()
return result
enumsInfosMap := map[string][]string{}
for rows.Next() {
var enumName string
var enumValue string
err = rows.Scan(&enumName, &enumValue)
utils.PanicOnError(err)
enumValues := enumsInfosMap[enumName]
enumValues = append(enumValues, enumValue)
enumsInfosMap[enumName] = enumValues
}
err = rows.Err()
utils.PanicOnError(err)
ret := []metadata.MetaData{}
for enumName, enumValues := range enumsInfosMap {
ret = append(ret, metadata.EnumMetaData{
EnumName: enumName,
Values: enumValues,
})
}
return ret
} }

View file

@ -0,0 +1,80 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm"
"strings"
)
// sqliteQuerySet is dialect query set for SQLite
type sqliteQuerySet struct{}
func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table {
query := `
SELECT name as "table.name"
FROM sqlite_master
WHERE type=? AND name != 'sqlite_sequence'
ORDER BY name;
`
sqlTableType := "table"
if tableType == metadata.ViewTable {
sqlTableType = "view"
}
var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables)
throw.OnError(err)
for i := range tables {
tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
}
return tables
}
func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column {
query := fmt.Sprintf(`select * from pragma_table_info(?);`)
var columnInfos []struct {
Name string
Type string
NotNull int32
Pk int32
}
err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos)
throw.OnError(err)
var columns []metadata.Column
for _, columnInfo := range columnInfos {
columnType := getColumnType(columnInfo.Type)
columns = append(columns, metadata.Column{
Name: columnInfo.Name,
IsPrimaryKey: columnInfo.Pk != 0,
IsNullable: columnInfo.NotNull != 1,
DataType: metadata.DataType{
Name: columnType,
Kind: metadata.BaseType,
IsUnsigned: false,
},
})
}
return columns
}
// will convert VARCHAR(10) -> VARCHAR, etc...
func getColumnType(columnType string) string {
return strings.TrimSpace(strings.Split(columnType, "(")[0])
}
func (p sqliteQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum {
return nil
}

View file

@ -0,0 +1,32 @@
package sqlite
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/sqlite"
)
// GenerateDSN generates jet files using dsn connection string
func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error) {
defer utils.ErrorCatch(&err)
db, err := sql.Open("sqlite3", dsn)
throw.OnError(err)
defer utils.DBClose(db)
fmt.Println("Retrieving schema information...")
generatorTemplate := template.Default(sqlite.Dialect)
if len(templates) > 0 {
generatorTemplate = templates[0]
}
schemaMetadata := metadata.GetSchema(db, &sqliteQuerySet{}, "")
template.ProcessSchema(destDir, schemaMetadata, generatorTemplate)
return
}

View file

@ -0,0 +1,229 @@
package template
var autoGenWarningTemplate = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
`
var tableSQLBuilderTemplate = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- $field := columnField $c}}
{{- if gt $i 0 }}, {{end}}{{$field.Name}}Column
{{- end}}
{{- end}}
package {{package}}
import (
"github.com/go-jet/jet/v2/{{dialect.PackageName}}"
)
var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "")
type {{tableTemplate.TypeName}} struct {
{{dialect.PackageName}}.Table
//Columns
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
MutableColumns {{dialect.PackageName}}.ColumnList
}
// AS creates new {{tableTemplate.TypeName}} with assigned alias
func (a {{tableTemplate.TypeName}}) AS(alias string) {{tableTemplate.TypeName}} {
return new{{tableTemplate.TypeName}}(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new {{tableTemplate.TypeName}} with assigned schema name
func (a {{tableTemplate.TypeName}}) FromSchema(schemaName string) {{tableTemplate.TypeName}} {
return new{{tableTemplate.TypeName}}(schemaName, a.TableName(), a.Alias())
}
func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) {{tableTemplate.TypeName}} {
var (
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}}Column = {{dialect.PackageName}}.{{$field.Type}}Column("{{$c.Name}}")
{{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
)
return {{tableTemplate.TypeName}}{
Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}}: {{$field.Name}}Column,
{{- end}}
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
var tableSQLBuilderTemplateWithEXCLUDED = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- $field := columnField $c}}
{{- if gt $i 0 }}, {{end}}{{$field.Name}}Column
{{- end}}
{{- end}}
package {{package}}
import (
"github.com/go-jet/jet/v2/{{dialect.PackageName}}"
)
var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "")
type {{structImplName}} struct {
{{dialect.PackageName}}.Table
//Columns
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
MutableColumns {{dialect.PackageName}}.ColumnList
}
type {{tableTemplate.TypeName}} struct {
{{structImplName}}
EXCLUDED {{structImplName}}
}
// AS creates new {{tableTemplate.TypeName}} with assigned alias
func (a {{tableTemplate.TypeName}}) AS(alias string) *{{tableTemplate.TypeName}} {
return new{{tableTemplate.TypeName}}(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new {{tableTemplate.TypeName}} with assigned schema name
func (a {{tableTemplate.TypeName}}) FromSchema(schemaName string) *{{tableTemplate.TypeName}} {
return new{{tableTemplate.TypeName}}(schemaName, a.TableName(), a.Alias())
}
func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) *{{tableTemplate.TypeName}} {
return &{{tableTemplate.TypeName}}{
{{structImplName}}: new{{tableTemplate.TypeName}}Impl(schemaName, tableName, alias),
EXCLUDED: new{{tableTemplate.TypeName}}Impl("", "excluded", ""),
}
}
func new{{tableTemplate.TypeName}}Impl(schemaName, tableName, alias string) {{structImplName}} {
var (
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}}Column = {{dialect.PackageName}}.{{$field.Type}}Column("{{$c.Name}}")
{{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
)
return {{structImplName}}{
Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}}: {{$field.Name}}Column,
{{- end}}
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
var tableModelFileTemplate = `package {{package}}
{{ with modelImports }}
import (
{{- range .}}
"{{.}}"
{{- end}}
)
{{end}}
{{$modelTableTemplate := tableTemplate}}
type {{$modelTableTemplate.TypeName}} struct {
{{- range .Columns}}
{{- $field := structField .}}
{{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + `
{{- end}}
}
`
var enumSQLBuilderTemplate = `package {{package}}
import "github.com/go-jet/jet/v2/{{dialect.PackageName}}"
var {{enumTemplate.InstanceName}} = &struct {
{{- range $index, $value := .Values}}
{{enumValueName $value}} {{dialect.PackageName}}.StringExpression
{{- end}}
} {
{{- range $index, $value := .Values}}
{{enumValueName $value}}: {{dialect.PackageName}}.NewEnumValue("{{$value}}"),
{{- end}}
}
`
var enumModelTemplate = `package {{package}}
{{- $enumTemplate := enumTemplate}}
import "errors"
type {{$enumTemplate.TypeName}} string
const (
{{- range $_, $value := .Values}}
{{valueName $value}} {{$enumTemplate.TypeName}} = "{{$value}}"
{{- end}}
)
func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error {
var enumValue string
switch val := value.(type) {
case string:
enumValue = val
case []byte:
enumValue = string(val)
default:
return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte")
}
switch enumValue {
{{- range $_, $value := .Values}}
case "{{$value}}":
*e = {{valueName $value}}
{{- end}}
default:
return errors.New("jet: Invalid scan value '" + enumValue + "' for {{$enumTemplate.TypeName}} enum")
}
return nil
}
func (e {{$enumTemplate.TypeName}}) String() string {
return string(e)
}
`

View file

@ -0,0 +1,60 @@
package template
import (
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/jet"
)
// Template is generator template used for file generation
type Template struct {
Dialect jet.Dialect
Schema func(schemaMetaData metadata.Schema) Schema
}
// Default is default generator template implementation
func Default(dialect jet.Dialect) Template {
return Template{
Dialect: dialect,
Schema: DefaultSchema,
}
}
// UseSchema replaces current schema generate function with a new implementation and returns new generator template
func (t Template) UseSchema(schemaFunc func(schemaMetaData metadata.Schema) Schema) Template {
t.Schema = schemaFunc
return t
}
// Schema is schema generator template used to generate schema(model and sql builder) files
type Schema struct {
Path string
Model Model
SQLBuilder SQLBuilder
}
// UsePath replaces path and returns new schema template
func (s Schema) UsePath(path string) Schema {
s.Path = path
return s
}
// UseModel returns new schema template with replaced template for model files generation
func (s Schema) UseModel(model Model) Schema {
s.Model = model
return s
}
// UseSQLBuilder returns new schema with replaced template for sql builder files generation
func (s Schema) UseSQLBuilder(sqlBuilder SQLBuilder) Schema {
s.SQLBuilder = sqlBuilder
return s
}
// DefaultSchema returns default schema template implementation
func DefaultSchema(schemaMetaData metadata.Schema) Schema {
return Schema{
Path: schemaMetaData.Name,
Model: DefaultModel(),
SQLBuilder: DefaultSQLBuilder(),
}
}

View file

@ -0,0 +1,327 @@
package template
import (
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/google/uuid"
"path"
"reflect"
"strings"
"time"
)
// Model is template for model files generation
type Model struct {
Skip bool
Path string
Table func(table metadata.Table) TableModel
View func(table metadata.Table) ViewModel
Enum func(enum metadata.Enum) EnumModel
}
// PackageName returns package name of model types
func (m Model) PackageName() string {
return path.Base(m.Path)
}
// UsePath returns new Model template with replaced file path
func (m Model) UsePath(path string) Model {
m.Path = path
return m
}
// UseTable returns new Model template with replaced template for table model files generation
func (m Model) UseTable(tableModelFunc func(table metadata.Table) TableModel) Model {
m.Table = tableModelFunc
return m
}
// UseView returns new Model template with replaced template for view model files generation
func (m Model) UseView(tableModelFunc func(table metadata.Table) TableModel) Model {
m.View = tableModelFunc
return m
}
// UseEnum returns new Model template with replaced template for enum model files generation
func (m Model) UseEnum(enumFunc func(enumMetaData metadata.Enum) EnumModel) Model {
m.Enum = enumFunc
return m
}
// DefaultModel returns default Model template implementation
func DefaultModel() Model {
return Model{
Skip: false,
Path: "/model",
Table: DefaultTableModel,
View: DefaultViewModel,
Enum: DefaultEnumModel,
}
}
// TableModel is template for table model files generation
type TableModel struct {
Skip bool
FileName string
TypeName string
Field func(columnMetaData metadata.Column) TableModelField
}
// ViewModel is template for view model files generation
type ViewModel = TableModel
// DefaultViewModel is default view template implementation
var DefaultViewModel = DefaultTableModel
// DefaultTableModel is default table template implementation
func DefaultTableModel(tableMetaData metadata.Table) TableModel {
return TableModel{
FileName: utils.ToGoFileName(tableMetaData.Name),
TypeName: utils.ToGoIdentifier(tableMetaData.Name),
Field: DefaultTableModelField,
}
}
// UseFileName returns new TableModel with new file name set
func (t TableModel) UseFileName(fileName string) TableModel {
t.FileName = fileName
return t
}
// UseTypeName returns new TableModel with new type name set
func (t TableModel) UseTypeName(typeName string) TableModel {
t.TypeName = typeName
return t
}
// UseField returns new TableModel with new TableModelField template function
func (t TableModel) UseField(structFieldFunc func(columnMetaData metadata.Column) TableModelField) TableModel {
t.Field = structFieldFunc
return t
}
func getTableModelImports(modelType TableModel, tableMetaData metadata.Table) []string {
importPaths := map[string]bool{}
for _, columnMetaData := range tableMetaData.Columns {
field := modelType.Field(columnMetaData)
importPath := field.Type.ImportPath
if importPath != "" {
importPaths[importPath] = true
}
}
var ret []string
for importPath := range importPaths {
ret = append(ret, importPath)
}
return ret
}
// EnumModel is template for enum model files generation
type EnumModel struct {
Skip bool
FileName string
TypeName string
ValueName func(value string) string
}
// UseFileName returns new EnumModel with new file name set
func (em EnumModel) UseFileName(fileName string) EnumModel {
em.FileName = fileName
return em
}
// UseTypeName returns new EnumModel with new type name set
func (em EnumModel) UseTypeName(typeName string) EnumModel {
em.TypeName = typeName
return em
}
// DefaultEnumModel returns default implementation for EnumModel
func DefaultEnumModel(enumMetaData metadata.Enum) EnumModel {
typeName := utils.ToGoIdentifier(enumMetaData.Name)
return EnumModel{
FileName: utils.ToGoFileName(enumMetaData.Name),
TypeName: typeName,
ValueName: func(value string) string {
return typeName + "_" + utils.ToGoIdentifier(value)
},
}
}
// TableModelField is template for table model field generation
type TableModelField struct {
Name string
Type Type
Tags []string
}
// DefaultTableModelField returns default TableModelField implementation
func DefaultTableModelField(columnMetaData metadata.Column) TableModelField {
var tags []string
if columnMetaData.IsPrimaryKey {
tags = append(tags, `sql:"primary_key"`)
}
return TableModelField{
Name: utils.ToGoIdentifier(columnMetaData.Name),
Type: getType(columnMetaData),
Tags: tags,
}
}
// UseType returns new TypeModelField with a new field type set
func (f TableModelField) UseType(t Type) TableModelField {
f.Type = t
return f
}
// UseName returns new TableModelField implementation with new field name set
func (f TableModelField) UseName(name string) TableModelField {
f.Name = name
return f
}
// UseTags returns new TableModelField implementation with additional tags added.
func (f TableModelField) UseTags(tags ...string) TableModelField {
f.Tags = append(f.Tags, tags...)
return f
}
// TagsString returns tags string representation
func (f TableModelField) TagsString() string {
if len(f.Tags) == 0 {
return ""
}
return fmt.Sprintf("`%s`", strings.Join(f.Tags, " "))
}
// Type represents type of the struct field
type Type struct {
ImportPath string
Name string
}
// NewType creates new type for dummy object
func NewType(dummyObject interface{}) Type {
return Type{
ImportPath: getImportPath(dummyObject),
Name: getTypeName(dummyObject),
}
}
func getTypeName(t interface{}) string {
typeStr := reflect.TypeOf(t).String()
typeStr = strings.Replace(typeStr, "[]uint8", "[]byte", -1)
return typeStr
}
func getImportPath(dummyData interface{}) string {
dataType := reflect.TypeOf(dummyData)
if dataType.Kind() == reflect.Ptr {
return dataType.Elem().PkgPath()
}
return dataType.PkgPath()
}
func getType(columnMetadata metadata.Column) Type {
userDefinedType := getUserDefinedType(columnMetadata)
if userDefinedType != "" {
if columnMetadata.IsNullable {
return Type{Name: "*" + userDefinedType}
}
return Type{Name: userDefinedType}
}
return NewType(getGoType(columnMetadata))
}
func getUserDefinedType(column metadata.Column) string {
switch column.DataType.Kind {
case metadata.EnumType:
return utils.ToGoIdentifier(column.DataType.Name)
case metadata.UserDefinedType, metadata.ArrayType:
return "string"
}
return ""
}
func getGoType(column metadata.Column) interface{} {
defaultGoType := toGoType(column)
if column.IsNullable {
return reflect.New(reflect.TypeOf(defaultGoType)).Interface()
}
return defaultGoType
}
// toGoType returns model type for column info.
func toGoType(column metadata.Column) interface{} {
switch strings.ToLower(column.DataType.Name) {
case "user-defined", "enum":
return ""
case "boolean", "bool":
return false
case "tinyint":
if column.DataType.IsUnsigned {
return uint8(0)
}
return int8(0)
case "smallint", "int2",
"year":
if column.DataType.IsUnsigned {
return uint16(0)
}
return int16(0)
case "integer", "int4",
"mediumint", "int": //MySQL
if column.DataType.IsUnsigned {
return uint32(0)
}
return int32(0)
case "bigint", "int8":
if column.DataType.IsUnsigned {
return uint64(0)
}
return int64(0)
case "date",
"timestamp without time zone", "timestamp",
"timestamp with time zone", "timestamptz",
"time without time zone", "time",
"time with time zone", "timetz",
"datetime": // MySQL
return time.Time{}
case "bytea",
"binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": //MySQL
return []byte("")
case "text",
"character", "bpchar",
"character varying", "varchar", "nvarchar",
"tsvector", "bit", "bit varying", "varbit",
"money", "json", "jsonb",
"xml", "point", "interval", "line", "array",
"char", "tinytext", "mediumtext", "longtext": // MySQL
return ""
case "real", "float4":
return float32(0.0)
case "numeric", "decimal",
"double precision", "float8", "float",
"double": // MySQL
return float64(0.0)
case "uuid":
return uuid.UUID{}
default:
fmt.Println("- [Model ] Unsupported sql column '" + column.Name + " " + column.DataType.Name + "', using string instead.")
return ""
}
}

View file

@ -0,0 +1,45 @@
package template
import (
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/stretchr/testify/require"
"testing"
)
func Test_TableModelField(t *testing.T) {
require.Equal(t, DefaultTableModelField(metadata.Column{
Name: "col_name",
IsPrimaryKey: true,
IsNullable: true,
DataType: metadata.DataType{
Name: "smallint",
Kind: "base",
IsUnsigned: true,
},
}), TableModelField{
Name: "ColName",
Type: Type{
ImportPath: "",
Name: "*uint16",
},
Tags: []string{"sql:\"primary_key\""},
})
require.Equal(t, DefaultTableModelField(metadata.Column{
Name: "time_column_1",
IsPrimaryKey: false,
IsNullable: true,
DataType: metadata.DataType{
Name: "timestamp with time zone",
Kind: "base",
IsUnsigned: false,
},
}), TableModelField{
Name: "TimeColumn1",
Type: Type{
ImportPath: "time",
Name: "*time.Time",
},
Tags: nil,
})
}

View file

@ -0,0 +1,269 @@
package template
import (
"bytes"
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"path"
"strings"
"text/template"
)
// ProcessSchema will process schema metadata and constructs go files using generator Template
func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemplate Template) {
if schemaMetaData.IsEmpty() {
return
}
schemaTemplate := generatorTemplate.Schema(schemaMetaData)
schemaPath := path.Join(dirPath, schemaTemplate.Path)
fmt.Println("Destination directory:", schemaPath)
fmt.Println("Cleaning up destination directory...")
err := utils.CleanUpGeneratedFiles(schemaPath)
throw.OnError(err)
processModel(schemaPath, schemaMetaData, schemaTemplate)
processSQLBuilder(schemaPath, generatorTemplate.Dialect, schemaMetaData, schemaTemplate)
}
func processModel(dirPath string, schemaMetaData metadata.Schema, schemaTemplate Schema) {
modelTemplate := schemaTemplate.Model
if modelTemplate.Skip {
fmt.Println("Skipping the generation of model types.")
return
}
modelDirPath := path.Join(dirPath, modelTemplate.Path)
err := utils.EnsureDirPath(modelDirPath)
throw.OnError(err)
processTableModels("table", modelDirPath, schemaMetaData.TablesMetaData, modelTemplate)
processTableModels("view", modelDirPath, schemaMetaData.ViewsMetaData, modelTemplate)
processEnumModels(modelDirPath, schemaMetaData.EnumsMetaData, modelTemplate)
}
func processSQLBuilder(dirPath string, dialect jet.Dialect, schemaMetaData metadata.Schema, schemaTemplate Schema) {
sqlBuilderTemplate := schemaTemplate.SQLBuilder
if sqlBuilderTemplate.Skip {
fmt.Println("Skipping the generation of SQL Builder types.")
return
}
sqlBuilderPath := path.Join(dirPath, sqlBuilderTemplate.Path)
processTableSQLBuilder("table", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.TablesMetaData, sqlBuilderTemplate)
processTableSQLBuilder("view", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.ViewsMetaData, sqlBuilderTemplate)
processEnumSQLBuilder(sqlBuilderPath, dialect, schemaMetaData.EnumsMetaData, sqlBuilderTemplate)
}
func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []metadata.Enum, sqlBuilder SQLBuilder) {
if len(enumsMetaData) == 0 {
return
}
fmt.Printf("Generating enum sql builder files\n")
for _, enumMetaData := range enumsMetaData {
enumTemplate := sqlBuilder.Enum(enumMetaData)
if enumTemplate.Skip {
continue
}
enumSQLBuilderPath := path.Join(dirPath, enumTemplate.Path)
err := utils.EnsureDirPath(enumSQLBuilderPath)
throw.OnError(err)
text, err := generateTemplate(
autoGenWarningTemplate+enumSQLBuilderTemplate,
enumMetaData,
template.FuncMap{
"package": func() string {
return enumTemplate.PackageName()
},
"dialect": func() jet.Dialect {
return dialect
},
"enumTemplate": func() EnumSQLBuilder {
return enumTemplate
},
"enumValueName": func(enumValue string) string {
return enumTemplate.ValueName(enumValue)
},
})
throw.OnError(err)
err = utils.SaveGoFile(enumSQLBuilderPath, enumTemplate.FileName, text)
throw.OnError(err)
}
}
func processTableSQLBuilder(fileTypes, dirPath string,
dialect jet.Dialect,
schemaMetaData metadata.Schema,
tablesMetaData []metadata.Table,
sqlBuilderTemplate SQLBuilder) {
if len(tablesMetaData) == 0 {
return
}
fmt.Printf("Generating %s sql builder files\n", fileTypes)
for _, tableMetaData := range tablesMetaData {
var tableSQLBuilderTemplate TableSQLBuilder
if fileTypes == "view" {
tableSQLBuilderTemplate = sqlBuilderTemplate.View(tableMetaData)
} else {
tableSQLBuilderTemplate = sqlBuilderTemplate.Table(tableMetaData)
}
if tableSQLBuilderTemplate.Skip {
continue
}
tableSQLBuilderPath := path.Join(dirPath, tableSQLBuilderTemplate.Path)
err := utils.EnsureDirPath(tableSQLBuilderPath)
throw.OnError(err)
text, err := generateTemplate(
autoGenWarningTemplate+getTableSQLBuilderTemplate(dialect),
tableMetaData,
template.FuncMap{
"package": func() string {
return tableSQLBuilderTemplate.PackageName()
},
"dialect": func() jet.Dialect {
return dialect
},
"schemaName": func() string {
return schemaMetaData.Name
},
"tableTemplate": func() TableSQLBuilder {
return tableSQLBuilderTemplate
},
"structImplName": func() string { // postgres only
structName := tableSQLBuilderTemplate.TypeName
return string(strings.ToLower(structName)[0]) + structName[1:]
},
"columnField": func(columnMetaData metadata.Column) TableSQLBuilderColumn {
return tableSQLBuilderTemplate.Column(columnMetaData)
},
})
throw.OnError(err)
err = utils.SaveGoFile(tableSQLBuilderPath, tableSQLBuilderTemplate.FileName, text)
throw.OnError(err)
}
}
func getTableSQLBuilderTemplate(dialect jet.Dialect) string {
if dialect.Name() == "PostgreSQL" || dialect.Name() == "SQLite" {
return tableSQLBuilderTemplateWithEXCLUDED
}
return tableSQLBuilderTemplate
}
func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadata.Table, modelTemplate Model) {
if len(tablesMetaData) == 0 {
return
}
fmt.Printf("Generating %s model files...\n", fileTypes)
for _, tableMetaData := range tablesMetaData {
var tableTemplate TableModel
if fileTypes == "table" {
tableTemplate = modelTemplate.Table(tableMetaData)
} else {
tableTemplate = modelTemplate.View(tableMetaData)
}
if tableTemplate.Skip {
continue
}
text, err := generateTemplate(
autoGenWarningTemplate+tableModelFileTemplate,
tableMetaData,
template.FuncMap{
"package": func() string {
return modelTemplate.PackageName()
},
"modelImports": func() []string {
return getTableModelImports(tableTemplate, tableMetaData)
},
"tableTemplate": func() TableModel {
return tableTemplate
},
"structField": func(columnMetaData metadata.Column) TableModelField {
return tableTemplate.Field(columnMetaData)
},
})
throw.OnError(err)
err = utils.SaveGoFile(modelDirPath, tableTemplate.FileName, text)
throw.OnError(err)
}
}
func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemplate Model) {
if len(enumsMetaData) == 0 {
return
}
fmt.Print("Generating enum model files...\n")
for _, enumMetaData := range enumsMetaData {
enumTemplate := modelTemplate.Enum(enumMetaData)
if enumTemplate.Skip {
continue
}
text, err := generateTemplate(
autoGenWarningTemplate+enumModelTemplate,
enumMetaData,
template.FuncMap{
"package": func() string {
return modelTemplate.PackageName()
},
"enumTemplate": func() EnumModel {
return enumTemplate
},
"valueName": func(value string) string {
return enumTemplate.ValueName(value)
},
})
throw.OnError(err)
err = utils.SaveGoFile(modelDir, enumTemplate.FileName, text)
throw.OnError(err)
}
}
func generateTemplate(templateText string, templateData interface{}, funcMap template.FuncMap) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(funcMap).Parse(templateText)
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := t.Execute(&buf, templateData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View file

@ -0,0 +1,226 @@
package template
import (
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils"
"path"
"strings"
"unicode"
)
// SQLBuilder is template for generating sql builder files
type SQLBuilder struct {
Skip bool
Path string
Table func(table metadata.Table) TableSQLBuilder
View func(view metadata.Table) TableSQLBuilder
Enum func(enum metadata.Enum) EnumSQLBuilder
}
// DefaultSQLBuilder returns default SQLBuilder implementation
func DefaultSQLBuilder() SQLBuilder {
return SQLBuilder{
Path: "",
Table: DefaultTableSQLBuilder,
View: DefaultViewSQLBuilder,
Enum: DefaultEnumSQLBuilder,
}
}
// UsePath returns new SQLBuilder with new relative path set
func (sb SQLBuilder) UsePath(path string) SQLBuilder {
sb.Path = path
return sb
}
// UseTable returns new SQLBuilder with new TableSQLBuilder template function set
func (sb SQLBuilder) UseTable(tableFunc func(table metadata.Table) TableSQLBuilder) SQLBuilder {
sb.Table = tableFunc
return sb
}
// UseView returns new SQLBuilder with new ViewSQLBuilder template function set
func (sb SQLBuilder) UseView(viewFunc func(table metadata.Table) ViewSQLBuilder) SQLBuilder {
sb.View = viewFunc
return sb
}
// UseEnum returns new SQLBuilder with new EnumSQLBuilder template function set
func (sb SQLBuilder) UseEnum(enumFunc func(enum metadata.Enum) EnumSQLBuilder) SQLBuilder {
sb.Enum = enumFunc
return sb
}
// TableSQLBuilder is template for generating table SQLBuilder files
type TableSQLBuilder struct {
Skip bool
Path string
FileName string
InstanceName string
TypeName string
Column func(columnMetaData metadata.Column) TableSQLBuilderColumn
}
// ViewSQLBuilder is template for generating view SQLBuilder files
type ViewSQLBuilder = TableSQLBuilder
// DefaultTableSQLBuilder returns default implementation for TableSQLBuilder
func DefaultTableSQLBuilder(tableMetaData metadata.Table) TableSQLBuilder {
return TableSQLBuilder{
Path: "/table",
FileName: utils.ToGoFileName(tableMetaData.Name),
InstanceName: utils.ToGoIdentifier(tableMetaData.Name),
TypeName: utils.ToGoIdentifier(tableMetaData.Name) + "Table",
Column: DefaultTableSQLBuilderColumn,
}
}
// DefaultViewSQLBuilder returns default implementation for ViewSQLBuilder
func DefaultViewSQLBuilder(viewMetaData metadata.Table) ViewSQLBuilder {
tableSQLBuilder := DefaultTableSQLBuilder(viewMetaData)
tableSQLBuilder.Path = "/view"
return tableSQLBuilder
}
// PackageName returns package name of table sql builder types
func (tb TableSQLBuilder) PackageName() string {
return path.Base(tb.Path)
}
// UsePath returns new TableSQLBuilder with new relative path set
func (tb TableSQLBuilder) UsePath(path string) TableSQLBuilder {
tb.Path = path
return tb
}
// UseFileName returns new TableSQLBuilder with new file name set
func (tb TableSQLBuilder) UseFileName(name string) TableSQLBuilder {
tb.FileName = name
return tb
}
// UseInstanceName returns new TableSQLBuilder with new instance name set
func (tb TableSQLBuilder) UseInstanceName(name string) TableSQLBuilder {
tb.InstanceName = name
return tb
}
// UseTypeName returns new TableSQLBuilder with new type name set
func (tb TableSQLBuilder) UseTypeName(name string) TableSQLBuilder {
tb.TypeName = name
return tb
}
// UseColumn returns new TableSQLBuilder with new column template function set
func (tb TableSQLBuilder) UseColumn(columnsFunc func(column metadata.Column) TableSQLBuilderColumn) TableSQLBuilder {
tb.Column = columnsFunc
return tb
}
// TableSQLBuilderColumn is template for table sql builder column
type TableSQLBuilderColumn struct {
Name string
Type string
}
// DefaultTableSQLBuilderColumn returns default implementation of TableSQLBuilderColumn
func DefaultTableSQLBuilderColumn(columnMetaData metadata.Column) TableSQLBuilderColumn {
return TableSQLBuilderColumn{
Name: utils.ToGoIdentifier(columnMetaData.Name),
Type: getSqlBuilderColumnType(columnMetaData),
}
}
// getSqlBuilderColumnType returns type of jet sql builder column
func getSqlBuilderColumnType(columnMetaData metadata.Column) string {
if columnMetaData.DataType.Kind != metadata.BaseType {
return "String"
}
switch strings.ToLower(columnMetaData.DataType.Name) {
case "boolean":
return "Bool"
case "smallint", "integer", "bigint",
"tinyint", "mediumint", "int", "year": //MySQL
return "Integer"
case "date":
return "Date"
case "timestamp without time zone",
"timestamp", "datetime": //MySQL:
return "Timestamp"
case "timestamp with time zone":
return "Timestampz"
case "time without time zone",
"time": //MySQL
return "Time"
case "time with time zone":
return "Timez"
case "interval":
return "Interval"
case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY",
"char", "varchar", "nvarchar", "binary", "varbinary",
"tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL
return "String"
case "real", "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "Float"
default:
fmt.Println("- [SQL Builder] Unsupported sql column '" + columnMetaData.Name + " " + columnMetaData.DataType.Name + "', using StringColumn instead.")
return "String"
}
}
// EnumSQLBuilder is template for generating enum SQLBuilder files
type EnumSQLBuilder struct {
Skip bool
Path string
FileName string
InstanceName string
ValueName func(enumValue string) string
}
// DefaultEnumSQLBuilder returns default implementation of EnumSQLBuilder
func DefaultEnumSQLBuilder(enumMetaData metadata.Enum) EnumSQLBuilder {
return EnumSQLBuilder{
Path: "/enum",
FileName: utils.ToGoFileName(enumMetaData.Name),
InstanceName: utils.ToGoIdentifier(enumMetaData.Name),
ValueName: func(enumValue string) string {
return defaultEnumValueName(enumMetaData.Name, enumValue)
},
}
}
// PackageName returns enum sql builder package name
func (e EnumSQLBuilder) PackageName() string {
return path.Base(e.Path)
}
// UsePath returns new EnumSQLBuilder with new path set
func (e EnumSQLBuilder) UsePath(path string) EnumSQLBuilder {
e.Path = path
return e
}
// UseFileName returns new EnumSQLBuilder with new file name set
func (e EnumSQLBuilder) UseFileName(name string) EnumSQLBuilder {
e.FileName = name
return e
}
// UseInstanceName returns new EnumSQLBuilder with instance name set
func (e EnumSQLBuilder) UseInstanceName(name string) EnumSQLBuilder {
e.InstanceName = name
return e
}
func defaultEnumValueName(enumName, enumValue string) string {
enumValueName := utils.ToGoIdentifier(enumValue)
if !unicode.IsLetter([]rune(enumValueName)[0]) {
return utils.ToGoIdentifier(enumName) + enumValueName
}
return enumValueName
}

View file

@ -0,0 +1,11 @@
package template
import (
"github.com/stretchr/testify/require"
"testing"
)
func TestToGoEnumValueIdentifier(t *testing.T) {
require.Equal(t, defaultEnumValueName("enum_name", "enum_value"), "EnumValue")
require.Equal(t, defaultEnumValueName("NumEnum", "100"), "NumEnum100")
}

2
go.mod
View file

@ -6,8 +6,10 @@ require (
github.com/go-sql-driver/mysql v1.5.0 github.com/go-sql-driver/mysql v1.5.0
github.com/google/go-cmp v0.5.0 //tests github.com/google/go-cmp v0.5.0 //tests
github.com/google/uuid v1.1.1 github.com/google/uuid v1.1.1
github.com/jackc/pgconn v1.8.1
github.com/jackc/pgx/v4 v4.11.0 //tests github.com/jackc/pgx/v4 v4.11.0 //tests
github.com/lib/pq v1.7.0 github.com/lib/pq v1.7.0
github.com/mattn/go-sqlite3 v1.14.8
github.com/pkg/profile v1.5.0 //tests github.com/pkg/profile v1.5.0 //tests
github.com/shopspring/decimal v1.2.0 // tests github.com/shopspring/decimal v1.2.0 // tests
github.com/stretchr/testify v1.6.1 // tests github.com/stretchr/testify v1.6.1 // tests

4
go.sum
View file

@ -42,7 +42,6 @@ github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7
github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -219,6 +218,8 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU=
github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
@ -457,7 +458,6 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View file

@ -9,8 +9,12 @@ import (
) )
// SnakeToCamel returns a string converted from snake case to uppercase // SnakeToCamel returns a string converted from snake case to uppercase
func SnakeToCamel(s string) string { func SnakeToCamel(s string, firstLetterUppercase ...bool) string {
return snakeToCamel(s, true) upperCase := true
if len(firstLetterUppercase) > 0 {
upperCase = firstLetterUppercase[0]
}
return snakeToCamel(s, upperCase)
} }
func snakeToCamel(s string, upperCase bool) string { func snakeToCamel(s string, upperCase bool) string {

View file

@ -223,6 +223,7 @@ type ClauseSetStmtOperator struct {
OrderBy ClauseOrderBy OrderBy ClauseOrderBy
Limit ClauseLimit Limit ClauseLimit
Offset ClauseOffset Offset ClauseOffset
SkipSelectWrap bool
} }
// Projections returns set of projections for ClauseSetStmtOperator // Projections returns set of projections for ClauseSetStmtOperator
@ -242,6 +243,10 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLB
for i, selectStmt := range s.Selects { for i, selectStmt := range s.Selects {
out.NewLine() out.NewLine()
if i > 0 { if i > 0 {
if s.SkipSelectWrap {
out.NewLine()
}
out.WriteString(s.Operator) out.WriteString(s.Operator)
if s.All { if s.All {
@ -254,7 +259,11 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLB
panic("jet: select statement of '" + s.Operator + "' is nil") panic("jet: select statement of '" + s.Operator + "' is nil")
} }
selectStmt.serialize(statementType, out, FallTrough(options)...) if s.SkipSelectWrap {
options = append(FallTrough(options), NoWrap)
}
selectStmt.serialize(statementType, out, options...)
} }
s.OrderBy.Serialize(statementType, out) s.OrderBy.Serialize(statementType, out)
@ -360,10 +369,6 @@ type ClauseValuesQuery struct {
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(v.Rows) == 0 && v.Query == nil {
panic("jet: VALUES or QUERY has to be specified for INSERT statement")
}
if len(v.Rows) > 0 && v.Query != nil { if len(v.Rows) > 0 && v.Query != nil {
panic("jet: VALUES or QUERY has to be specified for INSERT statement") panic("jet: VALUES or QUERY has to be specified for INSERT statement")
} }
@ -406,6 +411,7 @@ func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder, o
// ClauseQuery struct // ClauseQuery struct
type ClauseQuery struct { type ClauseQuery struct {
Query SerializerStatement Query SerializerStatement
SkipSelectWrap bool
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
@ -414,7 +420,11 @@ func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder, op
return return
} }
v.Query.serialize(statementType, out, FallTrough(options)...) if v.SkipSelectWrap {
options = append(FallTrough(options), NoWrap)
}
v.Query.serialize(statementType, out, options...)
} }
// ClauseDelete struct // ClauseDelete struct
@ -561,3 +571,26 @@ type KeywordClause struct {
func (k KeywordClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { func (k KeywordClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
k.serialize(statementType, out, FallTrough(options)...) k.serialize(statementType, out, FallTrough(options)...)
} }
// ClauseReturning type
type ClauseReturning struct {
ProjectionList []Projection
}
// Serialize for ClauseReturning
func (r *ClauseReturning) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(r.ProjectionList) == 0 {
return
}
out.NewLine()
out.WriteString("RETURNING")
out.IncreaseIdent()
out.WriteProjections(statementType, r.ProjectionList)
out.DecreaseIdent()
}
// Projections for ClauseReturning
func (r ClauseReturning) Projections() ProjectionList {
return r.ProjectionList
}

View file

@ -11,6 +11,28 @@ func (cl ColumnList) SET(expression Expression) ColumnAssigment {
} }
} }
// Except will create new column list in which columns contained in excluded column names are removed
func (cl ColumnList) Except(excludedColumns ...Column) ColumnList {
excludedColumnList := UnwidColumnList(excludedColumns)
excludedColumnNames := map[string]bool{}
for _, excludedColumn := range excludedColumnList {
excludedColumnNames[excludedColumn.Name()] = true
}
var ret ColumnList
for _, column := range cl {
if excludedColumnNames[column.Name()] {
continue
}
ret = append(ret, column)
}
return ret
}
func (cl ColumnList) fromImpl(subQuery SelectTable) Projection { func (cl ColumnList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{} newProjectionList := ProjectionList{}

View file

@ -2,7 +2,7 @@ package jet
// ROW is construct one table row from list of expressions. // ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression { func ROW(expressions ...Expression) Expression {
return newFunc("ROW", expressions, nil) return NewFunc("ROW", expressions, nil)
} }
// ------------------ Mathematical functions ---------------// // ------------------ Mathematical functions ---------------//
@ -265,118 +265,118 @@ func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression {
// LOWER returns string expression in lower case // LOWER returns string expression in lower case
func LOWER(stringExpression StringExpression) StringExpression { func LOWER(stringExpression StringExpression) StringExpression {
return newStringFunc("LOWER", stringExpression) return NewStringFunc("LOWER", stringExpression)
} }
// UPPER returns string expression in upper case // UPPER returns string expression in upper case
func UPPER(stringExpression StringExpression) StringExpression { func UPPER(stringExpression StringExpression) StringExpression {
return newStringFunc("UPPER", stringExpression) return NewStringFunc("UPPER", stringExpression)
} }
// BTRIM removes the longest string consisting only of characters // BTRIM removes the longest string consisting only of characters
// in characters (a space by default) from the start and end of string // in characters (a space by default) from the start and end of string
func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression { func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 { if len(trimChars) > 0 {
return newStringFunc("BTRIM", stringExpression, trimChars[0]) return NewStringFunc("BTRIM", stringExpression, trimChars[0])
} }
return newStringFunc("BTRIM", stringExpression) return NewStringFunc("BTRIM", stringExpression)
} }
// LTRIM removes the longest string containing only characters // LTRIM removes the longest string containing only characters
// from characters (a space by default) from the start of string // from characters (a space by default) from the start of string
func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 { if len(trimChars) > 0 {
return newStringFunc("LTRIM", str, trimChars[0]) return NewStringFunc("LTRIM", str, trimChars[0])
} }
return newStringFunc("LTRIM", str) return NewStringFunc("LTRIM", str)
} }
// RTRIM removes the longest string containing only characters // RTRIM removes the longest string containing only characters
// from characters (a space by default) from the end of string // from characters (a space by default) from the end of string
func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 { if len(trimChars) > 0 {
return newStringFunc("RTRIM", str, trimChars[0]) return NewStringFunc("RTRIM", str, trimChars[0])
} }
return newStringFunc("RTRIM", str) return NewStringFunc("RTRIM", str)
} }
// CHR returns character with the given code. // CHR returns character with the given code.
func CHR(integerExpression IntegerExpression) StringExpression { func CHR(integerExpression IntegerExpression) StringExpression {
return newStringFunc("CHR", integerExpression) return NewStringFunc("CHR", integerExpression)
} }
// CONCAT adds two or more expressions together // CONCAT adds two or more expressions together
func CONCAT(expressions ...Expression) StringExpression { func CONCAT(expressions ...Expression) StringExpression {
return newStringFunc("CONCAT", expressions...) return NewStringFunc("CONCAT", expressions...)
} }
// CONCAT_WS adds two or more expressions together with a separator. // CONCAT_WS adds two or more expressions together with a separator.
func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression { func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression {
return newStringFunc("CONCAT_WS", append([]Expression{separator}, expressions...)...) return NewStringFunc("CONCAT_WS", append([]Expression{separator}, expressions...)...)
} }
// CONVERT converts string to dest_encoding. The original encoding is // CONVERT converts string to dest_encoding. The original encoding is
// specified by src_encoding. The string must be valid in this encoding. // specified by src_encoding. The string must be valid in this encoding.
func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression { func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT", str, srcEncoding, destEncoding) return NewStringFunc("CONVERT", str, srcEncoding, destEncoding)
} }
// CONVERT_FROM converts string to the database encoding. The original // CONVERT_FROM converts string to the database encoding. The original
// encoding is specified by src_encoding. The string must be valid in this encoding. // encoding is specified by src_encoding. The string must be valid in this encoding.
func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression { func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_FROM", str, srcEncoding) return NewStringFunc("CONVERT_FROM", str, srcEncoding)
} }
// CONVERT_TO converts string to dest_encoding. // CONVERT_TO converts string to dest_encoding.
func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression { func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_TO", str, toEncoding) return NewStringFunc("CONVERT_TO", str, toEncoding)
} }
// ENCODE encodes binary data into a textual representation. // ENCODE encodes binary data into a textual representation.
// Supported formats are: base64, hex, escape. escape converts zero bytes and // Supported formats are: base64, hex, escape. escape converts zero bytes and
// high-bit-set bytes to octal sequences (\nnn) and doubles backslashes. // high-bit-set bytes to octal sequences (\nnn) and doubles backslashes.
func ENCODE(data StringExpression, format StringExpression) StringExpression { func ENCODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("ENCODE", data, format) return NewStringFunc("ENCODE", data, format)
} }
// DECODE decodes binary data from textual representation in string. // DECODE decodes binary data from textual representation in string.
// Options for format are same as in encode. // Options for format are same as in encode.
func DECODE(data StringExpression, format StringExpression) StringExpression { func DECODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("DECODE", data, format) return NewStringFunc("DECODE", data, format)
} }
// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression { func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression {
args := []Expression{formatStr} args := []Expression{formatStr}
args = append(args, formatArgs...) args = append(args, formatArgs...)
return newStringFunc("FORMAT", args...) return NewStringFunc("FORMAT", args...)
} }
// INITCAP converts the first letter of each word to upper case // INITCAP converts the first letter of each word to upper case
// and the rest to lower case. Words are sequences of alphanumeric // and the rest to lower case. Words are sequences of alphanumeric
// characters separated by non-alphanumeric characters. // characters separated by non-alphanumeric characters.
func INITCAP(str StringExpression) StringExpression { func INITCAP(str StringExpression) StringExpression {
return newStringFunc("INITCAP", str) return NewStringFunc("INITCAP", str)
} }
// LEFT returns first n characters in the string. // LEFT returns first n characters in the string.
// When n is negative, return all but last |n| characters. // When n is negative, return all but last |n| characters.
func LEFT(str StringExpression, n IntegerExpression) StringExpression { func LEFT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("LEFT", str, n) return NewStringFunc("LEFT", str, n)
} }
// RIGHT returns last n characters in the string. // RIGHT returns last n characters in the string.
// When n is negative, return all but first |n| characters. // When n is negative, return all but first |n| characters.
func RIGHT(str StringExpression, n IntegerExpression) StringExpression { func RIGHT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("RIGHT", str, n) return NewStringFunc("RIGHT", str, n)
} }
// LENGTH returns number of characters in string with a given encoding // LENGTH returns number of characters in string with a given encoding
func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression { func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 { if len(encoding) > 0 {
return newStringFunc("LENGTH", str, encoding[0]) return NewStringFunc("LENGTH", str, encoding[0])
} }
return newStringFunc("LENGTH", str) return NewStringFunc("LENGTH", str)
} }
// LPAD fills up the string to length length by prepending the characters // LPAD fills up the string to length length by prepending the characters
@ -384,40 +384,40 @@ func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression
// then it is truncated (on the right). // then it is truncated (on the right).
func LPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression { func LPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 { if len(text) > 0 {
return newStringFunc("LPAD", str, length, text[0]) return NewStringFunc("LPAD", str, length, text[0])
} }
return newStringFunc("LPAD", str, length) return NewStringFunc("LPAD", str, length)
} }
// RPAD fills up the string to length length by appending the characters // RPAD fills up the string to length length by appending the characters
// fill (a space by default). If the string is already longer than length then it is truncated. // fill (a space by default). If the string is already longer than length then it is truncated.
func RPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression { func RPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 { if len(text) > 0 {
return newStringFunc("RPAD", str, length, text[0]) return NewStringFunc("RPAD", str, length, text[0])
} }
return newStringFunc("RPAD", str, length) return NewStringFunc("RPAD", str, length)
} }
// MD5 calculates the MD5 hash of string, returning the result in hexadecimal // MD5 calculates the MD5 hash of string, returning the result in hexadecimal
func MD5(stringExpression StringExpression) StringExpression { func MD5(stringExpression StringExpression) StringExpression {
return newStringFunc("MD5", stringExpression) return NewStringFunc("MD5", stringExpression)
} }
// REPEAT repeats string the specified number of times // REPEAT repeats string the specified number of times
func REPEAT(str StringExpression, n IntegerExpression) StringExpression { func REPEAT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("REPEAT", str, n) return NewStringFunc("REPEAT", str, n)
} }
// REPLACE replaces all occurrences in string of substring from with substring to // REPLACE replaces all occurrences in string of substring from with substring to
func REPLACE(text, from, to StringExpression) StringExpression { func REPLACE(text, from, to StringExpression) StringExpression {
return newStringFunc("REPLACE", text, from, to) return NewStringFunc("REPLACE", text, from, to)
} }
// REVERSE returns reversed string. // REVERSE returns reversed string.
func REVERSE(stringExpression StringExpression) StringExpression { func REVERSE(stringExpression StringExpression) StringExpression {
return newStringFunc("REVERSE", stringExpression) return NewStringFunc("REVERSE", stringExpression)
} }
// STRPOS returns location of specified substring (same as position(substring in string), // STRPOS returns location of specified substring (same as position(substring in string),
@ -429,22 +429,22 @@ func STRPOS(str, substring StringExpression) IntegerExpression {
// SUBSTR extracts substring // SUBSTR extracts substring
func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression { func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression {
if len(count) > 0 { if len(count) > 0 {
return newStringFunc("SUBSTR", str, from, count[0]) return NewStringFunc("SUBSTR", str, from, count[0])
} }
return newStringFunc("SUBSTR", str, from) return NewStringFunc("SUBSTR", str, from)
} }
// TO_ASCII convert string to ASCII from another encoding // TO_ASCII convert string to ASCII from another encoding
func TO_ASCII(str StringExpression, encoding ...StringExpression) StringExpression { func TO_ASCII(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 { if len(encoding) > 0 {
return newStringFunc("TO_ASCII", str, encoding[0]) return NewStringFunc("TO_ASCII", str, encoding[0])
} }
return newStringFunc("TO_ASCII", str) return NewStringFunc("TO_ASCII", str)
} }
// TO_HEX converts number to its equivalent hexadecimal representation // TO_HEX converts number to its equivalent hexadecimal representation
func TO_HEX(number IntegerExpression) StringExpression { func TO_HEX(number IntegerExpression) StringExpression {
return newStringFunc("TO_HEX", number) return NewStringFunc("TO_HEX", number)
} }
// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. // REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise.
@ -460,12 +460,12 @@ func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType
// TO_CHAR converts expression to string with format // TO_CHAR converts expression to string with format
func TO_CHAR(expression Expression, format StringExpression) StringExpression { func TO_CHAR(expression Expression, format StringExpression) StringExpression {
return newStringFunc("TO_CHAR", expression, format) return NewStringFunc("TO_CHAR", expression, format)
} }
// TO_DATE converts string to date using format // TO_DATE converts string to date using format
func TO_DATE(dateStr, format StringExpression) DateExpression { func TO_DATE(dateStr, format StringExpression) DateExpression {
return newDateFunc("TO_DATE", dateStr, format) return NewDateFunc("TO_DATE", dateStr, format)
} }
// TO_NUMBER converts string to numeric using format // TO_NUMBER converts string to numeric using format
@ -482,7 +482,7 @@ func TO_TIMESTAMP(timestampzStr, format StringExpression) TimestampzExpression {
// CURRENT_DATE returns current date // CURRENT_DATE returns current date
func CURRENT_DATE() DateExpression { func CURRENT_DATE() DateExpression {
dateFunc := newDateFunc("CURRENT_DATE") dateFunc := NewDateFunc("CURRENT_DATE")
dateFunc.noBrackets = true dateFunc.noBrackets = true
return dateFunc return dateFunc
} }
@ -522,9 +522,9 @@ func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc var timeFunc *timeFunc
if len(precision) > 0 { if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", FixedLiteral(precision[0])) timeFunc = NewTimeFunc("LOCALTIME", FixedLiteral(precision[0]))
} else { } else {
timeFunc = newTimeFunc("LOCALTIME") timeFunc = NewTimeFunc("LOCALTIME")
} }
timeFunc.noBrackets = true timeFunc.noBrackets = true
@ -558,26 +558,26 @@ func NOW() TimestampzExpression {
func COALESCE(value Expression, values ...Expression) Expression { func COALESCE(value Expression, values ...Expression) Expression {
var allValues = []Expression{value} var allValues = []Expression{value}
allValues = append(allValues, values...) allValues = append(allValues, values...)
return newFunc("COALESCE", allValues, nil) return NewFunc("COALESCE", allValues, nil)
} }
// NULLIF function returns a null value if value1 equals value2; otherwise it returns value1. // NULLIF function returns a null value if value1 equals value2; otherwise it returns value1.
func NULLIF(value1, value2 Expression) Expression { func NULLIF(value1, value2 Expression) Expression {
return newFunc("NULLIF", []Expression{value1, value2}, nil) return NewFunc("NULLIF", []Expression{value1, value2}, nil)
} }
// GREATEST selects the largest value from a list of expressions // GREATEST selects the largest value from a list of expressions
func GREATEST(value Expression, values ...Expression) Expression { func GREATEST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value} var allValues = []Expression{value}
allValues = append(allValues, values...) allValues = append(allValues, values...)
return newFunc("GREATEST", allValues, nil) return NewFunc("GREATEST", allValues, nil)
} }
// LEAST selects the smallest value from a list of expressions // LEAST selects the smallest value from a list of expressions
func LEAST(value Expression, values ...Expression) Expression { func LEAST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value} var allValues = []Expression{value}
allValues = append(allValues, values...) allValues = append(allValues, values...)
return newFunc("LEAST", allValues, nil) return NewFunc("LEAST", allValues, nil)
} }
//--------------------------------------------------------------------// //--------------------------------------------------------------------//
@ -590,7 +590,8 @@ type funcExpressionImpl struct {
noBrackets bool noBrackets bool
} }
func newFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl { // NewFunc creates new function with name and expressions parameters
func NewFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl {
funcExp := &funcExpressionImpl{ funcExp := &funcExpressionImpl{
name: name, name: name,
expressions: expressions, expressions: expressions,
@ -608,7 +609,7 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
// NewFloatWindowFunc creates new float function with name and expressions // NewFloatWindowFunc creates new float function with name and expressions
func newWindowFunc(name string, expressions ...Expression) windowExpression { func newWindowFunc(name string, expressions ...Expression) windowExpression {
newFun := newFunc(name, expressions, nil) newFun := NewFunc(name, expressions, nil)
windowExpr := newWindowExpression(newFun) windowExpr := newWindowExpression(newFun)
newFun.ExpressionInterfaceImpl.Parent = windowExpr newFun.ExpressionInterfaceImpl.Parent = windowExpr
@ -645,7 +646,7 @@ type boolFunc struct {
func newBoolFunc(name string, expressions ...Expression) BoolExpression { func newBoolFunc(name string, expressions ...Expression) BoolExpression {
boolFunc := &boolFunc{} boolFunc := &boolFunc{}
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) boolFunc.funcExpressionImpl = *NewFunc(name, expressions, boolFunc)
boolFunc.boolInterfaceImpl.parent = boolFunc boolFunc.boolInterfaceImpl.parent = boolFunc
boolFunc.ExpressionInterfaceImpl.Parent = boolFunc boolFunc.ExpressionInterfaceImpl.Parent = boolFunc
@ -656,7 +657,7 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression {
func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression { func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression {
boolFunc := &boolFunc{} boolFunc := &boolFunc{}
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) boolFunc.funcExpressionImpl = *NewFunc(name, expressions, boolFunc)
intWindowFunc := newBoolWindowExpression(boolFunc) intWindowFunc := newBoolWindowExpression(boolFunc)
boolFunc.boolInterfaceImpl.parent = intWindowFunc boolFunc.boolInterfaceImpl.parent = intWindowFunc
boolFunc.ExpressionInterfaceImpl.Parent = intWindowFunc boolFunc.ExpressionInterfaceImpl.Parent = intWindowFunc
@ -673,7 +674,7 @@ type floatFunc struct {
func NewFloatFunc(name string, expressions ...Expression) FloatExpression { func NewFloatFunc(name string, expressions ...Expression) FloatExpression {
floatFunc := &floatFunc{} floatFunc := &floatFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc)
floatFunc.floatInterfaceImpl.parent = floatFunc floatFunc.floatInterfaceImpl.parent = floatFunc
return floatFunc return floatFunc
@ -683,7 +684,7 @@ func NewFloatFunc(name string, expressions ...Expression) FloatExpression {
func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression { func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression {
floatFunc := &floatFunc{} floatFunc := &floatFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc)
floatWindowFunc := newFloatWindowExpression(floatFunc) floatWindowFunc := newFloatWindowExpression(floatFunc)
floatFunc.floatInterfaceImpl.parent = floatWindowFunc floatFunc.floatInterfaceImpl.parent = floatWindowFunc
floatFunc.ExpressionInterfaceImpl.Parent = floatWindowFunc floatFunc.ExpressionInterfaceImpl.Parent = floatWindowFunc
@ -699,7 +700,7 @@ type integerFunc struct {
func newIntegerFunc(name string, expressions ...Expression) IntegerExpression { func newIntegerFunc(name string, expressions ...Expression) IntegerExpression {
floatFunc := &integerFunc{} floatFunc := &integerFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc)
floatFunc.integerInterfaceImpl.parent = floatFunc floatFunc.integerInterfaceImpl.parent = floatFunc
return floatFunc return floatFunc
@ -709,7 +710,7 @@ func newIntegerFunc(name string, expressions ...Expression) IntegerExpression {
func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression { func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression {
integerFunc := &integerFunc{} integerFunc := &integerFunc{}
integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc) integerFunc.funcExpressionImpl = *NewFunc(name, expressions, integerFunc)
intWindowFunc := newIntegerWindowExpression(integerFunc) intWindowFunc := newIntegerWindowExpression(integerFunc)
integerFunc.integerInterfaceImpl.parent = intWindowFunc integerFunc.integerInterfaceImpl.parent = intWindowFunc
integerFunc.ExpressionInterfaceImpl.Parent = intWindowFunc integerFunc.ExpressionInterfaceImpl.Parent = intWindowFunc
@ -722,10 +723,11 @@ type stringFunc struct {
stringInterfaceImpl stringInterfaceImpl
} }
func newStringFunc(name string, expressions ...Expression) StringExpression { // NewStringFunc creates new string function with name and expression parameters
func NewStringFunc(name string, expressions ...Expression) StringExpression {
stringFunc := &stringFunc{} stringFunc := &stringFunc{}
stringFunc.funcExpressionImpl = *newFunc(name, expressions, stringFunc) stringFunc.funcExpressionImpl = *NewFunc(name, expressions, stringFunc)
stringFunc.stringInterfaceImpl.parent = stringFunc stringFunc.stringInterfaceImpl.parent = stringFunc
return stringFunc return stringFunc
@ -736,10 +738,11 @@ type dateFunc struct {
dateInterfaceImpl dateInterfaceImpl
} }
func newDateFunc(name string, expressions ...Expression) *dateFunc { // NewDateFunc creates new date function with name and expression parameters
func NewDateFunc(name string, expressions ...Expression) *dateFunc {
dateFunc := &dateFunc{} dateFunc := &dateFunc{}
dateFunc.funcExpressionImpl = *newFunc(name, expressions, dateFunc) dateFunc.funcExpressionImpl = *NewFunc(name, expressions, dateFunc)
dateFunc.dateInterfaceImpl.parent = dateFunc dateFunc.dateInterfaceImpl.parent = dateFunc
return dateFunc return dateFunc
@ -750,10 +753,11 @@ type timeFunc struct {
timeInterfaceImpl timeInterfaceImpl
} }
func newTimeFunc(name string, expressions ...Expression) *timeFunc { // NewTimeFunc creates new time function with name and expression parameters
func NewTimeFunc(name string, expressions ...Expression) *timeFunc {
timeFun := &timeFunc{} timeFun := &timeFunc{}
timeFun.funcExpressionImpl = *newFunc(name, expressions, timeFun) timeFun.funcExpressionImpl = *NewFunc(name, expressions, timeFun)
timeFun.timeInterfaceImpl.parent = timeFun timeFun.timeInterfaceImpl.parent = timeFun
return timeFun return timeFun
@ -767,7 +771,7 @@ type timezFunc struct {
func newTimezFunc(name string, expressions ...Expression) *timezFunc { func newTimezFunc(name string, expressions ...Expression) *timezFunc {
timezFun := &timezFunc{} timezFun := &timezFunc{}
timezFun.funcExpressionImpl = *newFunc(name, expressions, timezFun) timezFun.funcExpressionImpl = *NewFunc(name, expressions, timezFun)
timezFun.timezInterfaceImpl.parent = timezFun timezFun.timezInterfaceImpl.parent = timezFun
return timezFun return timezFun
@ -782,7 +786,7 @@ type timestampFunc struct {
func NewTimestampFunc(name string, expressions ...Expression) *timestampFunc { func NewTimestampFunc(name string, expressions ...Expression) *timestampFunc {
timestampFunc := &timestampFunc{} timestampFunc := &timestampFunc{}
timestampFunc.funcExpressionImpl = *newFunc(name, expressions, timestampFunc) timestampFunc.funcExpressionImpl = *NewFunc(name, expressions, timestampFunc)
timestampFunc.timestampInterfaceImpl.parent = timestampFunc timestampFunc.timestampInterfaceImpl.parent = timestampFunc
return timestampFunc return timestampFunc
@ -796,7 +800,7 @@ type timestampzFunc struct {
func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc { func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc {
timestampzFunc := &timestampzFunc{} timestampzFunc := &timestampzFunc{}
timestampzFunc.funcExpressionImpl = *newFunc(name, expressions, timestampzFunc) timestampzFunc.funcExpressionImpl = *NewFunc(name, expressions, timestampzFunc)
timestampzFunc.timestampzInterfaceImpl.parent = timestampzFunc timestampzFunc.timestampzInterfaceImpl.parent = timestampzFunc
return timestampzFunc return timestampzFunc
@ -804,5 +808,5 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc {
// Func can be used to call an custom or as of yet unsupported function in the database. // Func can be used to call an custom or as of yet unsupported function in the database.
func Func(name string, expressions ...Expression) Expression { func Func(name string, expressions ...Expression) Expression {
return newFunc(name, expressions, nil) return NewFunc(name, expressions, nil)
} }

View file

@ -19,7 +19,7 @@ func (i *IsIntervalImpl) isInterval() {}
// NewInterval creates new interval from serializer // NewInterval creates new interval from serializer
func NewInterval(s Serializer) *IntervalImpl { func NewInterval(s Serializer) *IntervalImpl {
newInterval := &IntervalImpl{ newInterval := &IntervalImpl{
interval: s, Value: s,
} }
return newInterval return newInterval
@ -27,11 +27,11 @@ func NewInterval(s Serializer) *IntervalImpl {
// IntervalImpl is implementation of Interval type // IntervalImpl is implementation of Interval type
type IntervalImpl struct { type IntervalImpl struct {
interval Serializer Value Serializer
IsIntervalImpl IsIntervalImpl
} }
func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("INTERVAL") out.WriteString("INTERVAL")
i.interval.serialize(statement, out, FallTrough(options)...) i.Value.serialize(statement, out, FallTrough(options)...)
} }

View file

@ -375,9 +375,14 @@ type wrap struct {
expressions []Expression expressions []Expression
} }
func (n *wrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(") out.WriteString("(")
serializeExpressionList(statement, n.expressions, ", ", out)
if len(n.expressions) == 1 {
options = append(options, NoWrap, Ident)
}
serializeExpressionList(statementType, n.expressions, ", ", out, options...)
out.WriteString(")") out.WriteString(")")
} }

View file

@ -7,8 +7,10 @@ type SerializeOption int
const ( const (
NoWrap SerializeOption = iota NoWrap SerializeOption = iota
SkipNewLine SkipNewLine
Ident
fallTroughOptions // fall trough options fallTroughOptions // fall trough options
ShortName ShortName
) )

View file

@ -195,10 +195,19 @@ func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, opti
out.IncreaseIdent() out.IncreaseIdent()
} }
if contains(options, Ident) {
out.IncreaseIdent()
}
for _, clause := range s.Clauses { for _, clause := range s.Clauses {
clause.Serialize(statement, out, FallTrough(options)...) clause.Serialize(statement, out, FallTrough(options)...)
} }
if contains(options, Ident) {
out.DecreaseIdent()
out.NewLine()
}
if !contains(options, NoWrap) { if !contains(options, NoWrap) {
out.DecreaseIdent() out.DecreaseIdent()
out.NewLine() out.NewLine()

View file

@ -21,14 +21,19 @@ func SerializeClauseList(statement StatementType, clauses []Serializer, out *SQL
} }
} }
func serializeExpressionList(statement StatementType, expressions []Expression, separator string, out *SQLBuilder) { func serializeExpressionList(
statement StatementType,
expressions []Expression,
separator string,
out *SQLBuilder,
options ...SerializeOption) {
for i, value := range expressions { for i, expression := range expressions {
if i > 0 { if i > 0 {
out.WriteString(separator) out.WriteString(separator)
} }
value.serialize(statement, out) expression.serialize(statement, out, options...)
} }
} }

View file

@ -5,7 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/qrm"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -20,6 +20,11 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
// UnixTimeComparer will compare time equality while ignoring time zone
var UnixTimeComparer = cmp.Comparer(func(t1, t2 time.Time) bool {
return t1.Unix() == t2.Unix()
})
// AssertExec assert statement execution for successful execution and number of rows affected // AssertExec assert statement execution for successful execution and number of rows affected
func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) {
res, err := stmt.Exec(db) res, err := stmt.Exec(db)
@ -66,7 +71,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath) filePath := getFullPath(testRelativePath)
err := ioutil.WriteFile(filePath, jsonText, 0644) err := ioutil.WriteFile(filePath, jsonText, 0644)
utils.PanicOnError(err) throw.OnError(err)
} }
// AssertJSONFile check if data json representation is the same as json at testRelativePath // AssertJSONFile check if data json representation is the same as json at testRelativePath
@ -113,7 +118,7 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st
_, args := query.Sql() _, args := query.Sql()
if len(expectedArgs) > 0 { if len(expectedArgs) > 0 {
AssertDeepEqual(t, args, expectedArgs, "arguments are not equal") AssertDeepEqual(t, args, expectedArgs)
} }
debugSql := query.DebugSql() debugSql := query.DebugSql()
@ -223,9 +228,9 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
} }
// AssertDeepEqual checks if actual and expected objects are deeply equal. // AssertDeepEqual checks if actual and expected objects are deeply equal.
func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { func AssertDeepEqual(t *testing.T, actual, expected interface{}, option ...cmp.Option) {
if !assert.True(t, cmp.Equal(actual, expected), msg) { if !assert.True(t, cmp.Equal(actual, expected, option...)) {
printDiff(actual, expected) printDiff(actual, expected, option...)
t.FailNow() t.FailNow()
} }
} }
@ -237,7 +242,8 @@ func assertQueryString(t *testing.T, actual, expected string) {
} }
} }
func printDiff(actual, expected interface{}) { func printDiff(actual, expected interface{}, options ...cmp.Option) {
fmt.Println(cmp.Diff(actual, expected, options...))
fmt.Println("Actual: ") fmt.Println("Actual: ")
fmt.Println(actual) fmt.Println(actual)
fmt.Println("Expected: ") fmt.Println("Expected: ")

View file

@ -1,7 +1,7 @@
package testutils package testutils
import ( import (
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils/throw"
"strings" "strings"
"time" "time"
) )
@ -10,7 +10,7 @@ import (
func Date(t string) *time.Time { func Date(t string) *time.Time {
newTime, err := time.Parse("2006-01-02", t) newTime, err := time.Parse("2006-01-02", t)
utils.PanicOnError(err) throw.OnError(err)
return &newTime return &newTime
} }
@ -26,7 +26,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time {
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000") newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000")
utils.PanicOnError(err) throw.OnError(err)
return &newTime return &newTime
} }
@ -35,7 +35,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time {
func TimeWithoutTimeZone(t string) *time.Time { func TimeWithoutTimeZone(t string) *time.Time {
newTime, err := time.Parse("15:04:05", t) newTime, err := time.Parse("15:04:05", t)
utils.PanicOnError(err) throw.OnError(err)
return &newTime return &newTime
} }
@ -44,7 +44,7 @@ func TimeWithoutTimeZone(t string) *time.Time {
func TimeWithTimeZone(t string) *time.Time { func TimeWithTimeZone(t string) *time.Time {
newTimez, err := time.Parse("15:04:05 -0700", t) newTimez, err := time.Parse("15:04:05 -0700", t)
utils.PanicOnError(err) throw.OnError(err)
return &newTimez return &newTimez
} }
@ -60,7 +60,7 @@ func TimestampWithTimeZone(t string, precision int) *time.Time {
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t) newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t)
utils.PanicOnError(err) throw.OnError(err)
return &newTime return &newTime
} }

View file

@ -0,0 +1,9 @@
package min
// Int returns minimum of two int values
func Int(a, b int) int {
if a < b {
return a
}
return b
}

View file

@ -0,0 +1,8 @@
package throw
// OnError will panic if err is not nill
func OnError(err error) {
if err != nil {
panic(err)
}
}

View file

@ -10,7 +10,6 @@ import (
"reflect" "reflect"
"strings" "strings"
"time" "time"
"unicode"
) )
// ToGoIdentifier converts database to Go identifier. // ToGoIdentifier converts database to Go identifier.
@ -18,16 +17,6 @@ func ToGoIdentifier(databaseIdentifier string) string {
return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier)) return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier))
} }
// ToGoEnumValueIdentifier converts enum value name to Go identifier name.
func ToGoEnumValueIdentifier(enumName, enumValue string) string {
enumValueIdentifier := ToGoIdentifier(enumValue)
if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) {
return ToGoIdentifier(enumName) + enumValueIdentifier
}
return enumValueIdentifier
}
// ToGoFileName converts database identifier to Go file name. // ToGoFileName converts database identifier to Go file name.
func ToGoFileName(databaseIdentifier string) string { func ToGoFileName(databaseIdentifier string) string {
return strings.ToLower(replaceInvalidChars(databaseIdentifier)) return strings.ToLower(replaceInvalidChars(databaseIdentifier))
@ -35,7 +24,11 @@ func ToGoFileName(databaseIdentifier string) string {
// SaveGoFile saves go file at folder dir, with name fileName and contents text. // SaveGoFile saves go file at folder dir, with name fileName and contents text.
func SaveGoFile(dirPath, fileName string, text []byte) error { func SaveGoFile(dirPath, fileName string, text []byte) error {
newGoFilePath := filepath.Join(dirPath, fileName) + ".go" newGoFilePath := filepath.Join(dirPath, fileName)
if !strings.HasSuffix(newGoFilePath, ".go") {
newGoFilePath += ".go"
}
file, err := os.Create(newGoFilePath) file, err := os.Create(newGoFilePath)
@ -160,13 +153,6 @@ func MustBeInitializedPtr(val interface{}, errorStr string) {
} }
} }
// PanicOnError panics if err is not nil
func PanicOnError(err error) {
if err != nil {
panic(err)
}
}
// ErrorCatch is used in defer to recover from panics and to set err // ErrorCatch is used in defer to recover from panics and to set err
func ErrorCatch(err *error) { func ErrorCatch(err *error) {
recovered := recover() recovered := recover()

View file

@ -25,11 +25,6 @@ func TestToGoIdentifier(t *testing.T) {
require.Equal(t, ToGoIdentifier("My-Table"), "MyTable") require.Equal(t, ToGoIdentifier("My-Table"), "MyTable")
} }
func TestToGoEnumValueIdentifier(t *testing.T) {
require.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue")
require.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100")
}
func TestErrorCatchErr(t *testing.T) { func TestErrorCatchErr(t *testing.T) {
var err error var err error

View file

@ -7,7 +7,6 @@ import (
) )
func TestInvalidInsert(t *testing.T) { func TestInvalidInsert(t *testing.T) {
assertStatementSqlErr(t, table1.INSERT(table1Col1), "jet: VALUES or QUERY has to be specified for INSERT statement")
assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list")
} }

View file

@ -43,7 +43,7 @@ type SelectStatement interface {
DISTINCT() SelectStatement DISTINCT() SelectStatement
FROM(tables ...ReadableTable) SelectStatement FROM(tables ...ReadableTable) SelectStatement
WHERE(expression BoolExpression) SelectStatement WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement GROUP_BY(groupByClauses ...GroupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement
WINDOW(name string) windowExpand WINDOW(name string) windowExpand
ORDER_BY(orderByClauses ...OrderByClause) SelectStatement ORDER_BY(orderByClauses ...OrderByClause) SelectStatement
@ -118,7 +118,7 @@ func (s *selectStatementImpl) WHERE(condition BoolExpression) SelectStatement {
return s return s
} }
func (s *selectStatementImpl) GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement { func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement {
s.GroupBy.List = groupByClauses s.GroupBy.List = groupByClauses
return s return s
} }

View file

@ -20,5 +20,8 @@ type PrintableStatement = jet.PrintableStatement
// OrderByClause is the combination of an expression and the wanted ordering to use as input for ORDER BY. // OrderByClause is the combination of an expression and the wanted ordering to use as input for ORDER BY.
type OrderByClause = jet.OrderByClause type OrderByClause = jet.OrderByClause
// GroupByClause interface to use as input for GROUP_BY
type GroupByClause = jet.GroupByClause
// SetLogger sets automatic statement logging // SetLogger sets automatic statement logging
var SetLogger = jet.SetLoggerFunc var SetLogger = jet.SetLoggerFunc

View file

@ -4,33 +4,10 @@ import (
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
) )
type clauseReturning struct {
ProjectionList []jet.Projection
}
func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(r.ProjectionList) == 0 {
return
}
out.NewLine()
out.WriteString("RETURNING")
out.IncreaseIdent()
out.WriteProjections(statementType, r.ProjectionList)
out.DecreaseIdent()
}
func (r clauseReturning) Projections() ProjectionList {
return r.ProjectionList
}
// ========================================== //
type onConflict interface { type onConflict interface {
ON_CONSTRAINT(name string) conflictTarget ON_CONSTRAINT(name string) conflictTarget
WHERE(indexPredicate BoolExpression) conflictTarget WHERE(indexPredicate BoolExpression) conflictTarget
DO_NOTHING() InsertStatement conflictTarget
DO_UPDATE(action conflictAction) InsertStatement
} }
type conflictTarget interface { type conflictTarget interface {

View file

@ -16,7 +16,7 @@ type deleteStatementImpl struct {
Delete jet.ClauseStatementBegin Delete jet.ClauseStatementBegin
Where jet.ClauseWhere Where jet.ClauseWhere
Returning clauseReturning Returning jet.ClauseReturning
} }
func newDeleteStatement(table WritableTable) DeleteStatement { func newDeleteStatement(table WritableTable) DeleteStatement {

View file

@ -46,33 +46,33 @@ func TestExists(t *testing.T) {
func TestIN(t *testing.T) { func TestIN(t *testing.T) {
assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)),
`($1 IN (( `($1 IN (
SELECT table1.col1 AS "table1.col1" SELECT table1.col1 AS "table1.col1"
FROM db.table1 FROM db.table1
)))`, float64(1.11)) ))`, float64(1.11))
assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) IN (( `(ROW($1, table1.col1) IN (
SELECT table2.col3 AS "table2.col3", SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1" table3.col1 AS "table3.col1"
FROM db.table2 FROM db.table2
)))`, int64(12)) ))`, int64(12))
} }
func TestNOT_IN(t *testing.T) { func TestNOT_IN(t *testing.T) {
assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)),
`($1 NOT IN (( `($1 NOT IN (
SELECT table1.col1 AS "table1.col1" SELECT table1.col1 AS "table1.col1"
FROM db.table1 FROM db.table1
)))`, float64(1.11)) ))`, float64(1.11))
assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) NOT IN (( `(ROW($1, table1.col1) NOT IN (
SELECT table2.col3 AS "table2.col3", SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1" table3.col1 AS "table3.col1"
FROM db.table2 FROM db.table2
)))`, int64(12)) ))`, int64(12))
} }
func TestReservedWordEscaped(t *testing.T) { func TestReservedWordEscaped(t *testing.T) {

View file

@ -22,7 +22,11 @@ type InsertStatement interface {
func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement { func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{} newInsert := &insertStatementImpl{}
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnConflict, &newInsert.Returning) &newInsert.Insert,
&newInsert.ValuesQuery,
&newInsert.OnConflict,
&newInsert.Returning,
)
newInsert.Insert.Table = table newInsert.Insert.Table = table
newInsert.Insert.Columns = columns newInsert.Insert.Columns = columns
@ -35,7 +39,7 @@ type insertStatementImpl struct {
Insert jet.ClauseInsert Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery ValuesQuery jet.ClauseValuesQuery
Returning clauseReturning Returning jet.ClauseReturning
OnConflict onConflictClause OnConflict onConflictClause
} }

View file

@ -8,7 +8,6 @@ import (
) )
func TestInvalidInsert(t *testing.T) { func TestInvalidInsert(t *testing.T) {
assertStatementSqlErr(t, table1.INSERT(table1Col1), "jet: VALUES or QUERY has to be specified for INSERT statement")
assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list")
} }
@ -155,7 +154,7 @@ func TestInsert_ON_CONFLICT(t *testing.T) {
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
SET(table1ColBool.SET(Bool(true)), SET(table1ColBool.SET(Bool(true)),
table2ColInt.SET(Int(1)), table2ColInt.SET(Int(1)),
ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))), ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))),
).WHERE(table1Col1.GT(Int(2))), ).WHERE(table1Col1.GT(Int(2))),
). ).
RETURNING(table1Col1, table1ColBool) RETURNING(table1Col1, table1ColBool)

View file

@ -116,7 +116,7 @@ func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression {
panic("jet: invalid number of quantity and unit fields") panic("jet: invalid number of quantity and unit fields")
} }
fields := []string{} var fields []string
for i := 0; i < len(quantityAndUnit); i += 2 { for i := 0; i < len(quantityAndUnit); i += 2 {
quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64) quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64)

View file

@ -1,8 +1,9 @@
package postgres package postgres
import ( import (
"github.com/go-jet/jet/v2/internal/jet"
"math" "math"
"github.com/go-jet/jet/v2/internal/jet"
) )
// RowLock is interface for SELECT statement row lock types // RowLock is interface for SELECT statement row lock types
@ -46,7 +47,7 @@ type SelectStatement interface {
DISTINCT() SelectStatement DISTINCT() SelectStatement
FROM(tables ...ReadableTable) SelectStatement FROM(tables ...ReadableTable) SelectStatement
WHERE(expression BoolExpression) SelectStatement WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement GROUP_BY(groupByClauses ...GroupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement
WINDOW(name string) windowExpand WINDOW(name string) windowExpand
ORDER_BY(orderByClauses ...OrderByClause) SelectStatement ORDER_BY(orderByClauses ...OrderByClause) SelectStatement
@ -121,7 +122,7 @@ func (s *selectStatementImpl) WHERE(condition BoolExpression) SelectStatement {
return s return s
} }
func (s *selectStatementImpl) GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement { func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement {
s.GroupBy.List = groupByClauses s.GroupBy.List = groupByClauses
return s return s
} }

View file

@ -20,5 +20,8 @@ type PrintableStatement = jet.PrintableStatement
// OrderByClause is the combination of an expression and the wanted ordering to use as input for ORDER BY. // OrderByClause is the combination of an expression and the wanted ordering to use as input for ORDER BY.
type OrderByClause = jet.OrderByClause type OrderByClause = jet.OrderByClause
// GroupByClause interface to use as input for GROUP_BY
type GroupByClause = jet.GroupByClause
// SetLogger sets automatic statement logging // SetLogger sets automatic statement logging
var SetLogger = jet.SetLoggerFunc var SetLogger = jet.SetLoggerFunc

View file

@ -22,7 +22,7 @@ type updateStatementImpl struct {
Set clauseSet Set clauseSet
SetNew jet.SetClauseNew SetNew jet.SetClauseNew
Where jet.ClauseWhere Where jet.ClauseWhere
Returning clauseReturning Returning jet.ClauseReturning
} }
func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement { func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement {

View file

@ -1,263 +1,175 @@
package internal package internal
import ( import (
"database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/utils/min"
"reflect"
"strconv" "strconv"
"time" "time"
) )
//===============================================================// // NullBool struct
type NullBool struct {
// NullByteArray struct sql.NullBool
type NullByteArray struct {
ByteArray []byte
Valid bool
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (nb *NullByteArray) Scan(value interface{}) error { func (nb *NullBool) Scan(value interface{}) error {
switch v := value.(type) { switch v := value.(type) {
case nil: case bool:
nb.Valid = false nb.Bool, nb.Valid = v, true
return nil case int8, int16, int32, int64, int:
case []byte: intVal := reflect.ValueOf(v).Int()
nb.ByteArray = append(v[:0:0], v...)
if intVal != 0 && intVal != 1 {
return fmt.Errorf("can't assign %T(%d) to bool", value, value)
}
nb.Bool = intVal == 1
nb.Valid = true
case uint8, uint16, uint32, uint64, uint:
uintVal := reflect.ValueOf(v).Uint()
if uintVal != 0 && uintVal != 1 {
return fmt.Errorf("can't assign %T(%d) to bool", value, value)
}
nb.Bool = uintVal == 1
nb.Valid = true nb.Valid = true
return nil
default: default:
return fmt.Errorf("can't scan []byte from %v", value) return nb.NullBool.Scan(value)
}
} }
// Value implements the driver Valuer interface. return nil
func (nb NullByteArray) Value() (driver.Value, error) {
if !nb.Valid {
return nil, nil
} }
return nb.ByteArray, nil
}
//===============================================================//
// NullTime struct // NullTime struct
type NullTime struct { type NullTime struct {
Time time.Time sql.NullTime
Valid bool // Valid is true if Time is not NULL
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) (err error) { func (nt *NullTime) Scan(value interface{}) error {
switch v := value.(type) { err := nt.NullTime.Scan(value)
case nil:
nt.Valid = false if err == nil {
return return nil
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, nt.Valid = parseTime(string(v))
return
case string:
nt.Time, nt.Valid = parseTime(v)
return
default:
return fmt.Errorf("can't scan time.Time from %v", value)
}
} }
// Value implements the driver Valuer interface. // Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value.
func (nt NullTime) Value() (driver.Value, error) { // At this point we try to parse those values using some of the predefined formats
nt.Time, nt.Valid = tryParseAsTime(value)
if !nt.Valid { if !nt.Valid {
return nil, nil return fmt.Errorf("can't scan time.Time from %q", value)
}
return nt.Time, nil
} }
const formatTime = "2006-01-02 15:04:05.999999" return nil
func parseTime(timeStr string) (t time.Time, valid bool) {
var format string
switch len(timeStr) {
case 8:
format = formatTime[11:19]
case 10, 19, 21, 22, 23, 24, 25, 26:
format = formatTime[:len(timeStr)]
default:
return t, false
} }
t, err := time.Parse(format, timeStr) var formats = []string{
return t, err == nil "2006-01-02 15:04:05-07:00", // sqlite
"2006-01-02 15:04:05.999999", // go-sql-driver/mysql
"15:04:05-07", // pgx
"15:04:05.999999", // pgx
} }
//===============================================================// func tryParseAsTime(value interface{}) (time.Time, bool) {
// NullInt8 struct var timeStr string
type NullInt8 struct {
Int8 int8
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt8) Scan(value interface{}) (err error) {
switch v := value.(type) { switch v := value.(type) {
case nil: case string:
n.Valid = false timeStr = v
return
case int64:
n.Int8, n.Valid = int8(v), true
return
case int8:
n.Int8, n.Valid = v, true
return
case []byte: case []byte:
intV, err := strconv.ParseInt(string(v), 10, 8) timeStr = string(v)
if err == nil { case int64:
n.Int8, n.Valid = int8(intV), true return time.Unix(v, 0), true // sqlite
}
return err
default: default:
return fmt.Errorf("can't scan int8 from %v", value) return time.Time{}, false
}
} }
// Value implements the driver Valuer interface. for _, format := range formats {
func (n NullInt8) Value() (driver.Value, error) { formatLen := min.Int(len(format), len(timeStr))
if !n.Valid { t, err := time.Parse(format[:formatLen], timeStr)
return nil, nil
} if err != nil {
return n.Int8, nil continue
} }
//===============================================================// return t, true
}
// NullInt16 struct return time.Time{}, false
type NullInt16 struct { }
Int16 int16
// NullUInt64 struct
type NullUInt64 struct {
UInt64 uint64
Valid bool Valid bool
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (n *NullInt16) Scan(value interface{}) error { func (n *NullUInt64) Scan(value interface{}) error {
var stringValue string
switch v := value.(type) { switch v := value.(type) {
case nil: case nil:
n.Valid = false n.Valid = false
return nil return nil
case int64: case int64:
n.Int16, n.Valid = int16(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case int16: case uint64:
n.Int16, n.Valid = v, true n.UInt64, n.Valid = v, true
return nil
case int8:
n.Int16, n.Valid = int16(v), true
return nil
case uint8:
n.Int16, n.Valid = int16(v), true
return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 16)
if err == nil {
n.Int16, n.Valid = int16(intV), true
}
return nil
default:
return fmt.Errorf("can't scan int16 from %v", value)
}
}
// Value implements the driver Valuer interface.
func (n NullInt16) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int16, nil
}
//===============================================================//
// NullInt32 struct
type NullInt32 struct {
Int32 int32
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt32) Scan(value interface{}) error {
switch v := value.(type) {
case nil:
n.Valid = false
return nil
case int64:
n.Int32, n.Valid = int32(v), true
return nil return nil
case int32: case int32:
n.Int32, n.Valid = v, true n.UInt64, n.Valid = uint64(v), true
return nil
case uint32:
n.UInt64, n.Valid = uint64(v), true
return nil return nil
case int16: case int16:
n.Int32, n.Valid = int32(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case uint16: case uint16:
n.Int32, n.Valid = int32(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case int8: case int8:
n.Int32, n.Valid = int32(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case uint8: case uint8:
n.Int32, n.Valid = int32(v), true n.UInt64, n.Valid = uint64(v), true
return nil
case int:
n.UInt64, n.Valid = uint64(v), true
return nil
case uint:
n.UInt64, n.Valid = uint64(v), true
return nil return nil
case []byte: case []byte:
intV, err := strconv.ParseInt(string(v), 10, 32) stringValue = string(v)
if err == nil { case string:
n.Int32, n.Valid = int32(intV), true stringValue = v
}
return nil
default: default:
return fmt.Errorf("can't scan int32 from %v", value) return fmt.Errorf("can't scan uint64 from %v", value)
} }
uintV, err := strconv.ParseUint(stringValue, 10, 64)
if err != nil {
return err
}
n.UInt64 = uintV
n.Valid = true
return nil
} }
// Value implements the driver Valuer interface. // Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) { func (n NullUInt64) Value() (driver.Value, error) {
if !n.Valid { if !n.Valid {
return nil, nil return nil, nil
} }
return n.Int32, nil return n.UInt64, nil
}
//===============================================================//
// NullFloat32 struct
type NullFloat32 struct {
Float32 float32
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullFloat32) Scan(value interface{}) error {
switch v := value.(type) {
case nil:
n.Valid = false
return nil
case float64:
n.Float32, n.Valid = float32(v), true
return nil
case float32:
n.Float32, n.Valid = v, true
return nil
default:
return fmt.Errorf("can't scan float32 from %v", value)
}
}
// Value implements the driver Valuer interface.
func (n NullFloat32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Float32, nil
} }

View file

@ -7,141 +7,85 @@ import (
"time" "time"
) )
func TestNullByteArray(t *testing.T) { func TestNullBool(t *testing.T) {
var array NullByteArray var nullBool NullBool
require.NoError(t, array.Scan(nil)) require.NoError(t, nullBool.Scan(nil))
require.Equal(t, array.Valid, false) require.Equal(t, nullBool.Valid, false)
require.NoError(t, array.Scan([]byte("bytea"))) require.NoError(t, nullBool.Scan(int64(1)))
require.Equal(t, array.Valid, true) require.Equal(t, nullBool.Valid, true)
require.Equal(t, string(array.ByteArray), string([]byte("bytea"))) value, _ := nullBool.Value()
require.Equal(t, value, true)
require.Error(t, array.Scan(12), "can't scan []byte from 12") require.NoError(t, nullBool.Scan(uint32(0)))
require.Equal(t, nullBool.Valid, true)
value, _ = nullBool.Value()
require.Equal(t, value, false)
require.EqualError(t, nullBool.Scan(uint16(22)), "can't assign uint16(22) to bool")
} }
func TestNullTime(t *testing.T) { func TestNullTime(t *testing.T) {
var array NullTime var nullTime NullTime
require.NoError(t, array.Scan(nil)) require.NoError(t, nullTime.Scan(nil))
require.Equal(t, array.Valid, false) require.Equal(t, nullTime.Valid, false)
time := time.Now() time := time.Now()
require.NoError(t, array.Scan(time)) require.NoError(t, nullTime.Scan(time))
require.Equal(t, array.Valid, true) require.Equal(t, nullTime.Valid, true)
value, _ := array.Value() value, _ := nullTime.Value()
require.Equal(t, value, time) require.Equal(t, value, time)
require.NoError(t, array.Scan([]byte("13:10:11"))) require.NoError(t, nullTime.Scan([]byte("13:10:11")))
require.Equal(t, array.Valid, true) require.Equal(t, nullTime.Valid, true)
value, _ = array.Value() value, _ = nullTime.Value()
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
require.NoError(t, array.Scan("13:10:11")) require.NoError(t, nullTime.Scan("13:10:11"))
require.Equal(t, array.Valid, true) require.Equal(t, nullTime.Valid, true)
value, _ = array.Value() value, _ = nullTime.Value()
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
require.Error(t, array.Scan(12), "can't scan time.Time from 12") require.Error(t, nullTime.Scan(12), "can't scan time.Time from 12")
} }
func TestNullInt8(t *testing.T) { func TestNullUInt64(t *testing.T) {
var array NullInt8 var nullUInt64 NullUInt64
require.NoError(t, array.Scan(nil)) require.NoError(t, nullUInt64.Scan(nil))
require.Equal(t, array.Valid, false) require.Equal(t, nullUInt64.Valid, false)
require.NoError(t, array.Scan(int64(11))) require.NoError(t, nullUInt64.Scan(int64(11)))
require.Equal(t, array.Valid, true) require.Equal(t, nullUInt64.Valid, true)
value, _ := array.Value() value, _ := nullUInt64.Value()
require.Equal(t, value, int8(11)) require.Equal(t, value, uint64(11))
require.Error(t, array.Scan("text"), "can't scan int8 from text") require.NoError(t, nullUInt64.Scan(int32(32)))
} require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
func TestNullInt16(t *testing.T) { require.Equal(t, value, uint64(32))
var array NullInt16
require.NoError(t, nullUInt64.Scan(int16(20)))
require.NoError(t, array.Scan(nil)) require.Equal(t, nullUInt64.Valid, true)
require.Equal(t, array.Valid, false) value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(20))
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true) require.NoError(t, nullUInt64.Scan(uint16(16)))
value, _ := array.Value() require.Equal(t, nullUInt64.Valid, true)
require.Equal(t, value, int16(11)) value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(16))
require.NoError(t, array.Scan(int16(20)))
require.Equal(t, array.Valid, true) require.NoError(t, nullUInt64.Scan(int8(30)))
value, _ = array.Value() require.Equal(t, nullUInt64.Valid, true)
require.Equal(t, value, int16(20)) value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(30))
require.NoError(t, array.Scan(int8(30)))
require.Equal(t, array.Valid, true) require.NoError(t, nullUInt64.Scan(uint8(30)))
value, _ = array.Value() require.Equal(t, nullUInt64.Valid, true)
require.Equal(t, value, int16(30)) value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(30))
require.NoError(t, array.Scan(uint8(30)))
require.Equal(t, array.Valid, true) require.Error(t, nullUInt64.Scan("text"), "can't scan int32 from text")
value, _ = array.Value()
require.Equal(t, value, int16(30))
require.Error(t, array.Scan("text"), "can't scan int16 from text")
}
func TestNullInt32(t *testing.T) {
var array NullInt32
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, int32(11))
require.NoError(t, array.Scan(int32(32)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(32))
require.NoError(t, array.Scan(int16(20)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(20))
require.NoError(t, array.Scan(uint16(16)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(16))
require.NoError(t, array.Scan(int8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(30))
require.NoError(t, array.Scan(uint8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(30))
require.Error(t, array.Scan("text"), "can't scan int32 from text")
}
func TestNullFloat32(t *testing.T) {
var array NullFloat32
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(float64(64)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, float32(64))
require.NoError(t, array.Scan(float32(32)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, float32(32))
require.Error(t, array.Scan(12), "can't scan float32 from 12")
} }

View file

@ -27,7 +27,10 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
if destinationPtrType.Elem().Kind() == reflect.Slice { if destinationPtrType.Elem().Kind() == reflect.Slice {
_, err := queryToSlice(ctx, db, query, args, destPtr) _, err := queryToSlice(ctx, db, query, args, destPtr)
return err if err != nil {
return fmt.Errorf("jet: %w", err)
}
return nil
} else if destinationPtrType.Elem().Kind() == reflect.Struct { } else if destinationPtrType.Elem().Kind() == reflect.Struct {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem() tempSliceValue := tempSlicePtrValue.Elem()
@ -35,7 +38,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface())
if err != nil { if err != nil {
return err return fmt.Errorf("jet: %w", err)
} }
if rowsProcessed == 0 { if rowsProcessed == 0 {
@ -214,7 +217,7 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value
} }
rowElemPtr := scanContext.rowElemValuePtr(index) rowElemPtr := scanContext.rowElemValuePtr(index)
if !rowElemPtr.IsNil() { if rowElemPtr.IsValid() && !rowElemPtr.IsNil() {
updated = true updated = true
err = appendElemToSlice(slicePtrValue, rowElemPtr) err = appendElemToSlice(slicePtrValue, rowElemPtr)
if err != nil { if err != nil {
@ -275,10 +278,16 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
err = scanner.Scan(cellValue) err = scanner.Scan(cellValue)
if err != nil { if err != nil {
panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String()) err = fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
return
} }
} else { } else {
setReflectValue(reflect.ValueOf(cellValue), fieldValue) err = setReflectValue(reflect.ValueOf(cellValue), fieldValue)
if err != nil {
err = fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
return
}
} }
} }
} }

View file

@ -2,9 +2,7 @@ package qrm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/utils"
"reflect" "reflect"
"strings" "strings"
) )
@ -45,7 +43,7 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
} }
return &scanContext{ return &scanContext{
row: createScanValue(columnTypes), row: createScanSlice(len(columnTypes)),
uniqueDestObjectsMap: make(map[string]int), uniqueDestObjectsMap: make(map[string]int),
groupKeyInfoCache: make(map[string]groupKeyInfo), groupKeyInfoCache: make(map[string]groupKeyInfo),
@ -55,6 +53,17 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
}, nil }, nil
} }
func createScanSlice(columnCount int) []interface{} {
scanSlice := make([]interface{}, columnCount)
scanPtrSlice := make([]interface{}, columnCount)
for i := range scanPtrSlice {
scanPtrSlice[i] = &scanSlice[i] // if destination is pointer to interface sql.Scan will just forward driver value
}
return scanPtrSlice
}
type typeInfo struct { type typeInfo struct {
fieldMappings []fieldMapping fieldMappings []fieldMapping
} }
@ -209,22 +218,23 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int {
} }
func (s *scanContext) rowElem(index int) interface{} { func (s *scanContext) rowElem(index int) interface{} {
cellValue := reflect.ValueOf(s.row[index])
valuer, ok := s.row[index].(driver.Valuer) if cellValue.IsValid() && !cellValue.IsNil() {
return cellValue.Elem().Interface()
}
utils.MustBeTrue(ok, "jet: internal error, scan value doesn't implement driver.Valuer") return nil
value, err := valuer.Value()
utils.PanicOnError(err)
return value
} }
func (s *scanContext) rowElemValuePtr(index int) reflect.Value { func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
rowElem := s.rowElem(index) rowElem := s.rowElem(index)
rowElemValue := reflect.ValueOf(rowElem) rowElemValue := reflect.ValueOf(rowElem)
if !rowElemValue.IsValid() {
return reflect.Value{}
}
if rowElemValue.Kind() == reflect.Ptr { if rowElemValue.Kind() == reflect.Ptr {
return rowElemValue return rowElemValue
} }

View file

@ -7,7 +7,6 @@ import (
"github.com/go-jet/jet/v2/qrm/internal" "github.com/go-jet/jet/v2/qrm/internal"
"github.com/google/uuid" "github.com/google/uuid"
"reflect" "reflect"
"strconv"
"strings" "strings"
"time" "time"
) )
@ -56,21 +55,30 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
sliceValue := slicePtrValue.Elem() sliceValue := slicePtrValue.Elem()
sliceElemType := sliceValue.Type().Elem() sliceElemType := sliceValue.Type().Elem()
newElemValue := objPtrValue var newSliceElemValue reflect.Value
if sliceElemType.Kind() != reflect.Ptr { if objPtrValue.Type().AssignableTo(sliceElemType) {
newElemValue = objPtrValue.Elem() newSliceElemValue = objPtrValue
} else if objPtrValue.Elem().Type().AssignableTo(sliceElemType) {
newSliceElemValue = objPtrValue.Elem()
} else {
newSliceElemValue = reflect.New(sliceElemType).Elem()
var err error
if newSliceElemValue.Kind() == reflect.Ptr {
newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem()))
err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem())
} else {
err = tryAssign(objPtrValue.Elem(), newSliceElemValue)
} }
if newElemValue.Type().ConvertibleTo(sliceElemType) { if err != nil {
newElemValue = newElemValue.Convert(sliceElemType) return fmt.Errorf("can't append %T to %T slice: %w", objPtrValue.Elem().Interface(), sliceValue.Interface(), err)
}
} }
if !newElemValue.Type().AssignableTo(sliceElemType) { sliceValue.Set(reflect.Append(sliceValue, newSliceElemValue))
panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice")
}
sliceValue.Set(reflect.Append(sliceValue, newElemValue))
return nil return nil
} }
@ -121,7 +129,6 @@ func toCommonIdentifier(name string) string {
} }
func initializeValueIfNilPtr(value reflect.Value) { func initializeValueIfNilPtr(value reflect.Value) {
if !value.IsValid() || !value.CanSet() { if !value.IsValid() || !value.CanSet() {
return return
} }
@ -173,172 +180,160 @@ func isSimpleModelType(objType reflect.Type) bool {
return objType == timeType || objType == uuidType || objType == byteArrayType return objType == timeType || objType == uuidType || objType == byteArrayType
} }
func isIntegerType(value reflect.Type) bool { func isIntegerType(objType reflect.Type) bool {
switch value { objType = indirectType(objType)
case int8Type, unit8Type, int16Type, uint16Type,
int32Type, uint32Type, int64Type, uint64Type: switch objType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true return true
} }
return false return false
} }
func isNumber(valueType reflect.Type) bool { func isFloatType(value reflect.Type) bool {
return isIntegerType(valueType) || valueType == float64Type || valueType == float32Type switch value.Kind() {
case reflect.Float32, reflect.Float64:
return true
} }
func tryAssign(source, destination reflect.Value) bool {
switch {
case source.Type().ConvertibleTo(destination.Type()):
source = source.Convert(destination.Type())
case isIntegerType(source.Type()) && destination.Type() == boolType:
intValue := source.Int()
if intValue == 1 {
source = reflect.ValueOf(true)
} else if intValue == 0 {
source = reflect.ValueOf(false)
}
case source.Type() == stringType && isNumber(destination.Type()):
// if source is string and destination is a number(int8, int32, float32, ...), we first parse string to float64 number
// and then parsed number is converted into destination type
f, err := strconv.ParseFloat(source.String(), 64)
if err != nil {
return false return false
} }
source = reflect.ValueOf(f)
if source.Type().ConvertibleTo(destination.Type()) { func tryAssign(source, destination reflect.Value) error {
if source.Type() != destination.Type() &&
!isFloatType(destination.Type()) && // to preserve precision during conversion
!(isIntegerType(source.Type()) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string
source.Type().ConvertibleTo(destination.Type()) {
source = source.Convert(destination.Type()) source = source.Convert(destination.Type())
} }
}
if source.Type().AssignableTo(destination.Type()) { if source.Type().AssignableTo(destination.Type()) {
switch b := source.Interface().(type) {
case []byte:
destination.SetBytes(cloneBytes(b))
default:
destination.Set(source) destination.Set(source)
return true }
return nil
} }
return false sourceInterface := source.Interface()
switch destination.Interface().(type) {
case bool:
var nullBool internal.NullBool
err := nullBool.Scan(sourceInterface)
if err != nil {
return err
} }
func setReflectValue(source, destination reflect.Value) { destination.SetBool(nullBool.Bool)
if tryAssign(source, destination) { case float32, float64:
return var nullFloat sql.NullFloat64
err := nullFloat.Scan(sourceInterface)
if err != nil {
return err
} }
if nullFloat.Valid {
destination.SetFloat(nullFloat.Float64)
}
case int, int8, int16, int32, int64:
var integer sql.NullInt64
err := integer.Scan(sourceInterface)
if err != nil {
return err
}
if integer.Valid {
destination.SetInt(integer.Int64)
}
case uint, uint8, uint16, uint32, uint64:
var uInt internal.NullUInt64
err := uInt.Scan(sourceInterface)
if err != nil {
return err
}
if uInt.Valid {
destination.SetUint(uInt.UInt64)
}
case string:
var str sql.NullString
err := str.Scan(sourceInterface)
if err != nil {
return err
}
if str.Valid {
destination.SetString(str.String)
}
case time.Time:
var nullTime internal.NullTime
err := nullTime.Scan(sourceInterface)
if err != nil {
return err
}
if nullTime.Valid {
destination.Set(reflect.ValueOf(nullTime.Time))
}
default:
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface())
}
return nil
}
func setReflectValue(source, destination reflect.Value) error {
if destination.Kind() == reflect.Ptr { if destination.Kind() == reflect.Ptr {
if source.Kind() == reflect.Ptr {
if !source.IsNil() {
if destination.IsNil() { if destination.IsNil() {
initializeValueIfNilPtr(destination) initializeValueIfNilPtr(destination)
} }
if tryAssign(source.Elem(), destination.Elem()) {
return
}
} else {
return
}
} else {
if source.CanAddr() {
source = source.Addr()
} else {
sourceCopy := reflect.New(source.Type())
sourceCopy.Elem().Set(source)
source = sourceCopy
}
if tryAssign(source, destination) {
return
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
}
} else {
if source.Kind() == reflect.Ptr { if source.Kind() == reflect.Ptr {
if source.IsNil() { if source.IsNil() {
return return nil // source is nil, destination should keep its zero value
} }
source = source.Elem() source = source.Elem()
} }
if tryAssign(source, destination) { if err := tryAssign(source, destination.Elem()); err != nil {
return return err
}
} else {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
if err := tryAssign(source, destination); err != nil {
return err
} }
} }
panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String()) return nil
}
func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
values := make([]interface{}, len(columnTypes))
for i, sqlColumnType := range columnTypes {
columnType := newScanType(sqlColumnType)
columnValue := reflect.New(columnType)
values[i] = columnValue.Interface()
}
return values
}
var boolType = reflect.TypeOf(true)
var int8Type = reflect.TypeOf(int8(1))
var unit8Type = reflect.TypeOf(uint8(1))
var int16Type = reflect.TypeOf(int16(1))
var uint16Type = reflect.TypeOf(uint16(1))
var int32Type = reflect.TypeOf(int32(1))
var uint32Type = reflect.TypeOf(uint32(1))
var int64Type = reflect.TypeOf(int64(1))
var uint64Type = reflect.TypeOf(uint64(1))
var float32Type = reflect.TypeOf(float32(1))
var float64Type = reflect.TypeOf(float64(1))
var stringType = reflect.TypeOf("")
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullInt8Type = reflect.TypeOf(internal.NullInt8{})
var nullInt16Type = reflect.TypeOf(internal.NullInt16{})
var nullInt32Type = reflect.TypeOf(internal.NullInt32{})
var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{})
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
var nullStringType = reflect.TypeOf(sql.NullString{})
var nullTimeType = reflect.TypeOf(internal.NullTime{})
var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{})
func newScanType(columnType *sql.ColumnType) reflect.Type {
switch columnType.DatabaseTypeName() {
case "TINYINT":
return nullInt8Type
case "INT2", "SMALLINT", "YEAR":
return nullInt16Type
case "INT4", "MEDIUMINT", "INT":
return nullInt32Type
case "INT8", "BIGINT":
return nullInt64Type
case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
return nullStringType
case "FLOAT4":
return nullFloat32Type
case "FLOAT8", "FLOAT", "DOUBLE":
return nullFloat64Type
case "BOOL":
return nullBoolType
case "BYTEA", "BINARY", "VARBINARY", "BLOB":
return nullByteArrayType
case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
return nullTimeType
default:
return nullStringType
}
} }
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
@ -385,3 +380,12 @@ func fieldToString(field *reflect.StructField) string {
return " at '" + field.Name + " " + field.Type.String() + "'" return " at '" + field.Name + " " + field.Type.String() + "'"
} }
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
c := make([]byte, len(b))
copy(c, b)
return c
}

View file

@ -58,25 +58,24 @@ func TestTryAssign(t *testing.T) {
testValue := reflect.ValueOf(&destination).Elem() testValue := reflect.ValueOf(&destination).Elem()
// convertible // convertible
require.True(t, tryAssign(reflect.ValueOf(convertible), testValue.FieldByName("Convertible"))) require.NoError(t, tryAssign(reflect.ValueOf(convertible), testValue.FieldByName("Convertible")))
require.Equal(t, int64(16), destination.Convertible) require.Equal(t, int64(16), destination.Convertible)
// 1/0 to bool // 1/0 to bool
require.True(t, tryAssign(reflect.ValueOf(intBool1), testValue.FieldByName("IntBool1"))) require.NoError(t, tryAssign(reflect.ValueOf(intBool1), testValue.FieldByName("IntBool1")))
require.Equal(t, true, destination.IntBool1) require.Equal(t, true, destination.IntBool1)
require.True(t, tryAssign(reflect.ValueOf(intBool0), testValue.FieldByName("IntBool0"))) require.NoError(t, tryAssign(reflect.ValueOf(intBool0), testValue.FieldByName("IntBool0")))
require.Equal(t, false, destination.IntBool0) require.Equal(t, false, destination.IntBool0)
require.False(t, tryAssign(reflect.ValueOf(intBool2), testValue.FieldByName("IntBool2"))) require.EqualError(t, tryAssign(reflect.ValueOf(intBool2), testValue.FieldByName("IntBool2")), "can't assign int32(2) to bool")
require.Equal(t, false, destination.IntBool2)
// string to float // string to float
require.True(t, tryAssign(reflect.ValueOf(floatStr), testValue.FieldByName("FloatStr"))) require.NoError(t, tryAssign(reflect.ValueOf(floatStr), testValue.FieldByName("FloatStr")))
require.Equal(t, 1.11, destination.FloatStr) require.Equal(t, 1.11, destination.FloatStr)
require.False(t, tryAssign(reflect.ValueOf(floatErr), testValue.FieldByName("FloatErr"))) require.EqualError(t, tryAssign(reflect.ValueOf(floatErr), testValue.FieldByName("FloatErr")), "converting driver.Value type string (\"1.abcd2\") to a float64: invalid syntax")
require.Equal(t, 0.00, destination.FloatErr) require.Equal(t, 0.00, destination.FloatErr)
// string to string // string to string
require.True(t, tryAssign(reflect.ValueOf(str), testValue.FieldByName("Str"))) require.NoError(t, tryAssign(reflect.ValueOf(str), testValue.FieldByName("Str")))
require.Equal(t, str, destination.Str) require.Equal(t, str, destination.Str)
} }

55
sqlite/cast.go Normal file
View file

@ -0,0 +1,55 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
)
type cast interface {
AS(castType string) Expression
AS_TEXT() StringExpression
AS_NUMERIC() FloatExpression
AS_INTEGER() IntegerExpression
AS_REAL() FloatExpression
AS_BLOB() StringExpression
}
type castImpl struct {
jet.Cast
}
// CAST function converts a expr (of any type) into latter specified datatype.
func CAST(expr Expression) cast {
castImpl := &castImpl{}
castImpl.Cast = jet.NewCastImpl(expr)
return castImpl
}
// AS casts expressions to castType
func (c *castImpl) AS(castType string) Expression {
return c.Cast.AS(castType)
}
// AS_TEXT cast expression to TEXT type
func (c *castImpl) AS_TEXT() StringExpression {
return StringExp(c.AS("TEXT"))
}
// AS_NUMERIC cast expression to NUMERIC type
func (c *castImpl) AS_NUMERIC() FloatExpression {
return FloatExp(c.AS("NUMERIC"))
}
// AS_INTEGER cast expression to INTEGER type
func (c *castImpl) AS_INTEGER() IntegerExpression {
return IntExp(c.AS("INTEGER"))
}
// AS_REAL cast expression to REAL type
func (c *castImpl) AS_REAL() FloatExpression {
return FloatExp(c.AS("REAL"))
}
// AS_BLOB cast expression to BLOB type
func (c *castImpl) AS_BLOB() StringExpression {
return StringExp(c.AS("BLOB"))
}

14
sqlite/cast_test.go Normal file
View file

@ -0,0 +1,14 @@
package sqlite
import (
"testing"
)
func TestCAST(t *testing.T) {
assertSerialize(t, CAST(Float(11.22)).AS("bigint"), `CAST(? AS bigint)`)
assertSerialize(t, CAST(Int(22)).AS_TEXT(), `CAST(? AS TEXT)`)
assertSerialize(t, CAST(Int(22)).AS_NUMERIC(), `CAST(? AS NUMERIC)`)
assertSerialize(t, CAST(String("22")).AS_INTEGER(), `CAST(? AS INTEGER)`)
assertSerialize(t, CAST(String("22.2")).AS_REAL(), `CAST(? AS REAL)`)
assertSerialize(t, CAST(String("blob")).AS_BLOB(), `CAST(? AS BLOB)`)
}

58
sqlite/columns.go Normal file
View file

@ -0,0 +1,58 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// Column is common column interface for all types of columns.
type Column = jet.ColumnExpression
// ColumnList function returns list of columns that be used as projection or column list for UPDATE and INSERT statement.
type ColumnList = jet.ColumnList
// ColumnBool is interface for SQL boolean columns.
type ColumnBool = jet.ColumnBool
// BoolColumn creates named bool column.
var BoolColumn = jet.BoolColumn
// ColumnString is interface for SQL text, character, character varying
// bytea, uuid columns and enums types.
type ColumnString = jet.ColumnString
// StringColumn creates named string column.
var StringColumn = jet.StringColumn
// ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger = jet.ColumnInteger
// IntegerColumn creates named integer column.
var IntegerColumn = jet.IntegerColumn
// ColumnFloat is interface for SQL real, numeric, decimal or double precision column.
type ColumnFloat = jet.ColumnFloat
// FloatColumn creates named float column.
var FloatColumn = jet.FloatColumn
// ColumnTime is interface for SQL time column.
type ColumnTime = jet.ColumnTime
// TimeColumn creates named time column
var TimeColumn = jet.TimeColumn
// ColumnDate is interface of SQL date columns.
type ColumnDate = jet.ColumnDate
// DateColumn creates named date column.
var DateColumn = jet.DateColumn
// ColumnDateTime is interface of SQL timestamp columns.
type ColumnDateTime = jet.ColumnTimestamp
// DateTimeColumn creates named timestamp column
var DateTimeColumn = jet.TimestampColumn
//ColumnTimestamp is interface of SQL timestamp columns.
type ColumnTimestamp = jet.ColumnTimestamp
// TimestampColumn creates named timestamp column
var TimestampColumn = jet.TimestampColumn

View file

@ -0,0 +1,61 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// DeleteStatement is interface for MySQL DELETE statement
type DeleteStatement interface {
Statement
WHERE(expression BoolExpression) DeleteStatement
ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement
LIMIT(limit int64) DeleteStatement
RETURNING(projections ...jet.Projection) DeleteStatement
}
type deleteStatementImpl struct {
jet.SerializerStatement
Delete jet.ClauseStatementBegin
Where jet.ClauseWhere
OrderBy jet.ClauseOrderBy
Limit jet.ClauseLimit
Returning jet.ClauseReturning
}
func newDeleteStatement(table Table) DeleteStatement {
newDelete := &deleteStatementImpl{}
newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete,
&newDelete.Delete,
&newDelete.Where,
&newDelete.OrderBy,
&newDelete.Limit,
&newDelete.Returning,
)
newDelete.Delete.Name = "DELETE FROM"
newDelete.Delete.Tables = append(newDelete.Delete.Tables, table)
newDelete.Where.Mandatory = true
newDelete.Limit.Count = -1
return newDelete
}
func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
d.Where.Condition = expression
return d
}
func (d *deleteStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement {
d.OrderBy.List = orderByClauses
return d
}
func (d *deleteStatementImpl) LIMIT(limit int64) DeleteStatement {
d.Limit.Count = limit
return d
}
func (d *deleteStatementImpl) RETURNING(projections ...jet.Projection) DeleteStatement {
d.Returning.ProjectionList = projections
return d
}

View file

@ -0,0 +1,26 @@
package sqlite
import (
"testing"
)
func TestDeleteUnconditionally(t *testing.T) {
assertStatementSqlErr(t, table1.DELETE(), `jet: WHERE clause not set`)
assertStatementSqlErr(t, table1.DELETE().WHERE(nil), `jet: WHERE clause not set`)
}
func TestDeleteWithWhere(t *testing.T) {
assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), `
DELETE FROM db.table1
WHERE table1.col1 = ?;
`, int64(1))
}
func TestDeleteWithWhereOrderByLimit(t *testing.T) {
assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).ORDER_BY(table1Col1).LIMIT(1), `
DELETE FROM db.table1
WHERE table1.col1 = ?
ORDER BY table1.col1
LIMIT ?;
`, int64(1), int64(1))
}

225
sqlite/dialect.go Normal file
View file

@ -0,0 +1,225 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
)
// Dialect is implementation of SQL Builder for SQLite databases.
var Dialect = newDialect()
func newDialect() jet.Dialect {
operatorSerializeOverrides := map[string]jet.SerializeOverride{}
operatorSerializeOverrides["IS DISTINCT FROM"] = sqlite_IS_DISTINCT_FROM
operatorSerializeOverrides["IS NOT DISTINCT FROM"] = sqlite_IS_NOT_DISTINCT_FROM
operatorSerializeOverrides["#"] = sqliteBitXOR
mySQLDialectParams := jet.DialectParams{
Name: "SQLite",
PackageName: "sqlite",
OperatorSerializeOverrides: operatorSerializeOverrides,
AliasQuoteChar: '"',
IdentifierQuoteChar: '`',
ArgumentPlaceholder: func(int) string {
return "?"
},
ReservedWords: reservedWords2,
}
return jet.NewDialect(mySQLDialectParams)
}
func sqliteBitXOR(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator XOR")
}
// (~(a&b))&(a|b)
a := expressions[0]
b := expressions[1]
out.WriteString("(~(")
jet.Serialize(a, statement, out, options...)
out.WriteByte('&')
jet.Serialize(b, statement, out, options...)
out.WriteString("))&(")
jet.Serialize(a, statement, out, options...)
out.WriteByte('|')
jet.Serialize(b, statement, out, options...)
out.WriteByte(')')
}
}
func sqlite_IS_NOT_DISTINCT_FROM(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")
}
jet.Serialize(expressions[0], statement, out)
out.WriteString("IS")
jet.Serialize(expressions[1], statement, out)
}
}
func sqlite_IS_DISTINCT_FROM(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")
}
jet.Serialize(expressions[0], statement, out)
out.WriteString("IS NOT")
jet.Serialize(expressions[1], statement, out)
}
}
var reservedWords2 = []string{
"ABORT",
"ACTION",
"ADD",
"AFTER",
"ALL",
"ALTER",
"ALWAYS",
"ANALYZE",
"AND",
"AS",
"ASC",
"ATTACH",
"AUTOINCREMENT",
"BEFORE",
"BEGIN",
"BETWEEN",
"BY",
"CASCADE",
"CASE",
"CAST",
"CHECK",
"COLLATE",
"COLUMN",
"COMMIT",
"CONFLICT",
"CONSTRAINT",
"CREATE",
"CROSS",
"CURRENT",
"CURRENT_DATE",
"CURRENT_TIME",
"CURRENT_TIMESTAMP",
"DATABASE",
"DEFAULT",
"DEFERRABLE",
"DEFERRED",
"DELETE",
"DESC",
"DETACH",
"DISTINCT",
"DO",
"DROP",
"EACH",
"ELSE",
"END",
"ESCAPE",
"EXCEPT",
"EXCLUDE",
"EXCLUSIVE",
"EXISTS",
"EXPLAIN",
"FAIL",
"FILTER",
"FIRST",
"FOLLOWING",
"FOR",
"FOREIGN",
"FROM",
"FULL",
"GENERATED",
"GLOB",
"GROUP",
"GROUPS",
"HAVING",
"IF",
"IGNORE",
"IMMEDIATE",
"IN",
"INDEX",
"INDEXED",
"INITIALLY",
"INNER",
"INSERT",
"INSTEAD",
"INTERSECT",
"INTO",
"IS",
"ISNULL",
"JOIN",
"KEY",
"LAST",
"LEFT",
"LIKE",
"LIMIT",
"MATCH",
"MATERIALIZED",
"NATURAL",
"NO",
"NOT",
"NOTHING",
"NOTNULL",
"NULL",
"NULLS",
"OF",
"OFFSET",
"ON",
"OR",
"ORDER",
"OTHERS",
"OUTER",
"OVER",
"PARTITION",
"PLAN",
"PRAGMA",
"PRECEDING",
"PRIMARY",
"QUERY",
"RAISE",
"RANGE",
"RECURSIVE",
"REFERENCES",
"REGEXP",
"REINDEX",
"RELEASE",
"RENAME",
"REPLACE",
"RESTRICT",
"RETURNING",
"RIGHT",
"ROLLBACK",
"ROW",
"ROWS",
"SAVEPOINT",
"SELECT",
"SET",
"TABLE",
"TEMP",
"TEMPORARY",
"THEN",
"TIES",
"TO",
"TRANSACTION",
"TRIGGER",
"UNBOUNDED",
"UNION",
"UNIQUE",
"UPDATE",
"USING",
"VACUUM",
"VALUES",
"VIEW",
"VIRTUAL",
"WHEN",
"WHERE",
"WINDOW",
"WITH",
"WITHOUT",
}

59
sqlite/dialect_test.go Normal file
View file

@ -0,0 +1,59 @@
package sqlite
import (
"testing"
)
func TestBoolExpressionIS_DISTINCT_FROM(t *testing.T) {
assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS NOT table2.col_bool)")
assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS NOT ?)", false)
}
func TestBoolExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS table2.col_bool)")
assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS ?)", false)
}
func TestBoolLiteral(t *testing.T) {
assertSerialize(t, Bool(true), "?", true)
assertSerialize(t, Bool(false), "?", false)
}
func TestIntegerExpressionDIV(t *testing.T) {
assertSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.col_int / table2.col_int)")
assertSerialize(t, table1ColInt.DIV(Int(11)), "(table1.col_int / ?)", int64(11))
}
func TestIntExpressionPOW(t *testing.T) {
assertSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)")
assertSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, ?)", int64(11))
}
func TestIntExpressionBIT_XOR(t *testing.T) {
assertSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "((~(table1.col_int & table2.col_int))&(table1.col_int | table2.col_int))")
assertSerialize(t, table1ColInt.BIT_XOR(Int(11)), "((~(table1.col_int & ?))&(table1.col_int | ?))", int64(11), int64(11))
}
func TestExists(t *testing.T) {
assertSerialize(t, EXISTS(
table2.
SELECT(Int(1)).
WHERE(table1Col1.EQ(table2Col3)),
),
`(EXISTS (
SELECT ?
FROM db.table2
WHERE table1.col1 = table2.col3
))`, int64(1))
}
func TestString_REGEXP_LIKE_operator(t *testing.T) {
assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)")
assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 REGEXP ?)", "JOHN")
}
func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) {
assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)")
assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 NOT REGEXP ?)", "JOHN")
}

97
sqlite/expressions.go Normal file
View file

@ -0,0 +1,97 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time or Timestamp expressions.
type Expression = jet.Expression
// BoolExpression interface
type BoolExpression = jet.BoolExpression
// StringExpression interface
type StringExpression = jet.StringExpression
// NumericExpression is shared interface for integer or real expression
type NumericExpression = jet.NumericExpression
// IntegerExpression interface
type IntegerExpression = jet.IntegerExpression
// FloatExpression interface
type FloatExpression = jet.FloatExpression
// TimeExpression interface
type TimeExpression = jet.TimeExpression
// DateExpression interface
type DateExpression = jet.DateExpression
// DateTimeExpression interface
type DateTimeExpression = jet.TimestampExpression
// TimestampExpression interface
type TimestampExpression = jet.TimestampExpression
// BoolExp is bool expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as bool expression.
// Does not add sql cast to generated sql builder output.
var BoolExp = jet.BoolExp
// StringExp is string expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as string expression.
// Does not add sql cast to generated sql builder output.
var StringExp = jet.StringExp
// IntExp is int expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as int expression.
// Does not add sql cast to generated sql builder output.
var IntExp = jet.IntExp
// FloatExp is date expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as float expression.
// Does not add sql cast to generated sql builder output.
var FloatExp = jet.FloatExp
// TimeExp is time expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as time expression.
// Does not add sql cast to generated sql builder output.
var TimeExp = jet.TimeExp
// DateExp is date expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as date expression.
// Does not add sql cast to generated sql builder output.
var DateExp = jet.DateExp
// DateTimeExp is timestamp expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as timestamp expression.
// Does not add sql cast to generated sql builder output.
var DateTimeExp = jet.TimestampExp
// TimestampExp is timestamp expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as timestamp expression.
// Does not add sql cast to generated sql builder output.
var TimestampExp = jet.TimestampExp
// RawArgs is type used to pass optional arguments to Raw method
type RawArgs = map[string]interface{}
// Raw can be used for any unsupported functions, operators or expressions.
// For example: Raw("current_database()")
// Raw helper methods for each of the sqlite types
var (
Raw = jet.Raw
RawInt = jet.RawInt
RawFloat = jet.RawFloat
RawString = jet.RawString
RawTime = jet.RawTime
RawTimestamp = jet.RawTimestamp
RawDate = jet.RawDate
)
// Func can be used to call an custom or as of yet unsupported function in the database.
var Func = jet.Func
// NewEnumValue creates new named enum value
var NewEnumValue = jet.NewEnumValue

View file

@ -0,0 +1,52 @@
package sqlite
import (
"github.com/stretchr/testify/require"
"testing"
)
func TestRaw(t *testing.T) {
assertSerialize(t, Raw("current_database()"), "(current_database())")
assertDebugSerialize(t, Raw("current_database()"), "(current_database())")
assertSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}),
"(? + table.colInt + ?)", 11, 22)
assertDebugSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}),
"(11 + table.colInt + 22)")
assertSerialize(t,
Int(700).ADD(RawInt("#1 + table.colInt + #2", RawArgs{"#1": 11, "#2": 22})),
"(? + (? + table.colInt + ?))",
int64(700), 11, 22)
assertDebugSerialize(t,
Int(700).ADD(RawInt("#1 + table.colInt + #2", RawArgs{"#1": 11, "#2": 22})),
"(700 + (11 + table.colInt + 22))")
}
func TestRawDuplicateArguments(t *testing.T) {
assertSerialize(t, Raw(":arg + table.colInt + :arg", RawArgs{":arg": 11}),
"(? + table.colInt + ?)", 11, 11)
assertSerialize(t, Raw("#age + table.colInt + #year + #age + #year + 11", RawArgs{"#age": 11, "#year": 2000}),
"(? + table.colInt + ? + ? + ? + 11)", 11, 2000, 11, 2000)
assertSerialize(t, Raw("#1 + all_types.integer + #2 + #1 + #2 + #3 + #4",
RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}),
`(? + all_types.integer + ? + ? + ? + ? + ?)`, 11, 22, 11, 22, 33, 44)
}
func TestRawInvalidArguments(t *testing.T) {
defer func() {
r := recover()
require.Equal(t, "jet: named argument 'first_arg' does not appear in raw query", r)
}()
assertSerialize(t, Raw("table.colInt + :second_arg", RawArgs{"first_arg": 11}), "(table.colInt + ?)", 22)
}
func TestRawType(t *testing.T) {
assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)),
"((table.colInt + ?) = ?)", 11.22, 3.14)
assertSerialize(t, RawString("table.colStr || str", RawArgs{"str": "doe"}).EQ(String("john doe")),
"((table.colStr || ?) = ?)", "doe", "john doe")
}

342
sqlite/functions.go Normal file
View file

@ -0,0 +1,342 @@
package sqlite
import (
"fmt"
"github.com/go-jet/jet/v2/internal/jet"
"time"
)
// ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression {
return jet.NewFunc("", expressions, nil)
}
// ------------------ Mathematical functions ---------------//
// ABSf calculates absolute value from float expression
var ABSf = jet.ABSf
// ABSi calculates absolute value from int expression
var ABSi = jet.ABSi
// POW calculates power of base with exponent
var POW = jet.POW
// POWER calculates power of base with exponent
var POWER = jet.POWER
// SQRT calculates square root of numeric expression
var SQRT = jet.SQRT
// CBRT calculates cube root of numeric expression
func CBRT(number jet.NumericExpression) jet.FloatExpression {
return POWER(number, Float(1.0).DIV(Float(3.0)))
}
// CEIL calculates ceil of float expression
var CEIL = jet.CEIL
// FLOOR calculates floor of float expression
var FLOOR = jet.FLOOR
// ROUND calculates round of a float expressions with optional precision
var ROUND = jet.ROUND
// SIGN returns sign of float expression
var SIGN = jet.SIGN
// TRUNC calculates trunc of float expression with precision
var TRUNC = TRUNCATE
// TRUNCATE calculates trunc of float expression with precision
var TRUNCATE = func(floatExpression jet.FloatExpression, precision jet.IntegerExpression) jet.FloatExpression {
return jet.NewFloatFunc("TRUNCATE", floatExpression, precision)
}
// LN calculates natural algorithm of float expression
var LN = jet.LN
// LOG calculates logarithm of float expression
var LOG = jet.LOG
// ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression
var AVG = jet.AVG
// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none.
//var BIT_AND = jet.BIT_AND
// BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none.
//var BIT_OR = jet.BIT_OR
// COUNT is aggregate function. Returns number of input rows for which the value of expression is not null.
var COUNT = jet.COUNT
// MAX is aggregate function. Returns maximum value of expression across all input values
var MAX = jet.MAX
// MAXi is aggregate function. Returns maximum value of int expression across all input values
var MAXi = jet.MAXi
// MAXf is aggregate function. Returns maximum value of float expression across all input values
var MAXf = jet.MAXf
// MIN is aggregate function. Returns minimum value of int expression across all input values
var MIN = jet.MIN
// MINi is aggregate function. Returns minimum value of int expression across all input values
var MINi = jet.MINi
// MINf is aggregate function. Returns minimum value of float expression across all input values
var MINf = jet.MINf
// SUM is aggregate function. Returns sum of all expressions
var SUM = jet.SUM
// SUMi is aggregate function. Returns sum of integer expression.
var SUMi = jet.SUMi
// SUMf is aggregate function. Returns sum of float expression.
var SUMf = jet.SUMf
// -------------------- Window functions -----------------------//
// ROW_NUMBER returns number of the current row within its partition, counting from 1
var ROW_NUMBER = jet.ROW_NUMBER
// RANK of the current row with gaps; same as row_number of its first peer
var RANK = jet.RANK
// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups
var DENSE_RANK = jet.DENSE_RANK
// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1)
var PERCENT_RANK = jet.PERCENT_RANK
// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows
var CUME_DIST = jet.CUME_DIST
// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible
var NTILE = jet.NTILE
// LAG returns value evaluated at the row that is offset rows before the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LAG = jet.LAG
// LEAD returns value evaluated at the row that is offset rows after the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LEAD = jet.LEAD
// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame
var FIRST_VALUE = jet.FIRST_VALUE
// LAST_VALUE returns value evaluated at the row that is the last row of the window frame
var LAST_VALUE = jet.LAST_VALUE
// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row
var NTH_VALUE = jet.NTH_VALUE
//--------------------- String functions ------------------//
// BIT_LENGTH returns number of bits in string expression
//var BIT_LENGTH = jet.BIT_LENGTH
//
//// CHAR_LENGTH returns number of characters in string expression
//var CHAR_LENGTH = jet.CHAR_LENGTH
//
//// OCTET_LENGTH returns number of bytes in string expression
//var OCTET_LENGTH = jet.OCTET_LENGTH
// LOWER returns string expression in lower case
var LOWER = jet.LOWER
// UPPER returns string expression in upper case
var UPPER = jet.UPPER
// LTRIM removes the longest string containing only characters
// from characters (a space by default) from the start of string
var LTRIM = jet.LTRIM
// RTRIM removes the longest string containing only characters
// from characters (a space by default) from the end of string
var RTRIM = jet.RTRIM
// CONCAT adds two or more expressions together
//var CONCAT = jet.CONCAT
// CONCAT_WS adds two or more expressions together with a separator.
//var CONCAT_WS = jet.CONCAT_WS
// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
//var FORMAT = jet.FORMAT
// LEFTSTR returns first n characters in the string.
// When n is negative, return all but last |n| characters.
//func LEFTSTR(str StringExpression, n IntegerExpression) StringExpression {
// return jet.NewStringFunc("LEFTSTR", str, n)
//}
//
//// RIGHT returns last n characters in the string.
//// When n is negative, return all but first |n| characters.
//func RIGHTSTR(str StringExpression, n IntegerExpression) StringExpression {
// return jet.NewStringFunc("RIGHTSTR", str, n)
//}
// LENGTH returns number of characters in string with a given encoding
func LENGTH(str jet.StringExpression) jet.StringExpression {
return jet.LENGTH(str)
}
// LPAD fills up the string to length length by prepending the characters
// fill (a space by default). If the string is already longer than length
// then it is truncated (on the right).
//func LPAD(str jet.StringExpression, length jet.IntegerExpression, text jet.StringExpression) jet.StringExpression {
// return jet.LPAD(str, length, text)
//}
// RPAD fills up the string to length length by appending the characters
// fill (a space by default). If the string is already longer than length then it is truncated.
//func RPAD(str jet.StringExpression, length jet.IntegerExpression, text jet.StringExpression) jet.StringExpression {
// return jet.RPAD(str, length, text)
//}
// MD5 calculates the MD5 hash of string, returning the result in hexadecimal
//var MD5 = jet.MD5
// REPEAT repeats string the specified number of times
//var REPEAT = jet.REPEAT
// REPLACE replaces all occurrences in string of substring from with substring to
var REPLACE = jet.REPLACE
// REVERSE returns reversed string.
var REVERSE = jet.REVERSE
// SUBSTR extracts substring
var SUBSTR = jet.SUBSTR
// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise.
var REGEXP_LIKE = jet.REGEXP_LIKE
//----------------- Date/Time Functions and Operators ------------//
// CURRENT_DATE returns current date
var CURRENT_DATE = jet.CURRENT_DATE
// CURRENT_TIME returns current time with time zone
func CURRENT_TIME() TimeExpression {
return TimeExp(jet.CURRENT_TIME())
}
// CURRENT_TIMESTAMP returns current timestamp with time zone
func CURRENT_TIMESTAMP() TimestampExpression {
return TimestampExp(jet.CURRENT_TIMESTAMP())
}
//// NOW returns current datetime
//func NOW() DateTimeExpression {
// //if len(fsp) > 0 {
// // return jet.NewTimestampFunc("NOW", jet.FixedLiteral(int64(fsp[0])))
// //}
// //return jet.NewTimestampFunc("NOW")
// return DATETIME(jet.FixedLiteral("now"))
//}
// time-value modifiers
var (
YEARS = modifier("YEARS")
MONTHS = modifier("MONTHS")
DAYS = modifier("DAYS")
HOURS = modifier("HOURS")
MINUTES = modifier("MINUTES")
SECONDS = modifier("SECONDS")
START_OF_YEAR = String("start of year")
START_OF_MONTH = String("start of month")
UNIXEPOCH = String("unixepoch")
LOCALTIME = String("localtime")
UTC = String("UTC")
WEEKDAY = func(value int) Expression {
return String(fmt.Sprintf("WEEKDAY %d", value))
}
)
func modifier(modifierName string) func(value float64) Expression {
return func(value float64) Expression {
return String(fmt.Sprintf("%g %s", value, modifierName))
}
}
// DATE function creates new date from time-value and zero or more time modifiers
func DATE(timeValue interface{}, modifiers ...Expression) DateExpression {
exprList := getFuncExprList(timeValue, modifiers...)
return jet.NewDateFunc("DATE", exprList...)
}
// TIME function creates new time from time-value and zero or more time modifiers
func TIME(timeValue interface{}, modifiers ...Expression) TimeExpression {
exprList := getFuncExprList(timeValue, modifiers...)
return jet.NewTimeFunc("TIME", exprList...)
}
// DATETIME function creates new DateTime from time-value and zero or more time modifiers
func DATETIME(timeValue interface{}, modifiers ...Expression) DateTimeExpression {
exprList := getFuncExprList(timeValue, modifiers...)
return jet.NewTimestampFunc("DATETIME", exprList...)
}
// JULIANDAY returns the number of days since noon in Greenwich on November 24, 4714 B.C
func JULIANDAY(timeValue interface{}, modifiers ...Expression) FloatExpression {
exprList := getFuncExprList(timeValue, modifiers...)
return jet.NewFloatFunc("JULIANDAY", exprList...)
}
// STRFTIME routine returns the date formatted according to the format string specified as the first argument.
func STRFTIME(format StringExpression, timeValue interface{}, modifiers ...Expression) StringExpression {
exprList := append([]Expression{format}, getFuncExprList(timeValue, modifiers...)...)
return jet.NewStringFunc("strftime", exprList...)
}
func getFuncExprList(timeValue interface{}, modifiers ...Expression) []Expression {
return append([]Expression{getTimeValueExpression(timeValue)}, modifiers...)
}
func getTimeValueExpression(timeValue interface{}) Expression {
switch t := timeValue.(type) {
case string:
return String(t)
case Expression:
return t
case time.Time, int64:
return jet.Literal(t)
}
panic(fmt.Sprintf("jet: Invalid time value %T(%q)", timeValue, timeValue))
}
// TIMESTAMP return a datetime value based on the arguments:
func TIMESTAMP(str StringExpression) TimestampExpression {
return jet.NewTimestampFunc("TIMESTAMP", str)
}
// UNIX_TIMESTAMP returns unix timestamp
func UNIX_TIMESTAMP(str StringExpression) TimestampExpression {
return jet.NewTimestampFunc("UNIX_TIMESTAMP", str)
}
//----------- Comparison operators ---------------//
// EXISTS checks for existence of the rows in subQuery
var EXISTS = jet.EXISTS
// CASE create CASE operator with optional list of expressions
var CASE = jet.CASE

117
sqlite/insert_statement.go Normal file
View file

@ -0,0 +1,117 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// InsertStatement is interface for SQL INSERT statements
type InsertStatement interface {
Statement
VALUES(value interface{}, values ...interface{}) InsertStatement
MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
DEFAULT_VALUES() InsertStatement
ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict
RETURNING(projections ...Projection) InsertStatement
}
func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{
DefaultValues: jet.ClauseOptional{Name: "DEFAULT VALUES", InNewLine: true},
}
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert,
&newInsert.ValuesQuery,
&newInsert.OnDuplicateKey,
&newInsert.DefaultValues,
&newInsert.OnConflict,
&newInsert.Returning,
)
newInsert.Insert.Table = table
newInsert.Insert.Columns = columns
newInsert.ValuesQuery.SkipSelectWrap = true
return newInsert
}
type insertStatementImpl struct {
jet.SerializerStatement
Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery
OnDuplicateKey onDuplicateKeyUpdateClause
DefaultValues jet.ClauseOptional
OnConflict onConflictClause
Returning jet.ClauseReturning
}
func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values))
return is
}
// MODEL will insert row of values, where value for each column is extracted from filed of structure data.
// If data is not struct or there is no field for every column selected, this method will panic.
func (is *insertStatementImpl) MODEL(data interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromModel(is.Insert.GetColumns(), data))
return is
}
func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowsFromModels(is.Insert.GetColumns(), data)...)
return is
}
func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement {
is.OnDuplicateKey = assigments
return is
}
func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
is.ValuesQuery.Query = selectStatement
return is
}
func (is *insertStatementImpl) DEFAULT_VALUES() InsertStatement {
is.DefaultValues.Show = true
return is
}
func (is *insertStatementImpl) RETURNING(projections ...jet.Projection) InsertStatement {
is.Returning.ProjectionList = projections
return is
}
type onDuplicateKeyUpdateClause []jet.ColumnAssigment
// Serialize for SetClause
func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(s) == 0 {
return
}
out.NewLine()
out.WriteString("ON DUPLICATE KEY UPDATE")
out.IncreaseIdent(24)
for i, assigment := range s {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...)
}
out.DecreaseIdent(24)
}
func (is *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict {
is.OnConflict = onConflictClause{
insertStatement: is,
indexExpressions: indexExpressions,
}
return &is.OnConflict
}

View file

@ -0,0 +1,150 @@
package sqlite
import (
"github.com/stretchr/testify/require"
"testing"
"time"
)
func TestInvalidInsert(t *testing.T) {
assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list")
}
func TestInsertNilValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), `
INSERT INTO db.table1 (col1)
VALUES (?);
`, nil)
}
func TestInsertSingleValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), `
INSERT INTO db.table1 (col1)
VALUES (?);
`, int(1))
}
func TestInsertWithColumnList(t *testing.T) {
columnList := ColumnList{table3ColInt}
columnList = append(columnList, table3StrCol)
assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), `
INSERT INTO db.table3 (col_int, col2)
VALUES (?, ?);
`, 1, 3)
}
func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
assertStatementSql(t, table1.INSERT(table1ColTimestamp).VALUES(date), `
INSERT INTO db.table1 (col_timestamp)
VALUES (?);
`, date)
}
func TestInsertMultipleValues(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), `
INSERT INTO db.table1 (col1, col_float, col3)
VALUES (?, ?, ?);
`, 1, 2, 3)
}
func TestInsertMultipleRows(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColFloat).
VALUES(1, 2).
VALUES(11, 22).
VALUES(111, 222)
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (?, ?),
(?, ?),
(?, ?);
`, 1, 2, 11, 22, 111, 222)
}
func TestInsertValuesFromModel(t *testing.T) {
type Table1Model struct {
Col1 *int
ColFloat float64
}
one := 1
toInsert := Table1Model{
Col1: &one,
ColFloat: 1.11,
}
stmt := table1.INSERT(table1Col1, table1ColFloat).
MODEL(toInsert).
MODEL(&toInsert)
expectedSQL := `
INSERT INTO db.table1 (col1, col_float)
VALUES (?, ?),
(?, ?);
`
assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11))
}
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
defer func() {
r := recover()
require.Equal(t, r, "missing struct field for column : col1")
}()
type Table1Model struct {
Col1Prim int
Col2 string
}
newData := Table1Model{
Col1Prim: 1,
Col2: "one",
}
table1.
INSERT(table1Col1, table1ColFloat).
MODEL(newData)
}
func TestInsertFromNonStructModel(t *testing.T) {
defer func() {
r := recover()
require.Equal(t, r, "jet: data has to be a struct")
}()
table2.INSERT(table2ColInt).MODEL([]int{})
}
func TestInsert_ON_CONFLICT(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColBool).
VALUES("one", "two").
VALUES("1", "2").
VALUES("theta", "beta").
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
SET(table1ColBool.SET(Bool(true)),
table2ColInt.SET(Int(1)),
ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))),
).WHERE(table1Col1.GT(Int(2))),
).
RETURNING(table1Col1, table1ColBool)
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_bool)
VALUES (?, ?),
(?, ?),
(?, ?)
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
SET col_bool = ?,
col_int = ?,
(col1, col_bool) = (?, ?)
WHERE table1.col1 > ?
RETURNING table1.col1 AS "table1.col1",
table1.col_bool AS "table1.col_bool";
`)
}

70
sqlite/literal.go Normal file
View file

@ -0,0 +1,70 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
"time"
)
// Keywords
var (
STAR = jet.STAR
NULL = jet.NULL
)
// Bool creates new bool literal expression
var Bool = jet.Bool
// Int is constructor for 64 bit signed integer expressions literals.
var Int = jet.Int
// Int8 is constructor for 8 bit signed integer expressions literals.
var Int8 = jet.Int8
// Int16 is constructor for 16 bit signed integer expressions literals.
var Int16 = jet.Int16
// Int32 is constructor for 32 bit signed integer expressions literals.
var Int32 = jet.Int32
// Int64 is constructor for 64 bit signed integer expressions literals.
var Int64 = jet.Int
// Uint8 is constructor for 8 bit unsigned integer expressions literals.
var Uint8 = jet.Uint8
// Uint16 is constructor for 16 bit unsigned integer expressions literals.
var Uint16 = jet.Uint16
// Uint32 is constructor for 32 bit unsigned integer expressions literals.
var Uint32 = jet.Uint32
// Uint64 is constructor for 64 bit unsigned integer expressions literals.
var Uint64 = jet.Uint64
// Float creates new float literal expression from float64 value
var Float = jet.Float
// Decimal creates new float literal expression from string value
var Decimal = jet.Decimal
// String creates new string literal expression
var String = jet.String
// UUID is a helper function to create string literal expression from uuid object
// value can be any uuid type with a String method
var UUID = jet.UUID
// Date creates new date literal expression
func Date(year int, month time.Month, day int) DateExpression {
return DATE(jet.Date(year, month, day))
}
// Time creates new time literal expression
func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression {
return TIME(jet.Time(hour, minute, second, nanoseconds...))
}
// DateTime creates new datetime(timestamp) literal expression
func DateTime(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) DateTimeExpression {
return DATETIME(jet.Timestamp(year, month, day, hour, minute, second, nanoseconds...))
}

80
sqlite/literal_test.go Normal file
View file

@ -0,0 +1,80 @@
package sqlite
import (
"math"
"testing"
"time"
)
func TestBool(t *testing.T) {
assertSerialize(t, Bool(false), `?`, false)
}
func TestInt(t *testing.T) {
assertSerialize(t, Int(11), `?`, int64(11))
}
func TestInt8(t *testing.T) {
val := int8(math.MinInt8)
assertSerialize(t, Int8(val), `?`, val)
}
func TestInt16(t *testing.T) {
val := int16(math.MinInt16)
assertSerialize(t, Int16(val), `?`, val)
}
func TestInt32(t *testing.T) {
val := int32(math.MinInt32)
assertSerialize(t, Int32(val), `?`, val)
}
func TestInt64(t *testing.T) {
val := int64(math.MinInt64)
assertSerialize(t, Int64(val), `?`, val)
}
func TestUint8(t *testing.T) {
val := uint8(math.MaxUint8)
assertSerialize(t, Uint8(val), `?`, val)
}
func TestUint16(t *testing.T) {
val := uint16(math.MaxUint16)
assertSerialize(t, Uint16(val), `?`, val)
}
func TestUint32(t *testing.T) {
val := uint32(math.MaxUint32)
assertSerialize(t, Uint32(val), `?`, val)
}
func TestUint64(t *testing.T) {
val := uint64(math.MaxUint64)
assertSerialize(t, Uint64(val), `?`, val)
}
func TestFloat(t *testing.T) {
assertSerialize(t, Float(12.34), `?`, float64(12.34))
}
func TestString(t *testing.T) {
assertSerialize(t, String("Some text"), `?`, "Some text")
}
var testTime = time.Now()
func TestDate(t *testing.T) {
assertSerialize(t, Date(2014, time.January, 2), "DATE(?)", "2014-01-02")
assertSerialize(t, DATE(testTime), "DATE(?)", testTime)
}
func TestTime(t *testing.T) {
assertSerialize(t, Time(10, 15, 30), `TIME(?)`, "10:15:30")
assertSerialize(t, TIME(testTime), "TIME(?)", testTime)
}
func TestDateTime(t *testing.T) {
assertSerialize(t, DateTime(2010, time.March, 30, 10, 15, 30), `DATETIME(?)`, "2010-03-30 10:15:30")
assertSerialize(t, DATETIME(testTime), `DATETIME(?)`, testTime)
}

View file

@ -0,0 +1,84 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
)
type onConflict interface {
WHERE(indexPredicate BoolExpression) conflictTarget
conflictTarget
}
type conflictTarget interface {
DO_NOTHING() InsertStatement
DO_UPDATE(action conflictAction) InsertStatement
}
type onConflictClause struct {
insertStatement InsertStatement
indexExpressions []jet.ColumnExpression
whereClause jet.ClauseWhere
do jet.Serializer
}
func (o *onConflictClause) WHERE(indexPredicate BoolExpression) conflictTarget {
o.whereClause.Condition = indexPredicate
return o
}
func (o *onConflictClause) DO_NOTHING() InsertStatement {
o.do = jet.Keyword("DO NOTHING")
return o.insertStatement
}
func (o *onConflictClause) DO_UPDATE(action conflictAction) InsertStatement {
o.do = action
return o.insertStatement
}
func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(o.indexExpressions) == 0 && o.do == nil {
return
}
out.NewLine()
out.WriteString("ON CONFLICT")
if len(o.indexExpressions) > 0 {
out.WriteString("(")
jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName)
out.WriteString(")")
}
o.whereClause.Serialize(statementType, out, jet.SkipNewLine, jet.ShortName)
out.IncreaseIdent(7)
jet.Serialize(o.do, statementType, out)
out.DecreaseIdent(7)
}
type conflictAction interface {
jet.Serializer
WHERE(condition BoolExpression) conflictAction
}
// SET creates conflict action for ON_CONFLICT clause
func SET(assigments ...ColumnAssigment) conflictAction {
conflictAction := updateConflictActionImpl{}
conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"}
conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where)
conflictAction.set = assigments
return &conflictAction
}
type updateConflictActionImpl struct {
jet.Serializer
doUpdate jet.KeywordClause
set jet.SetClauseNew
where jet.ClauseWhere
}
func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction {
u.where.Condition = condition
return u
}

9
sqlite/operators.go Normal file
View file

@ -0,0 +1,9 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// NOT returns negation of bool expression result
var NOT = jet.NOT
// BIT_NOT inverts every bit in integer expression result
var BIT_NOT = jet.BIT_NOT

186
sqlite/select_statement.go Normal file
View file

@ -0,0 +1,186 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
)
// RowLock is interface for SELECT statement row lock types
type RowLock = jet.RowLock
// Row lock types
var (
UPDATE = jet.NewRowLock("UPDATE")
SHARE = jet.NewRowLock("SHARE")
)
// Window function clauses
var (
PARTITION_BY = jet.PARTITION_BY
ORDER_BY = jet.ORDER_BY
UNBOUNDED = jet.UNBOUNDED
CURRENT_ROW = jet.CURRENT_ROW
)
// PRECEDING window frame clause
func PRECEDING(offset interface{}) jet.FrameExtent {
return jet.PRECEDING(toJetFrameOffset(offset))
}
// FOLLOWING window frame clause
func FOLLOWING(offset interface{}) jet.FrameExtent {
return jet.FOLLOWING(toJetFrameOffset(offset))
}
// Window is used to specify window reference from WINDOW clause
var Window = jet.WindowName
// SelectStatement is interface for MySQL SELECT statement
type SelectStatement interface {
Statement
jet.HasProjections
Expression
DISTINCT() SelectStatement
FROM(tables ...ReadableTable) SelectStatement
WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...GroupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement
WINDOW(name string) windowExpand
ORDER_BY(orderByClauses ...OrderByClause) SelectStatement
LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement
FOR(lock RowLock) SelectStatement
LOCK_IN_SHARE_MODE() SelectStatement
UNION(rhs SelectStatement) setStatement
UNION_ALL(rhs SelectStatement) setStatement
AsTable(alias string) SelectTable
}
//SELECT creates new SelectStatement with list of projections
func SELECT(projection Projection, projections ...Projection) SelectStatement {
return newSelectStatement(nil, append([]Projection{projection}, projections...))
}
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
newSelect := &selectStatementImpl{}
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select,
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy,
&newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock)
newSelect.Select.ProjectionList = projections
if table != nil {
newSelect.From.Tables = []jet.Serializer{table}
}
newSelect.Limit.Count = -1
newSelect.Offset.Count = -1
newSelect.ShareLock.Name = "LOCK IN SHARE MODE"
newSelect.ShareLock.InNewLine = true
newSelect.setOperatorsImpl.parent = newSelect
return newSelect
}
type selectStatementImpl struct {
jet.ExpressionStatement
setOperatorsImpl
Select jet.ClauseSelect
From jet.ClauseFrom
Where jet.ClauseWhere
GroupBy jet.ClauseGroupBy
Having jet.ClauseHaving
Window jet.ClauseWindow
OrderBy jet.ClauseOrderBy
Limit jet.ClauseLimit
Offset jet.ClauseOffset
For jet.ClauseFor
ShareLock jet.ClauseOptional
}
func (s *selectStatementImpl) DISTINCT() SelectStatement {
s.Select.Distinct = true
return s
}
func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement {
s.From.Tables = nil
for _, table := range tables {
s.From.Tables = append(s.From.Tables, table)
}
return s
}
func (s *selectStatementImpl) WHERE(condition BoolExpression) SelectStatement {
s.Where.Condition = condition
return s
}
func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement {
s.GroupBy.List = groupByClauses
return s
}
func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatement {
s.Having.Condition = boolExpression
return s
}
func (s *selectStatementImpl) WINDOW(name string) windowExpand {
s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name})
return windowExpand{selectStatement: s}
}
func (s *selectStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) SelectStatement {
s.OrderBy.List = orderByClauses
return s
}
func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement {
s.Limit.Count = limit
return s
}
func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement {
s.Offset.Count = offset
return s
}
func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement {
s.For.Lock = lock
return s
}
func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement {
s.ShareLock.Show = true
return s
}
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
}
//-----------------------------------------------------
type windowExpand struct {
selectStatement *selectStatementImpl
}
func (w windowExpand) AS(window ...jet.Window) SelectStatement {
if len(window) == 0 {
return w.selectStatement
}
windowsDefinition := w.selectStatement.Window.Definitions
windowsDefinition[len(windowsDefinition)-1].Window = window[0]
return w.selectStatement
}
func toJetFrameOffset(offset interface{}) jet.Serializer {
if offset == UNBOUNDED {
return jet.UNBOUNDED
}
return jet.FixedLiteral(offset)
}

View file

@ -0,0 +1,156 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/testutils"
"testing"
)
func TestInvalidSelect(t *testing.T) {
assertStatementSqlErr(t, SELECT(nil), "jet: Projection is nil")
}
func TestSelectColumnList(t *testing.T) {
columnList := ColumnList{table2ColInt, table2ColFloat, table3ColInt}
assertStatementSql(t, SELECT(columnList).FROM(table2), `
SELECT table2.col_int AS "table2.col_int",
table2.col_float AS "table2.col_float",
table3.col_int AS "table3.col_int"
FROM db.table2;
`)
}
func TestSelectLiterals(t *testing.T) {
assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), `
SELECT ?,
?,
?
FROM db.table1;
`, int64(1), 2.2, false)
}
func TestSelectDistinct(t *testing.T) {
assertStatementSql(t, SELECT(table1ColBool).DISTINCT().FROM(table1), `
SELECT DISTINCT table1.col_bool AS "table1.col_bool"
FROM db.table1;
`)
}
func TestSelectFrom(t *testing.T) {
assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1;
`)
assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int);
`)
assertStatementSql(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int);
`)
}
func TestSelectWhere(t *testing.T) {
assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE ?;
`, true)
assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE table1.col_int >= ?;
`, int64(10))
}
func TestSelectGroupBy(t *testing.T) {
assertStatementSql(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
GROUP BY table2.col_float;
`)
}
func TestSelectHaving(t *testing.T) {
assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), `
SELECT table3.col_int AS "table3.col_int"
FROM db.table3
HAVING table1.col_bool = ?;
`, true)
}
func TestSelectOrderBy(t *testing.T) {
assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), `
SELECT table2.col_float AS "table2.col_float"
FROM db.table2
ORDER BY table2.col_int DESC;
`)
assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), `
SELECT table2.col_float AS "table2.col_float"
FROM db.table2
ORDER BY table2.col_int DESC, table2.col_int ASC;
`)
}
func TestSelectLimitOffset(t *testing.T) {
assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
LIMIT ?;
`, int64(10))
assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
LIMIT ?
OFFSET ?;
`, int64(10), int64(2))
}
func TestSelectLock(t *testing.T) {
testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE()), `
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
FOR UPDATE;
`)
testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(SHARE().NOWAIT()), `
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
FOR SHARE NOWAIT;
`)
}
func TestSelect_LOCK_IN_SHARE_MODE(t *testing.T) {
testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).LOCK_IN_SHARE_MODE(), `
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
LOCK IN SHARE MODE;
`)
}
func TestSelect_NOT_EXISTS(t *testing.T) {
testutils.AssertStatementSql(t,
SELECT(table1ColInt).
FROM(table1).
WHERE(
NOT(EXISTS(
SELECT(table2ColInt).
FROM(table2).
WHERE(
table1ColInt.EQ(table2ColInt),
),
))), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE (NOT (EXISTS (
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
WHERE table1.col_int = table2.col_int
)));
`)
}

24
sqlite/select_table.go Normal file
View file

@ -0,0 +1,24 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// SelectTable is interface for MySQL sub-queries
type SelectTable interface {
readableTable
jet.SelectTable
}
type selectTableImpl struct {
jet.SelectTable
readableTableInterfaceImpl
}
func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable {
subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias),
}
subQuery.readableTableInterfaceImpl.parent = subQuery
return subQuery
}

99
sqlite/set_statement.go Normal file
View file

@ -0,0 +1,99 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// UNION effectively appends the result of sub-queries(select statements) into single query.
// It eliminates duplicate rows from its result.
func UNION(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...))
}
// UNION_ALL effectively appends the result of sub-queries(select statements) into single query.
// It does not eliminates duplicate rows from its result.
func UNION_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...))
}
type setStatement interface {
setOperators
ORDER_BY(orderByClauses ...OrderByClause) setStatement
LIMIT(limit int64) setStatement
OFFSET(offset int64) setStatement
AsTable(alias string) SelectTable
}
type setOperators interface {
jet.Statement
jet.HasProjections
jet.Expression
UNION(rhs SelectStatement) setStatement
UNION_ALL(rhs SelectStatement) setStatement
}
type setOperatorsImpl struct {
parent setOperators
}
func (s *setOperatorsImpl) UNION(rhs SelectStatement) setStatement {
return UNION(s.parent, rhs)
}
func (s *setOperatorsImpl) UNION_ALL(rhs SelectStatement) setStatement {
return UNION_ALL(s.parent, rhs)
}
type setStatementImpl struct {
jet.ExpressionStatement
setOperatorsImpl
setOperator jet.ClauseSetStmtOperator
}
func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStatement) setStatement {
newSetStatement := &setStatementImpl{}
newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement,
&newSetStatement.setOperator)
newSetStatement.setOperator.Operator = operator
newSetStatement.setOperator.All = all
newSetStatement.setOperator.Selects = selects
newSetStatement.setOperator.Limit.Count = -1
newSetStatement.setOperator.Offset.Count = -1
newSetStatement.setOperator.SkipSelectWrap = true
newSetStatement.setOperatorsImpl.parent = newSetStatement
return newSetStatement
}
func (s *setStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) setStatement {
s.setOperator.OrderBy.List = orderByClauses
return s
}
func (s *setStatementImpl) LIMIT(limit int64) setStatement {
s.setOperator.Limit.Count = limit
return s
}
func (s *setStatementImpl) OFFSET(offset int64) setStatement {
s.setOperator.Offset.Count = offset
return s
}
func (s *setStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
}
const (
union = "UNION"
)
func toSelectList(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) []jet.SerializerStatement {
return append([]jet.SerializerStatement{lhs, rhs}, selects...)
}

View file

@ -0,0 +1,31 @@
package sqlite
import (
"testing"
)
func TestSelectSets(t *testing.T) {
select1 := SELECT(table1ColBool).FROM(table1)
select2 := SELECT(table2ColBool).FROM(table2)
assertStatementSql(t, select1.UNION(select2), `
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
UNION
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2;
`)
assertStatementSql(t, select1.UNION_ALL(select2), `
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
UNION ALL
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2;
`)
}

8
sqlite/statement.go Normal file
View file

@ -0,0 +1,8 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// RawStatement creates new sql statements from raw query and optional map of named arguments
func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement {
return jet.RawStatement(Dialect, rawQuery, namedArguments...)
}

122
sqlite/table.go Normal file
View file

@ -0,0 +1,122 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// Table is interface for MySQL tables
type Table interface {
jet.SerializerTable
readableTable
INSERT(columns ...jet.Column) InsertStatement
UPDATE(columns ...jet.Column) UpdateStatement
DELETE() DeleteStatement
}
type readableTable interface {
// Generates a select query on the current tableName.
SELECT(projection Projection, projections ...Projection) SelectStatement
// Creates a inner join tableName Expression using onCondition.
INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable
// Creates a left join tableName Expression using onCondition.
LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable
// Creates a right join tableName Expression using onCondition.
RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable
// Creates a full join tableName Expression using onCondition.
FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable
// Creates a cross join tableName Expression using onCondition.
CROSS_JOIN(table ReadableTable) joinSelectUpdateTable
}
type joinSelectUpdateTable interface {
ReadableTable
UPDATE(columns ...jet.Column) UpdateStatement
}
// ReadableTable interface
type ReadableTable interface {
readableTable
jet.Serializer
}
type readableTableInterfaceImpl struct {
parent ReadableTable
}
// Generates a select query on the current tableName.
func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...))
}
// Creates a inner join tableName Expression using onCondition.
func (r readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.InnerJoin, onCondition)
}
// Creates a left join tableName Expression using onCondition.
func (r readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.LeftJoin, onCondition)
}
// Creates a right join tableName Expression using onCondition.
func (r readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.RightJoin, onCondition)
}
func (r readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.FullJoin, onCondition)
}
func (r readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.CrossJoin, nil)
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name, alias string, columns ...jet.ColumnExpression) Table {
t := &tableImpl{
SerializerTable: jet.NewTable(schemaName, name, alias, columns...),
}
t.readableTableInterfaceImpl.parent = t
t.parent = t
return t
}
type tableImpl struct {
jet.SerializerTable
readableTableInterfaceImpl
parent Table
}
func (t *tableImpl) INSERT(columns ...jet.Column) InsertStatement {
return newInsertStatement(t.parent, jet.UnwidColumnList(columns))
}
func (t *tableImpl) UPDATE(columns ...jet.Column) UpdateStatement {
return newUpdateStatement(t.parent, jet.UnwidColumnList(columns))
}
func (t *tableImpl) DELETE() DeleteStatement {
return newDeleteStatement(t.parent)
}
type joinTable struct {
tableImpl
jet.JoinTable
}
func newJoinTable(lhs jet.Serializer, rhs jet.Serializer, joinType jet.JoinType, onCondition BoolExpression) Table {
newJoinTable := &joinTable{
JoinTable: jet.NewJoinTable(lhs, rhs, joinType, onCondition),
}
newJoinTable.readableTableInterfaceImpl.parent = newJoinTable
newJoinTable.parent = newJoinTable
return newJoinTable
}

101
sqlite/table_test.go Normal file
View file

@ -0,0 +1,101 @@
package sqlite
import (
"testing"
)
func TestJoinNilInputs(t *testing.T) {
assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
"jet: right hand side of join operation is nil table")
assertSerializeErr(t, table2.INNER_JOIN(table1, nil),
"jet: join condition is nil")
}
func TestINNER_JOIN(t *testing.T) {
assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).
INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int)
INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(Int(1))).
INNER_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = ?)
INNER JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestLEFT_JOIN(t *testing.T) {
assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)
LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(Int(1))).
LEFT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = ?)
LEFT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestRIGHT_JOIN(t *testing.T) {
assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)
RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))).
RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = ?)
RIGHT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestFULL_JOIN(t *testing.T) {
assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)).
FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int)
FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(Int(1))).
FULL_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = ?)
FULL JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestCROSS_JOIN(t *testing.T) {
assertSerialize(t, table1.
CROSS_JOIN(table2),
`db.table1
CROSS JOIN db.table2`)
assertSerialize(t, table1.
CROSS_JOIN(table2).
CROSS_JOIN(table3),
`db.table1
CROSS JOIN db.table2
CROSS JOIN db.table3`)
}

27
sqlite/types.go Normal file
View file

@ -0,0 +1,27 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
type Statement = jet.Statement
// Projection is interface for all projection types. Types that can be part of, for instance SELECT clause.
type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment = jet.ColumnAssigment
// PrintableStatement is a statement which sql query can be logged
type PrintableStatement = jet.PrintableStatement
// OrderByClause is the combination of an expression and the wanted ordering to use as input for ORDER BY.
type OrderByClause = jet.OrderByClause
// GroupByClause interface to use as input for GROUP_BY
type GroupByClause = jet.GroupByClause
// SetLogger sets automatic statement logging
var SetLogger = jet.SetLoggerFunc

View file

@ -0,0 +1,70 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// UpdateStatement is interface of SQL UPDATE statement
type UpdateStatement interface {
jet.Statement
SET(value interface{}, values ...interface{}) UpdateStatement
MODEL(data interface{}) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...jet.Projection) UpdateStatement
}
type updateStatementImpl struct {
jet.SerializerStatement
Update jet.ClauseUpdate
Set jet.SetClause
SetNew jet.SetClauseNew
Where jet.ClauseWhere
Returning jet.ClauseReturning
}
func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
update := &updateStatementImpl{}
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update,
&update.Update,
&update.Set,
&update.SetNew,
&update.Where,
&update.Returning)
update.Update.Table = table
update.Set.Columns = columns
update.Where.Mandatory = true
return update
}
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
columnAssigment, isColumnAssigment := value.(ColumnAssigment)
if isColumnAssigment {
u.SetNew = []ColumnAssigment{columnAssigment}
for _, value := range values {
u.SetNew = append(u.SetNew, value.(ColumnAssigment))
}
} else {
u.Set.Values = jet.UnwindRowFromValues(value, values)
}
return u
}
func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement {
u.Set.Values = jet.UnwindRowFromModel(u.Set.Columns, data)
return u
}
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.Where.Condition = expression
return u
}
func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement {
u.Returning.ProjectionList = projections
return u
}

View file

@ -0,0 +1,82 @@
package sqlite
import (
"fmt"
"strings"
"testing"
)
func TestUpdateWithOneValue(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = ?
WHERE table1.col_int >= ?;
`
stmt := table1.UPDATE(table1ColInt).
SET(1).
WHERE(table1ColInt.GT_EQ(Int(33)))
fmt.Println(stmt.Sql())
assertStatementSql(t, stmt, expectedSQL, 1, int64(33))
}
func TestUpdateWithValues(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = ?,
col_float = ?
WHERE table1.col_int >= ?;
`
stmt := table1.UPDATE(table1ColInt, table1ColFloat).
SET(1, 22.2).
WHERE(table1ColInt.GT_EQ(Int(33)))
fmt.Println(stmt.Sql())
assertStatementSql(t, stmt, expectedSQL, 1, 22.2, int64(33))
}
func TestUpdateOneColumnWithSelect(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_float = (
SELECT table1.col_float AS "table1.col_float"
FROM db.table1
)
WHERE table1.col1 = ?;
`
stmt := table1.
UPDATE(table1ColFloat).
SET(
table1.SELECT(table1ColFloat),
).
WHERE(table1Col1.EQ(Int(2)))
assertStatementSql(t, stmt, expectedSQL, int64(2))
}
func TestUpdateReservedWorldColumn(t *testing.T) {
type table struct {
Load string
}
loadColumn := StringColumn("Load")
assertStatementSql(t,
table1.UPDATE(loadColumn).
MODEL(
table{
Load: "foo",
},
).
WHERE(loadColumn.EQ(String("bar"))), strings.Replace(`
UPDATE db.table1
SET ''Load'' = ?
WHERE ''Load'' = ?;
`, "''", "`", -1), "foo", "bar")
}
func TestInvalidInputs(t *testing.T) {
assertStatementSqlErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set")
assertStatementSqlErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list for SET clause")
}

55
sqlite/utils_test.go Normal file
View file

@ -0,0 +1,55 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/testutils"
"testing"
)
var table1Col1 = IntegerColumn("col1")
var table1ColBool = BoolColumn("col_bool")
var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float")
var table1ColString = StringColumn("col_string")
var table1Col3 = IntegerColumn("col3")
var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColDate = DateColumn("col_date")
var table1ColTime = TimeColumn("col_time")
var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1ColString, table1Col3, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTime)
var table2Col3 = IntegerColumn("col3")
var table2Col4 = IntegerColumn("col4")
var table2ColInt = IntegerColumn("col_int")
var table2ColFloat = FloatColumn("col_float")
var table2ColStr = StringColumn("col_str")
var table2ColBool = BoolColumn("col_bool")
var table2ColTimestamp = TimestampColumn("col_timestamp")
var table2ColDate = DateColumn("col_date")
var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColDate, table2ColTimestamp)
var table3Col1 = IntegerColumn("col1")
var table3ColInt = IntegerColumn("col_int")
var table3StrCol = StringColumn("col2")
var table3 = NewTable("db", "table3", "", table3Col1, table3ColInt, table3StrCol)
func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
testutils.AssertSerialize(t, Dialect, clause, query, args...)
}
func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
testutils.AssertDebugSerialize(t, Dialect, clause, query, args...)
}
func assertSerializeErr(t *testing.T, clause jet.Serializer, errString string) {
testutils.AssertSerializeErr(t, Dialect, clause, errString)
}
func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) {
testutils.AssertProjectionSerialize(t, Dialect, projection, query, args...)
}
var assertPanicErr = testutils.AssertPanicErr
var assertStatementSql = testutils.AssertStatementSql
var assertStatementSqlErr = testutils.AssertStatementSqlErr

26
sqlite/with_statement.go Normal file
View file

@ -0,0 +1,26 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
readableTableInterfaceImpl
jet.CommonTableExpression
}
// WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, cte...)
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
cte := CommonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name),
}
cte.parent = &cte
return cte
}

View file

@ -1,18 +1,21 @@
package dbconfig package dbconfig
import "fmt" import (
"fmt"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
)
// Postgres test database connection parameters // Postgres test database connection parameters
const ( const (
Host = "localhost" PgHost = "localhost"
Port = 5432 PgPort = 5432
User = "jet" PgUser = "jet"
Password = "jet" PgPassword = "jet"
DBName = "jetdb" PgDBName = "jetdb"
) )
// PostgresConnectString is PostgreSQL test database connection string // PostgresConnectString is PostgreSQL test database connection string
var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName) var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", PgHost, PgPort, PgUser, PgPassword, PgDBName)
// MySQL test database connection parameters // MySQL test database connection parameters
const ( const (
@ -24,3 +27,10 @@ const (
// MySQLConnectionString is MySQL driver connection string to test database // MySQLConnectionString is MySQL driver connection string to test database
var MySQLConnectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/", MySQLUser, MySQLPassword, MySqLHost, MySQLPort) var MySQLConnectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/", MySQLUser, MySQLPassword, MySqLHost, MySQLPort)
// sqllite
var (
SakilaDBPath = repo.GetTestDataFilePath("/init/sqlite/sakila.db")
ChinookDBPath = repo.GetTestDataFilePath("/init/sqlite/chinook.db")
TestSampleDBPath = repo.GetTestDataFilePath("/init/sqlite/test_sample.db")
)

View file

@ -4,16 +4,21 @@ import (
"database/sql" "database/sql"
"flag" "flag"
"fmt" "fmt"
"github.com/go-jet/jet/v2/generator/mysql" "github.com/go-jet/jet/v2/generator/sqlite"
"github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/tests/internal/utils/repo"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/tests/dbconfig"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"github.com/go-jet/jet/v2/generator/mysql"
"github.com/go-jet/jet/v2/generator/postgres"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/tests/dbconfig"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
) )
var testSuite string var testSuite string
@ -38,8 +43,23 @@ func main() {
return return
} }
if testSuite == "sqlite" {
initSQLiteDB()
return
}
initMySQLDB() initMySQLDB()
initPostgresDB() initPostgresDB()
initSQLiteDB()
}
func initSQLiteDB() {
err := sqlite.GenerateDSN(dbconfig.SakilaDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/sakila"))
throw.OnError(err)
err = sqlite.GenerateDSN(dbconfig.ChinookDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/chinook"))
throw.OnError(err)
err = sqlite.GenerateDSN(dbconfig.TestSampleDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/test_sample"))
throw.OnError(err)
} }
func initMySQLDB() { func initMySQLDB() {
@ -62,7 +82,7 @@ func initMySQLDB() {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
err := cmd.Run() err := cmd.Run()
utils.PanicOnError(err) throw.OnError(err)
err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{ err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{
Host: dbconfig.MySqLHost, Host: dbconfig.MySqLHost,
@ -72,7 +92,7 @@ func initMySQLDB() {
DBName: dbName, DBName: dbName,
}) })
utils.PanicOnError(err) throw.OnError(err)
} }
} }
@ -99,24 +119,24 @@ func initPostgresDB() {
execFile(db, "./testdata/init/postgres/"+schemaName+".sql") execFile(db, "./testdata/init/postgres/"+schemaName+".sql")
err = postgres.Generate("./.gentestdata", postgres.DBConnection{ err = postgres.Generate("./.gentestdata", postgres.DBConnection{
Host: dbconfig.Host, Host: dbconfig.PgHost,
Port: 5432, Port: dbconfig.PgPort,
User: dbconfig.User, User: dbconfig.PgUser,
Password: dbconfig.Password, Password: dbconfig.PgPassword,
DBName: dbconfig.DBName, DBName: dbconfig.PgDBName,
SchemaName: schemaName, SchemaName: schemaName,
SslMode: "disable", SslMode: "disable",
}) })
utils.PanicOnError(err) throw.OnError(err)
} }
} }
func execFile(db *sql.DB, sqlFilePath string) { func execFile(db *sql.DB, sqlFilePath string) {
testSampleSql, err := ioutil.ReadFile(sqlFilePath) testSampleSql, err := ioutil.ReadFile(sqlFilePath)
utils.PanicOnError(err) throw.OnError(err)
_, err = db.Exec(string(testSampleSql)) _, err = db.Exec(string(testSampleSql))
utils.PanicOnError(err) throw.OnError(err)
} }
func printOnError(err error) { func printOnError(err error) {

View file

@ -0,0 +1,25 @@
package file
import (
"github.com/stretchr/testify/require"
"io/ioutil"
"os"
"path"
"testing"
)
// Exists expects file to exist on path constructed from pathElems and returns content of the file
func Exists(t *testing.T, pathElems ...string) (fileContent string) {
modelFilePath := path.Join(pathElems...)
file, err := ioutil.ReadFile(modelFilePath)
require.Nil(t, err)
require.NotEmpty(t, file)
return string(file)
}
// NotExists expects file not to exist on path constructed from pathElems
func NotExists(t *testing.T, pathElems ...string) {
modelFilePath := path.Join(pathElems...)
_, err := ioutil.ReadFile(modelFilePath)
require.True(t, os.IsNotExist(err))
}

View file

@ -0,0 +1,33 @@
package repo
import (
"os/exec"
"path/filepath"
"strings"
)
// GetRootDirPath will return this repo full dir path
func GetRootDirPath() string {
cmd := exec.Command("git", "rev-parse", "--show-toplevel")
byteArr, err := cmd.Output()
if err != nil {
panic(err)
}
return strings.TrimSpace(string(byteArr))
}
// GetTestsDirPath will return tests folder full path
func GetTestsDirPath() string {
return filepath.Join(GetRootDirPath(), "tests")
}
// GetTestsFilePath will return full file path of the file in the tests folder
func GetTestsFilePath(subPath string) string {
return filepath.Join(GetTestsDirPath(), subPath)
}
// GetTestDataFilePath will return full file path of the file in the testdata folder
func GetTestDataFilePath(subPath string) string {
return filepath.Join(GetTestsDirPath(), "testdata", subPath)
}

View file

@ -104,18 +104,18 @@ func TestExpressionOperators(t *testing.T) {
SELECT all_types.'integer' IS NULL AS "result.is_null", SELECT all_types.'integer' IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN (?, ?)) AS "result.in", (all_types.small_int_ptr IN (?, ?)) AS "result.in",
(all_types.small_int_ptr IN (( (all_types.small_int_ptr IN (
SELECT all_types.'integer' AS "all_types.integer" SELECT all_types.'integer' AS "all_types.integer"
FROM test_sample.all_types FROM test_sample.all_types
))) AS "result.in_select", )) AS "result.in_select",
(CURRENT_USER()) AS "result.raw", (CURRENT_USER()) AS "result.raw",
(? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg",
(? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2",
(all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in",
(all_types.small_int_ptr NOT IN (( (all_types.small_int_ptr NOT IN (
SELECT all_types.'integer' AS "all_types.integer" SELECT all_types.'integer' AS "all_types.integer"
FROM test_sample.all_types FROM test_sample.all_types
))) AS "result.not_in_select" )) AS "result.not_in_select"
FROM test_sample.all_types FROM test_sample.all_types
LIMIT ?; LIMIT ?;
`, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) `, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2))
@ -467,10 +467,10 @@ func TestStringOperators(t *testing.T) {
AllTypes.Text.NOT_LIKE(String("_b_")), AllTypes.Text.NOT_LIKE(String("_b_")),
AllTypes.Text.REGEXP_LIKE(String("aba")), AllTypes.Text.REGEXP_LIKE(String("aba")),
AllTypes.Text.REGEXP_LIKE(String("aba"), false), AllTypes.Text.REGEXP_LIKE(String("aba"), false),
String("ABA").REGEXP_LIKE(String("aba"), true), //String("ABA").REGEXP_LIKE(String("aba"), true),
AllTypes.Text.NOT_REGEXP_LIKE(String("aba")), AllTypes.Text.NOT_REGEXP_LIKE(String("aba")),
AllTypes.Text.NOT_REGEXP_LIKE(String("aba"), false), AllTypes.Text.NOT_REGEXP_LIKE(String("aba"), false),
String("ABA").NOT_REGEXP_LIKE(String("aba"), true), //String("ABA").NOT_REGEXP_LIKE(String("aba"), true),
BIT_LENGTH(AllTypes.Text), BIT_LENGTH(AllTypes.Text),
CHAR_LENGTH(AllTypes.Char), CHAR_LENGTH(AllTypes.Char),
@ -962,7 +962,7 @@ func TestAllTypesInsert(t *testing.T) {
tx, err := db.Begin() tx, err := db.Begin()
require.NoError(t, err) require.NoError(t, err)
stmt := AllTypes.INSERT(AllTypes.AllColumns). stmt := AllTypes.INSERT(AllTypes.AllColumns.Except(AllTypes.TimestampPtr)).
MODEL(toInsert) MODEL(toInsert)
//fmt.Println(stmt.DebugSql()) //fmt.Println(stmt.DebugSql())
@ -970,7 +970,7 @@ func TestAllTypesInsert(t *testing.T) {
testutils.AssertExec(t, stmt, tx, 1) testutils.AssertExec(t, stmt, tx, 1)
var dest model.AllTypes var dest model.AllTypes
err = AllTypes.SELECT(AllTypes.AllColumns). err = AllTypes.SELECT(AllTypes.AllColumns.Except(AllTypes.TimestampPtr)).
WHERE(AllTypes.BigInt.EQ(Int(toInsert.BigInt))). WHERE(AllTypes.BigInt.EQ(Int(toInsert.BigInt))).
Query(tx, &dest) Query(tx, &dest)

Some files were not shown because too many files have changed in this diff Show more