Merge pull request #116 from go-jet/develop

Release 2.7.0
This commit is contained in:
go-jet 2022-01-20 17:41:34 +01:00 committed by GitHub
commit 3e802f8955
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
106 changed files with 4369 additions and 866 deletions

View file

@ -1,28 +1,37 @@
# Golang CircleCI 2.0 configuration file
#
# Check https://circleci.com/docs/2.0/language-go/ for more details
version: 2
version: 2.1
orbs:
codecov: codecov/codecov@3.1.1
jobs:
build-postgres-and-mysql:
build_and_tests:
docker:
# specify the version
- image: circleci/golang:1.13
- image: circleci/postgres:10.8-alpine
environment: # environment variables for primary container
- image: circleci/golang:1.16
- image: circleci/postgres:12
environment:
POSTGRES_USER: jet
POSTGRES_PASSWORD: jet
POSTGRES_DB: jetdb
PGPORT: 50901
- image: circleci/mysql:8.0.16
command: [--default-authentication-plugin=mysql_native_password]
- image: circleci/mysql:8.0.27
command: [ --default-authentication-plugin=mysql_native_password ]
environment:
MYSQL_ROOT_PASSWORD: jet
MYSQL_DATABASE: dvds
MYSQL_USER: jet
MYSQL_PASSWORD: jet
MYSQL_TCP_PORT: 50902
working_directory: /go/src/github.com/go-jet/jet
- image: circleci/mariadb:10.3
command: [ '--default-authentication-plugin=mysql_native_password', '--port=50903' ]
environment:
MYSQL_ROOT_PASSWORD: jet
MYSQL_DATABASE: dvds
MYSQL_USER: jet
MYSQL_PASSWORD: jet
environment: # environment variables for the build itself
TEST_RESULTS: /tmp/test-results # path to where test results will be saved
@ -32,25 +41,22 @@ jobs:
- run:
name: Submodule init
command: |
git submodule init
git submodule update
cd ./tests/testdata && git fetch && git checkout master
command: cd tests && make checkout-testdata
- restore_cache: # restores saved cache if no changes are detected since last run
keys:
- go-mod-v4-{{ checksum "go.sum" }}
- run:
name: Install dependencies
command: |
cd /go/src/github.com/go-jet/jet
go get github.com/jstemmer/go-junit-report
go build -o /home/circleci/.local/bin/jet ./cmd/jet/
name: Install jet generator
command: cd tests && make install-jet-gen
- run:
name: Waiting for Postgres to be ready
command: |
for i in `seq 1 10`;
do
nc -z localhost 5432 && echo Success && exit 0
nc -z localhost 50901 && echo Success && exit 0
echo -n .
sleep 1
done
@ -61,39 +67,71 @@ jobs:
command: |
for i in `seq 1 10`;
do
nc -z 127.0.0.1 3306 && echo Success && exit 0
nc -z 127.0.0.1 50902 && echo Success && exit 0
echo -n .
sleep 1
done
echo Failed waiting for MySQL && exit 1
- run:
name: Waiting for MariaDB to be ready
command: |
for i in `seq 1 10`;
do
nc -z 127.0.0.1 50903 && echo Success && exit 0
echo -n .
sleep 1
done
echo Failed waiting for MySQL && exit 1
- run:
name: Install MySQL CLI;
command: |
sudo apt-get --allow-releaseinfo-change update && sudo apt-get install default-mysql-client
- run:
name: Create MySQL user and databases
name: Create MySQL/MariaDB user and test databases
command: |
mysql -h 127.0.0.1 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -u jet -pjet -e "create database test_sample"
mysql -h 127.0.0.1 -u jet -pjet -e "create database dvds2"
mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database test_sample"
mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database dvds2"
mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database test_sample"
mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database dvds2"
- run:
name: Init Postgres database
command: |
cd tests
go run ./init/init.go -testsuite all
cd ..
name: Init databases
command: |
cd tests
go run ./init/init.go -testsuite all
# to create test results report
- run:
name: Install go-junit-report
command: go install github.com/jstemmer/go-junit-report@latest
- run: mkdir -p $TEST_RESULTS
# this will run all tests and exclude test files from code coverage report
- run: MY_SQL_SOURCE=MySQL go test -v ./... -covermode=atomic -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/sqlite/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
- run:
name: Upload code coverage
command: bash <(curl -s https://codecov.io/bash)
# this will run all tests and exclude test files from code coverage report
- run: |
go test -v ./... \
-covermode=atomic \
-coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... \
-coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
# run mariaDB tests. No need to collect coverage, because coverage is already included with mysql tests
- run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/
- save_cache:
key: go-mod-v4-{{ checksum "go.sum" }}
paths:
- "/go/pkg/mod"
- codecov/upload:
file: cover.out
- store_artifacts: # Upload test summary for display in Artifacts: https://circleci.com/docs/2.0/artifacts/
path: /tmp/test-results
@ -101,69 +139,9 @@ jobs:
- store_test_results: # Upload test results for display in Test Summary: https://circleci.com/docs/2.0/collect-test-data/
path: /tmp/test-results
build-mariadb:
docker:
# specify the version
- image: circleci/golang:1.13
- image: circleci/mariadb:10.3
command: [--default-authentication-plugin=mysql_native_password]
environment:
MYSQL_ROOT_PASSWORD: jet
MYSQL_DATABASE: dvds
MYSQL_USER: jet
MYSQL_PASSWORD: jet
working_directory: /go/src/github.com/go-jet/jet
environment: # environment variables for the build itself
TEST_RESULTS: /tmp/test-results # path to where test results will be saved
steps:
- checkout
- run:
name: Submodule init
command: |
git submodule init
git submodule update
cd ./tests/testdata && git fetch && git checkout master
- run:
name: Install dependencies
command: |
cd /go/src/github.com/go-jet/jet
go get github.com/jstemmer/go-junit-report
go build -o /home/circleci/.local/bin/jet ./cmd/jet/
- run:
name: Install MySQL CLI;
command: |
sudo apt-get --allow-releaseinfo-change update && sudo apt-get install default-mysql-client
- run:
name: Init MariaDB database
command: |
mysql -h 127.0.0.1 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -u jet -pjet -e "create database test_sample"
mysql -h 127.0.0.1 -u jet -pjet -e "create database dvds2"
- run:
name: Init MariaDB database
command: |
cd tests
go run ./init/init.go -testsuite MariaDB
cd ..
- run:
name: Run MariaDB tests
command: |
MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/
workflows:
version: 2
build_and_test:
jobs:
- build-postgres-and-mysql
- build-mariadb
- build_and_tests

4
.gitignore vendored
View file

@ -19,4 +19,6 @@
gen
.gentestdata
.tests/testdata/
.gen
.gen
.docker
.env

View file

@ -60,28 +60,26 @@ Use the command bellow to add jet as a dependency into `go.mod` project:
$ go get -u github.com/go-jet/jet/v2
```
Jet generator can be installed in the following ways:
Jet generator can be installed in one of the following ways:
1) Install jet generator to GOPATH/bin folder:
1) (Go1.16+) Install jet generator using go install:
```sh
go install github.com/go-jet/jet/v2/cmd/jet@latest
```
2) Install jet generator to GOPATH/bin folder:
```sh
cd $GOPATH/src/ && GO111MODULE=off go get -u github.com/go-jet/jet/cmd/jet
```
*Make sure GOPATH/bin folder is added to the PATH environment variable.*
```
2) Install jet generator into specific folder:
3) Install jet generator into specific folder:
```sh
git clone https://github.com/go-jet/jet.git
cd jet && go build -o dir_path ./cmd/jet
```
*Make sure `dir_path` folder is added to the PATH environment variable.*
*Make sure that the destination folder is added to the PATH environment variable.*
3) (Go1.16+) Install jet generator using go install:
```sh
go install github.com/go-jet/jet/v2/cmd/jet@latest
```
*Jet generator is installed to the directory named by the GOBIN environment variable,
which defaults to $GOPATH/bin or $HOME/go/bin if the GOPATH environment variable is not set.*
### Quick Start
For this quick start example we will use PostgreSQL sample _'dvd rental'_ database. Full database dump can be found in

View file

@ -3,7 +3,14 @@ package main
import (
"flag"
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
sqlitegen "github.com/go-jet/jet/v2/generator/sqlite"
"github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/mysql"
postgres2 "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/sqlite"
"os"
"strings"
@ -27,34 +34,17 @@ var (
dbName string
schemaName string
ignoreTables string
ignoreViews string
ignoreEnums string
destDir string
)
func init() {
flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL, MariaDB or SQLite)")
flag.StringVar(&source, "source", "", "Database system name (postgres, mysql, mariadb or sqlite)")
flag.StringVar(&dsn, "dsn", "", "Data source name connection string (Example: postgresql://user@localhost:5432/otherdb?sslmode=trust)")
flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
flag.IntVar(&port, "port", 0, "Database port")
flag.StringVar(&user, "user", "", "Database user")
flag.StringVar(&password, "password", "", "The users password")
flag.StringVar(&params, "params", "", "Additional connection string parameters(optional)")
flag.StringVar(&dbName, "dbname", "", "Database name")
flag.StringVar(&schemaName, "schema", "public", `Database schema name. (default "public") (ignored for MySQL and MariaDB)`)
flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL(optional)(default "disable") (ignored for MySQL and MariaDB)`)
flag.StringVar(&destDir, "path", "", "Destination dir for files generated.")
}
func main() {
flag.Usage = func() {
_, _ = fmt.Fprint(os.Stdout, `
Jet generator 2.6.0
Usage:
-dsn string
Data source name. Unified format for connecting to database.
flag.StringVar(&dsn, "dsn", "", `Data source name. Unified format for connecting to database.
PostgreSQL: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
Example:
postgresql://user:pass@localhost:5432/dbname
@ -63,65 +53,70 @@ Usage:
mysql://jet:jet@tcp(localhost:3306)/dvds
SQLite: https://www.sqlite.org/c3ref/open.html#urifilenameexamples
Example:
file://path/to/database/file
-source string
Database system name (PostgreSQL, MySQL, MariaDB or SQLite)
-host string
Database host path (Example: localhost)
-port int
Database port
-user string
Database user
-password string
The users password
-dbname string
Database name
-params string
Additional connection string parameters(optional)
-schema string
Database schema name. (default "public") (ignored for MySQL, MariaDB and SQLite)
-sslmode string
Whether or not to use SSL(optional) (default "disable") (ignored for MySQL, MariaDB and SQLite)
-path string
Destination dir for files generated.
file://path/to/database/file`)
flag.StringVar(&host, "host", "", "Database host path. Used only if dsn is not set. (Example: localhost)")
flag.IntVar(&port, "port", 0, "Database port. Used only if dsn is not set.")
flag.StringVar(&user, "user", "", "Database user. Used only if dsn is not set.")
flag.StringVar(&password, "password", "", "The users password. Used only if dsn is not set.")
flag.StringVar(&dbName, "dbname", "", "Database name. Used only if dsn is not set.")
flag.StringVar(&schemaName, "schema", "public", `Database schema name. Used only if dsn is not set. (default "public")(PostgreSQL only)`)
flag.StringVar(&params, "params", "", "Additional connection string parameters(optional). Used only if dsn is not set.")
flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL. Used only if dsn is not set. (optional)(default "disable")(PostgreSQL only)`)
flag.StringVar(&ignoreTables, "ignore-tables", "", `Comma-separated list of tables to ignore`)
flag.StringVar(&ignoreViews, "ignore-views", "", `Comma-separated list of views to ignore`)
flag.StringVar(&ignoreEnums, "ignore-enums", "", `Comma-separated list of enums to ignore`)
Example commands:
flag.StringVar(&destDir, "path", "", "Destination dir for files generated.")
}
func main() {
flag.Usage = func() {
fmt.Println("Jet generator 2.7.0")
fmt.Println()
fmt.Println("Usage:")
order := []string{
"source", "dsn", "host", "port", "user", "password", "dbname", "schema", "params", "sslmode",
"path",
"ignore-tables", "ignore-views", "ignore-enums",
}
for _, name := range order {
flagEntry := flag.CommandLine.Lookup(name)
fmt.Printf(" -%s\n", flagEntry.Name)
fmt.Printf("\t%s\n", flagEntry.Usage)
}
fmt.Println()
fmt.Println(`Example command:
$ jet -source=PostgreSQL -dbname=jetdb -host=localhost -port=5432 -user=jet -password=jet -schema=dvds -path=./gen
$ jet -dsn=postgresql://jet:jet@localhost:5432/jetdb -schema=dvds -path=./gen
$ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds -path=./gen
$ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -schema=dvds -path=./gen
`)
$ jet -source=mysql -host=localhost -port=3306 -user=jet -password=jet -dbname=jetdb -path=./gen
$ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -path=./gen
`)
}
flag.Parse()
if dsn == "" {
// validations for separated connection flags.
if source == "" || host == "" || port == 0 || user == "" || dbName == "" {
printErrorAndExit("ERROR: required flag(s) missing")
}
} else {
if source == "" {
// try to get source from schema
source = detectSchema(dsn)
}
// validations when dsn != ""
if source == "" {
printErrorAndExit("ERROR: required -source flag missing.")
}
if dsn == "" && (source == "" || host == "" || port == 0 || user == "" || dbName == "") {
printErrorAndExit("ERROR: required flag(s) missing")
}
source := getSource()
ignoreTablesList := parseList(ignoreTables)
ignoreViewsList := parseList(ignoreViews)
ignoreEnumsList := parseList(ignoreEnums)
var err error
switch strings.ToLower(strings.TrimSpace(source)) {
switch source {
case "postgresql", "postgres":
if dsn != "" {
err = postgresgen.GenerateDSN(dsn, schemaName, destDir)
break
}
genData := postgresgen.DBConnection{
dbConn := postgresgen.DBConnection{
Host: host,
Port: port,
User: user,
@ -133,7 +128,11 @@ Example commands:
SchemaName: schemaName,
}
err = postgresgen.Generate(destDir, genData)
err = postgresgen.Generate(
destDir,
dbConn,
genTemplate(postgres2.Dialect, ignoreTablesList, ignoreViewsList, ignoreEnumsList),
)
case "mysql", "mysqlx", "mariadb":
if dsn != "" {
@ -149,12 +148,24 @@ Example commands:
DBName: dbName,
}
err = mysqlgen.Generate(destDir, dbConn)
err = mysqlgen.Generate(
destDir,
dbConn,
genTemplate(mysql.Dialect, ignoreTablesList, ignoreViewsList, ignoreEnumsList),
)
case "sqlite":
if dsn == "" {
printErrorAndExit("ERROR: required -dsn flag missing.")
}
err = sqlitegen.GenerateDSN(dsn, destDir)
err = sqlitegen.GenerateDSN(
dsn,
destDir,
genTemplate(sqlite.Dialect, ignoreTablesList, ignoreViewsList, ignoreEnumsList),
)
case "":
printErrorAndExit("ERROR: required -source or -dns flag missing.")
default:
printErrorAndExit("ERROR: unknown data source " + source + ". Only postgres, mysql, mariadb and sqlite are supported.")
}
@ -167,10 +178,19 @@ Example commands:
func printErrorAndExit(error string) {
fmt.Println("\n", error)
fmt.Println()
flag.Usage()
os.Exit(-2)
}
func getSource() string {
if source != "" {
return strings.TrimSpace(strings.ToLower(source))
}
return detectSchema(dsn)
}
func detectSchema(dsn string) string {
match := strings.SplitN(dsn, "://", 2)
if len(match) < 2 { // not found
@ -183,5 +203,75 @@ func detectSchema(dsn string) string {
return "sqlite"
}
return match[0]
return strings.ToLower(match[0])
}
func parseList(list string) []string {
ret := strings.Split(list, ",")
for i := 0; i < len(ret); i++ {
ret[i] = strings.ToLower(strings.TrimSpace(ret[i]))
}
return ret
}
func genTemplate(dialect jet.Dialect, ignoreTables []string, ignoreViews []string, ignoreEnums []string) template.Template {
shouldSkipTable := func(table metadata.Table) bool {
return utils.StringSliceContains(ignoreTables, strings.ToLower(table.Name))
}
shouldSkipView := func(view metadata.Table) bool {
return utils.StringSliceContains(ignoreViews, strings.ToLower(view.Name))
}
shouldSkipEnum := func(enum metadata.Enum) bool {
return utils.StringSliceContains(ignoreEnums, strings.ToLower(enum.Name))
}
return template.Default(dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().
UseTable(func(table metadata.Table) template.TableModel {
if shouldSkipTable(table) {
return template.TableModel{Skip: true}
}
return template.DefaultTableModel(table)
}).
UseView(func(view metadata.Table) template.ViewModel {
if shouldSkipView(view) {
return template.ViewModel{Skip: true}
}
return template.DefaultViewModel(view)
}).
UseEnum(func(enum metadata.Enum) template.EnumModel {
if shouldSkipEnum(enum) {
return template.EnumModel{Skip: true}
}
return template.DefaultEnumModel(enum)
}),
).
UseSQLBuilder(template.DefaultSQLBuilder().
UseTable(func(table metadata.Table) template.TableSQLBuilder {
if shouldSkipTable(table) {
return template.TableSQLBuilder{Skip: true}
}
return template.DefaultTableSQLBuilder(table)
}).
UseView(func(table metadata.Table) template.ViewSQLBuilder {
if shouldSkipView(table) {
return template.ViewSQLBuilder{Skip: true}
}
return template.DefaultViewSQLBuilder(table)
}).
UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder {
if shouldSkipEnum(enum) {
return template.EnumSQLBuilder{Skip: true}
}
return template.DefaultEnumSQLBuilder(enum)
}),
)
})
}

151
doc.go
View file

@ -1,77 +1,156 @@
/*
Package jet is a framework for writing type-safe SQL queries in Go, with ability to easily convert database query
result into desired arbitrary object structure.
Package jet is a complete solution for efficient and high performance database access, consisting of type-safe SQL builder
with code generation and automatic query result data mapping.
Jet currently supports PostgreSQL, MySQL, MariaDB and SQLite. Future releases will add support for additional databases.
Installation
Use the bellow command to add jet as a dependency into go.mod project:
$ go get github.com/go-jet/jet/v2
Use the command bellow to add jet as a dependency into go.mod project:
$ go get -u github.com/go-jet/jet/v2
Use the bellow command to add jet as a dependency into GOPATH project:
$ go get -u github.com/go-jet/jet
Jet generator can be installed in one of the following ways:
Install jet generator to GOPATH bin folder. This will allow generating jet files from the command line.
cd $GOPATH/src/ && GO111MODULE=off go get -u github.com/go-jet/jet/cmd/jet
1) (Go1.16+) Install jet generator using go install:
go install github.com/go-jet/jet/v2/cmd/jet@latest
2) Install jet generator to GOPATH/bin folder:
cd $GOPATH/src/ && GO111MODULE=off go get -u github.com/go-jet/jet/cmd/jet
3) Install jet generator into specific folder:
git clone https://github.com/go-jet/jet.git
cd jet && go build -o dir_path ./cmd/jet
Make sure that the destination folder is added to the PATH environment variable.
Make sure GOPATH bin folder is added to the PATH environment variable.
Usage
Jet requires already defined database schema(with tables, enums etc), so that jet generator can generate SQL Builder
and Model files. File generation is very fast, and can be added as every pre-build step.
Sample command:
jet -source=PostgreSQL -host=localhost -port=5432 -user=jet -password=pass -dbname=jetdb -schema=dvds -path=./gen
jet -dsn=postgresql://user:pass@localhost:5432/jetdb -schema=dvds -path=./.gen
Then next step is to import generated SQL Builder and Model files and write SQL queries in Go:
Before we can write SQL queries in Go, we need to import generated SQL builder and model types:
import . "some_path/.gen/jetdb/dvds/table"
import "some_path/.gen/jetdb/dvds/model"
To write SQL queries for PostgreSQL import:
. "github.com/go-jet/jet/v2/postgres"
To write postgres SQL queries we import:
. "github.com/go-jet/jet/v2/postgres" // Dot import is used so that Go code resemble as much as native SQL. It is not mandatory.
To write SQL queries for MySQL and MariaDB import:
. "github.com/go-jet/jet/v2/mysql"
*Dot import is used so that Go code resemble as much as native SQL. Dot import is not mandatory.
Write SQL:
Then we can write the SQL query:
// sub-query
rRatingFilms := SELECT(
Film.FilmID,
Film.Title,
Film.Rating,
).
FROM(Film).
WHERE(Film.Rating.EQ(enum.FilmRating.R)).
AsTable("rFilms")
rRatingFilms :=
SELECT(
Film.FilmID,
Film.Title,
Film.Rating,
).FROM(
Film,
).WHERE(
Film.Rating.EQ(enum.FilmRating.R),
).AsTable("rFilms")
// export column from sub-query
rFilmID := Film.FilmID.From(rRatingFilms)
// main-query
query := SELECT(
stmt :=
SELECT(
Actor.AllColumns,
FilmActor.AllColumns,
rRatingFilms.AllColumns(),
).
FROM(
).FROM(
rRatingFilms.
INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(rFilmID)).
INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)
).
ORDER_BY(rFilmID, Actor.ActorID)
INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(rFilmID)).
INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)
).ORDER_BY(
rFilmID,
Actor.ActorID,
)
Store result into desired destination:
Now we can run the statement and store the result into desired destination:
var dest []struct {
model.Film
Actors []model.Actor
}
err := query.Query(db, &dest)
err := stmt.Query(db, &dest)
Detail info about all features and use cases can be
We can print a statement to see SQL query and arguments sent to postgres server:
fmt.Println(stmt.Sql())
Output:
SELECT "rFilms"."film.film_id" AS "film.film_id",
"rFilms"."film.title" AS "film.title",
"rFilms"."film.rating" AS "film.rating",
actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update",
film_actor.actor_id AS "film_actor.actor_id",
film_actor.film_id AS "film_actor.film_id",
film_actor.last_update AS "film_actor.last_update"
FROM (
SELECT film.film_id AS "film.film_id",
film.title AS "film.title",
film.rating AS "film.rating"
FROM dvds.film
WHERE film.rating = 'R'
) AS "rFilms"
INNER JOIN dvds.film_actor ON (film_actor.film_id = "rFilms"."film.film_id")
INNER JOIN dvds.actor ON (film_actor.actor_id = actor.actor_id)
WHERE "rFilms"."film.film_id" < $1
ORDER BY "rFilms"."film.film_id" ASC, actor.actor_id ASC;
[50]
If we print destination as json, we'll get:
[
{
"FilmID": 8,
"Title": "Airport Pollock",
"Rating": "R",
"Actors": [
{
"ActorID": 55,
"FirstName": "Fay",
"LastName": "Kilmer",
"LastUpdate": "2013-05-26T14:47:57.62Z"
},
{
"ActorID": 96,
"FirstName": "Gene",
"LastName": "Willis",
"LastUpdate": "2013-05-26T14:47:57.62Z"
},
...
]
},
{
"FilmID": 17,
"Title": "Alone Trip",
"Actors": [
{
"ActorID": 3,
"FirstName": "Ed",
"LastName": "Chase",
"LastUpdate": "2013-05-26T14:47:57.62Z"
},
{
"ActorID": 12,
"FirstName": "Karl",
"LastName": "Berry",
"LastUpdate": "2013-05-26T14:47:57.62Z"
},
...
...
]
Detail info about all statements, features and use cases can be
found at project wiki page - https://github.com/go-jet/jet/wiki.
*/
package jet

View file

@ -20,7 +20,7 @@ WHERE table_schema = ? and table_type = ?;
`
var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
throw.OnError(err)
for i := range tables {
@ -32,15 +32,14 @@ WHERE table_schema = ? and table_type = ?;
func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column {
query := `
WITH primaryKeys AS (
SELECT k.column_name
FROM information_schema.table_constraints t
JOIN information_schema.key_column_usage k USING(constraint_name,table_schema,table_name)
WHERE table_schema = ? AND table_name = ? AND t.constraint_type='PRIMARY KEY'
)
SELECT COLUMN_NAME AS "column.Name",
IS_NULLABLE = "YES" AS "column.IsNullable",
(EXISTS(SELECT 1 FROM primaryKeys AS pk WHERE pk.column_name = columns.column_name)) AS "column.IsPrimaryKey",
(EXISTS(
SELECT 1
FROM information_schema.table_constraints t
JOIN information_schema.key_column_usage k USING(constraint_name,table_schema,table_name)
WHERE table_schema = ? AND table_name = ? AND t.constraint_type='PRIMARY KEY' AND k.column_name = columns.column_name
)) AS "column.IsPrimaryKey",
IF (COLUMN_TYPE = 'tinyint(1)',
'boolean',
IF (DATA_TYPE='enum',
@ -54,7 +53,7 @@ WHERE table_schema = ? AND table_name = ?
ORDER BY ordinal_position;
`
var columns []metadata.Column
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns)
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns)
throw.OnError(err)
return columns
@ -73,7 +72,7 @@ WHERE c.table_schema = ? AND DATA_TYPE = 'enum';
Values string
}
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult)
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult)
throw.OnError(err)
var ret []metadata.Enum

View file

@ -19,7 +19,7 @@ WHERE table_schema = $1 and table_type = $2;
`
var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
throw.OnError(err)
for i := range tables {
@ -58,7 +58,7 @@ where table_schema = $1 and table_name = $2
order by ordinal_position;
`
var columns []metadata.Column
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns)
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns)
throw.OnError(err)
return columns
@ -76,7 +76,7 @@ ORDER BY n.nspname, t.typname, e.enumsortorder;`
var result []metadata.Enum
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result)
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result)
throw.OnError(err)
return result

View file

@ -28,7 +28,7 @@ func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTy
var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables)
_, err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables)
throw.OnError(err)
for i := range tables {
@ -47,7 +47,7 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t
Pk int32
}
err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos)
_, err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos)
throw.OnError(err)
var columns []metadata.Column

View file

@ -13,7 +13,12 @@ func newAlias(expression Expression, aliasName string) Projection {
}
func (a *alias) fromImpl(subQuery SelectTable) Projection {
column := NewColumnImpl(a.alias, "", nil)
// if alias is in the form "table.column", we break it into two parts so that ProjectionList.As(newAlias) can
// overwrite tableName with a new alias. This method is called only for exporting aliased custom columns.
// Generated columns have default aliasing.
tableName, columnName := extractTableAndColumnName(a.alias)
column := NewColumnImpl(columnName, tableName, nil)
column.subQuery = subQuery
return &column

View file

@ -6,7 +6,6 @@ import (
func TestBoolExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)")
assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: rhs is nil for '=' operator")
}
func TestBoolExpressionNOT_EQ(t *testing.T) {

View file

@ -18,8 +18,9 @@ type ClauseWithProjections interface {
// ClauseSelect struct
type ClauseSelect struct {
Distinct bool
ProjectionList []Projection
Distinct bool
DistinctOnColumns []ColumnExpression
ProjectionList []Projection
}
// Projections returns list of projections for select clause
@ -36,6 +37,12 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o
out.WriteString("DISTINCT")
}
if len(s.DistinctOnColumns) > 0 {
out.WriteString("ON (")
SerializeColumnExpressions(s.DistinctOnColumns, statementType, out)
out.WriteByte(')')
}
if len(s.ProjectionList) == 0 {
panic("jet: SELECT clause has to have at least one projection")
}
@ -45,6 +52,7 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o
// ClauseFrom struct
type ClauseFrom struct {
Name string
Tables []Serializer
}
@ -54,7 +62,11 @@ func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder, opt
return
}
out.NewLine()
out.WriteString("FROM")
if f.Name != "" {
out.WriteString(f.Name)
} else {
out.WriteString("FROM")
}
out.IncreaseIdent()
for i, table := range f.Tables {

View file

@ -13,6 +13,8 @@ type DateExpression interface {
LT_EQ(rhs DateExpression) BoolExpression
GT(rhs DateExpression) BoolExpression
GT_EQ(rhs DateExpression) BoolExpression
BETWEEN(min, max DateExpression) BoolExpression
NOT_BETWEEN(min, max DateExpression) BoolExpression
ADD(rhs Interval) TimestampExpression
SUB(rhs Interval) TimestampExpression
@ -54,6 +56,14 @@ func (d *dateInterfaceImpl) GT_EQ(rhs DateExpression) BoolExpression {
return GtEq(d.parent, rhs)
}
func (d *dateInterfaceImpl) BETWEEN(min, max DateExpression) BoolExpression {
return NewBetweenOperatorExpression(d.parent, min, max, false)
}
func (d *dateInterfaceImpl) NOT_BETWEEN(min, max DateExpression) BoolExpression {
return NewBetweenOperatorExpression(d.parent, min, max, true)
}
func (d *dateInterfaceImpl) ADD(rhs Interval) TimestampExpression {
return TimestampExp(Add(d.parent, rhs))
}

View file

@ -1,5 +1,7 @@
package jet
import "fmt"
// Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface {
@ -33,7 +35,8 @@ type ExpressionInterfaceImpl struct {
}
func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection {
return e.Parent
panic(fmt.Sprintf("jet: can't export unaliased expression subQuery: %s, expression: %s",
subQuery.Alias(), serializeToDefaultDebugString(e.Parent)))
}
// IS_NULL tests expression whether it is a NULL value.
@ -93,7 +96,7 @@ type binaryOperatorExpression struct {
}
// NewBinaryOperatorExpression creates new binaryOperatorExpression
func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) *binaryOperatorExpression {
func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) Expression {
binaryExpression := &binaryOperatorExpression{
lhs: lhs,
rhs: rhs,
@ -106,23 +109,10 @@ func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additiona
binaryExpression.ExpressionInterfaceImpl.Parent = binaryExpression
return binaryExpression
return complexExpr(binaryExpression)
}
func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.lhs == nil {
panic("jet: lhs is nil for '" + c.operator + "' operator")
}
if c.rhs == nil {
panic("jet: rhs is nil for '" + c.operator + "' operator")
}
wrap := !contains(options, NoWrap)
if wrap {
out.WriteString("(")
}
if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam)
serializeOverrideFunc(statement, out, FallTrough(options)...)
@ -131,10 +121,6 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu
out.WriteString(c.operator)
c.rhs.serialize(statement, out, FallTrough(options)...)
}
if wrap {
out.WriteString(")")
}
}
// A prefix operator Expression
@ -145,27 +131,19 @@ type prefixExpression struct {
operator string
}
func newPrefixOperatorExpression(expression Expression, operator string) *prefixExpression {
func newPrefixOperatorExpression(expression Expression, operator string) Expression {
prefixExpression := &prefixExpression{
expression: expression,
operator: operator,
}
prefixExpression.ExpressionInterfaceImpl.Parent = prefixExpression
return prefixExpression
return complexExpr(prefixExpression)
}
func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
out.WriteString(p.operator)
if p.expression == nil {
panic("jet: nil prefix expression in prefix operator " + p.operator)
}
p.expression.serialize(statement, out, FallTrough(options)...)
out.WriteString(")")
}
// A postfix operator Expression
@ -188,11 +166,77 @@ func newPostfixOperatorExpression(expression Expression, operator string) *postf
}
func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if p.expression == nil {
panic("jet: nil prefix expression in postfix operator " + p.operator)
}
p.expression.serialize(statement, out, FallTrough(options)...)
out.WriteString(p.operator)
}
type betweenOperatorExpression struct {
ExpressionInterfaceImpl
expression Expression
notBetween bool
min Expression
max Expression
}
// NewBetweenOperatorExpression creates new BETWEEN operator expression
func NewBetweenOperatorExpression(expression, min, max Expression, notBetween bool) BoolExpression {
newBetweenOperator := &betweenOperatorExpression{
expression: expression,
notBetween: notBetween,
min: min,
max: max,
}
newBetweenOperator.ExpressionInterfaceImpl.Parent = newBetweenOperator
return BoolExp(complexExpr(newBetweenOperator))
}
func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
p.expression.serialize(statement, out, FallTrough(options)...)
if p.notBetween {
out.WriteString("NOT")
}
out.WriteString("BETWEEN")
p.min.serialize(statement, out, FallTrough(options)...)
out.WriteString("AND")
p.max.serialize(statement, out, FallTrough(options)...)
}
type complexExpression struct {
ExpressionInterfaceImpl
expressions Expression
}
func complexExpr(expressions Expression) Expression {
complexExpression := &complexExpression{expressions: expressions}
complexExpression.ExpressionInterfaceImpl.Parent = complexExpression
return complexExpression
}
func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, NoWrap) {
out.WriteString("(")
}
s.expressions.serialize(statement, out, options...) // FallTrough here because complexExpression is just a wrapper
if !contains(options, NoWrap) {
out.WriteString(")")
}
}
type skipParenthesisWrap struct {
Expression
}
func skipWrap(expression Expression) Expression {
return &skipParenthesisWrap{expression}
}
// since the expression is a function parameter, there is no need to wrap it in parentheses
func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.Expression.serialize(statement, out, append(options, NoWrap)...)
}

View file

@ -4,10 +4,6 @@ import (
"testing"
)
func TestInvalidExpression(t *testing.T) {
assertClauseSerializeErr(t, table2Col3.ADD(nil), `jet: rhs is nil for '+' operator`)
}
func TestExpressionIS_NULL(t *testing.T) {
assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL")
assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL")

View file

@ -14,6 +14,8 @@ type FloatExpression interface {
LT_EQ(rhs FloatExpression) BoolExpression
GT(rhs FloatExpression) BoolExpression
GT_EQ(rhs FloatExpression) BoolExpression
BETWEEN(min, max FloatExpression) BoolExpression
NOT_BETWEEN(min, max FloatExpression) BoolExpression
ADD(rhs NumericExpression) FloatExpression
SUB(rhs NumericExpression) FloatExpression
@ -60,6 +62,14 @@ func (n *floatInterfaceImpl) LT_EQ(rhs FloatExpression) BoolExpression {
return LtEq(n.parent, rhs)
}
func (n *floatInterfaceImpl) BETWEEN(min, max FloatExpression) BoolExpression {
return NewBetweenOperatorExpression(n.parent, min, max, false)
}
func (n *floatInterfaceImpl) NOT_BETWEEN(min, max FloatExpression) BoolExpression {
return NewBetweenOperatorExpression(n.parent, min, max, true)
}
func (n *floatInterfaceImpl) ADD(rhs NumericExpression) FloatExpression {
return FloatExp(Add(n.parent, rhs))
}

View file

@ -81,7 +81,7 @@ func LOG(floatExpression FloatExpression) FloatExpression {
// ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression
func AVG(numericExpression NumericExpression) floatWindowExpression {
func AVG(numericExpression Expression) floatWindowExpression {
return NewFloatWindowFunc("AVG", numericExpression)
}
@ -594,7 +594,7 @@ type funcExpressionImpl struct {
func NewFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl {
funcExp := &funcExpressionImpl{
name: name,
expressions: expressions,
expressions: parameters(expressions),
}
if parent != nil {
@ -606,9 +606,22 @@ func NewFunc(name string, expressions []Expression, parent Expression) *funcExpr
return funcExp
}
func parameters(expressions []Expression) []Expression {
var ret []Expression
for _, expression := range expressions {
if _, isStatement := expression.(Statement); isStatement {
ret = append(ret, expression)
} else {
ret = append(ret, skipWrap(expression))
}
}
return ret
}
// NewFloatWindowFunc creates new float function with name and expressions
func newWindowFunc(name string, expressions ...Expression) windowExpression {
newFun := NewFunc(name, expressions, nil)
windowExpr := newWindowExpression(newFun)
newFun.ExpressionInterfaceImpl.Parent = windowExpr
@ -698,12 +711,12 @@ type integerFunc struct {
}
func newIntegerFunc(name string, expressions ...Expression) IntegerExpression {
floatFunc := &integerFunc{}
intFunc := &integerFunc{}
floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc)
floatFunc.integerInterfaceImpl.parent = floatFunc
intFunc.funcExpressionImpl = *NewFunc(name, expressions, intFunc)
intFunc.integerInterfaceImpl.parent = intFunc
return floatFunc
return intFunc
}
// NewFloatWindowFunc creates new float function with name and expressions
@ -806,7 +819,7 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc {
return timestampzFunc
}
// Func can be used to call an custom or as of yet unsupported function in the database.
// Func can be used to call custom or unsupported database functions.
func Func(name string, expressions ...Expression) Expression {
return NewFunc(name, expressions, nil)
}

View file

@ -5,46 +5,29 @@ type IntegerExpression interface {
Expression
numericExpression
// Check if expression is equal to rhs
EQ(rhs IntegerExpression) BoolExpression
// Check if expression is not equal to rhs
NOT_EQ(rhs IntegerExpression) BoolExpression
// Check if expression is distinct from rhs
IS_DISTINCT_FROM(rhs IntegerExpression) BoolExpression
// Check if expression is not distinct from rhs
IS_NOT_DISTINCT_FROM(rhs IntegerExpression) BoolExpression
// Check if expression is less then rhs
LT(rhs IntegerExpression) BoolExpression
// Check if expression is less then equal rhs
LT_EQ(rhs IntegerExpression) BoolExpression
// Check if expression is greater then rhs
GT(rhs IntegerExpression) BoolExpression
// Check if expression is greater then equal rhs
GT_EQ(rhs IntegerExpression) BoolExpression
BETWEEN(min, max IntegerExpression) BoolExpression
NOT_BETWEEN(min, max IntegerExpression) BoolExpression
// expression + rhs
ADD(rhs IntegerExpression) IntegerExpression
// expression - rhs
SUB(rhs IntegerExpression) IntegerExpression
// expression * rhs
MUL(rhs IntegerExpression) IntegerExpression
// expression / rhs
DIV(rhs IntegerExpression) IntegerExpression
// expression % rhs
MOD(rhs IntegerExpression) IntegerExpression
// expression ^ rhs
POW(rhs IntegerExpression) IntegerExpression
// expression & rhs
BIT_AND(rhs IntegerExpression) IntegerExpression
// expression | rhs
BIT_OR(rhs IntegerExpression) IntegerExpression
// expression # rhs
BIT_XOR(rhs IntegerExpression) IntegerExpression
// expression << rhs
BIT_SHIFT_LEFT(shift IntegerExpression) IntegerExpression
// expression >> rhs
BIT_SHIFT_RIGHT(shift IntegerExpression) IntegerExpression
}
@ -85,6 +68,14 @@ func (i *integerInterfaceImpl) LT_EQ(rhs IntegerExpression) BoolExpression {
return LtEq(i.parent, rhs)
}
func (i *integerInterfaceImpl) BETWEEN(min, max IntegerExpression) BoolExpression {
return NewBetweenOperatorExpression(i.parent, min, max, false)
}
func (i *integerInterfaceImpl) NOT_BETWEEN(min, max IntegerExpression) BoolExpression {
return NewBetweenOperatorExpression(i.parent, min, max, true)
}
func (i *integerInterfaceImpl) ADD(rhs IntegerExpression) IntegerExpression {
return IntExp(Add(i.parent, rhs))
}

View file

@ -99,3 +99,9 @@ func TestIntExpressionIntExp(t *testing.T) {
assertClauseSerialize(t, IntExp(table1ColFloat.ADD(table2ColFloat)).ADD(Int(11)),
"((table1.col_float + table2.col_float) + $1)", int64(11))
}
func TestIntExpressionBetween(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BETWEEN(Int(1), table1Col3), "(table1.col_int BETWEEN $1 AND table1.col3)", int64(1))
assertClauseSerialize(t, table1ColInt.BETWEEN(Int(1), table1Col3).AND(table1ColBool),
"((table1.col_int BETWEEN $1 AND table1.col3) AND table1.col_bool)", int64(1))
}

View file

@ -1,6 +1,11 @@
package jet
import "context"
import (
"context"
"runtime"
"strings"
"time"
)
// PrintableStatement is a statement which sql query can be logged
type PrintableStatement interface {
@ -8,7 +13,7 @@ type PrintableStatement interface {
DebugSql() (query string)
}
// LoggerFunc is a definition of a function user can implement to support automatic statement logging.
// LoggerFunc is a function user can implement to support automatic statement logging.
type LoggerFunc func(ctx context.Context, statement PrintableStatement)
var logger LoggerFunc
@ -17,3 +22,60 @@ var logger LoggerFunc
func SetLoggerFunc(loggerFunc LoggerFunc) {
logger = loggerFunc
}
func callLogger(ctx context.Context, statement Statement) {
if logger != nil {
logger(ctx, statement)
}
}
// QueryInfo contains information about executed query
type QueryInfo struct {
Statement PrintableStatement
// Depending on how the statement is executed, RowsProcessed is:
// - Number of rows returned for Query() and QueryContext() methods
// - RowsAffected() for Exec() and ExecContext() methods
// - Always 0 for Rows() method.
RowsProcessed int64
Duration time.Duration
Err error
}
// QueryLoggerFunc is a function user can implement to retrieve more information about statement executed.
type QueryLoggerFunc func(ctx context.Context, info QueryInfo)
var queryLoggerFunc QueryLoggerFunc
// SetQueryLogger sets automatic query logging function.
func SetQueryLogger(loggerFunc QueryLoggerFunc) {
queryLoggerFunc = loggerFunc
}
func callQueryLoggerFunc(ctx context.Context, info QueryInfo) {
if queryLoggerFunc != nil {
queryLoggerFunc(ctx, info)
}
}
// Caller returns information about statement caller
func (q QueryInfo) Caller() (file string, line int, function string) {
skip := 4
// depending on execution type (Query, QueryContext, Exec, ...) looped once or twice
for {
var pc uintptr
var ok bool
pc, file, line, ok = runtime.Caller(skip)
if !ok {
return
}
funcDetails := runtime.FuncForPC(pc)
if !strings.Contains(funcDetails.Name(), "github.com/go-jet/jet/v2/internal") {
function = funcDetails.Name()
return
}
skip++
}
}

View file

@ -173,3 +173,8 @@ func (c *caseOperatorImpl) serialize(statement StatementType, out *SQLBuilder, o
out.WriteString("END)")
}
// DISTINCT operator can be used to return distinct values of expr
func DISTINCT(expr Expression) Expression {
return newPrefixOperatorExpression(expr, "DISTINCT")
}

View file

@ -0,0 +1,60 @@
package jet
// MODE computes the most frequent value of the aggregated argument
func MODE() *OrderSetAggregateFunc {
return newOrderSetAggregateFunction("MODE", nil)
}
// PERCENTILE_CONT computes a value corresponding to the specified fraction within the ordered set of
// aggregated argument values. This will interpolate between adjacent input items if needed.
func PERCENTILE_CONT(fraction FloatExpression) *OrderSetAggregateFunc {
return newOrderSetAggregateFunction("PERCENTILE_CONT", fraction)
}
// PERCENTILE_DISC computes the first value within the ordered set of aggregated argument values whose position
// in the ordering equals or exceeds the specified fraction. The aggregated argument must be of a sortable type.
func PERCENTILE_DISC(fraction FloatExpression) *OrderSetAggregateFunc {
return newOrderSetAggregateFunction("PERCENTILE_DISC", fraction)
}
// OrderSetAggregateFunc implementation of order set aggregate function
type OrderSetAggregateFunc struct {
name string
fraction FloatExpression
orderBy Window
}
func newOrderSetAggregateFunction(name string, fraction FloatExpression) *OrderSetAggregateFunc {
return &OrderSetAggregateFunc{
name: name,
fraction: fraction,
}
}
// WITHIN_GROUP_ORDER_BY specifies ordered set of aggregated argument values
func (p *OrderSetAggregateFunc) WITHIN_GROUP_ORDER_BY(orderBy OrderByClause) Expression {
p.orderBy = ORDER_BY(orderBy)
return newOrderSetAggregateFuncExpression(*p)
}
func newOrderSetAggregateFuncExpression(aggFunc OrderSetAggregateFunc) *orderSetAggregateFuncExpression {
ret := &orderSetAggregateFuncExpression{
OrderSetAggregateFunc: aggFunc,
}
ret.ExpressionInterfaceImpl.Parent = ret
return ret
}
type orderSetAggregateFuncExpression struct {
ExpressionInterfaceImpl
OrderSetAggregateFunc
}
func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(p.name)
WRAP(p.fraction).serialize(statement, out, FallTrough(options)...)
out.WriteString("WITHIN GROUP")
p.orderBy.serialize(statement, out)
}

View file

@ -1,5 +1,7 @@
package jet
import "strings"
// Projection is interface for all projection types. Types that can be part of, for instance SELECT clause.
type Projection interface {
serializeForProjection(statement StatementType, out *SQLBuilder)
@ -14,16 +16,68 @@ func SerializeForProjection(projection Projection, statementType StatementType,
// ProjectionList is a redefined type, so that ProjectionList can be used as a Projection.
type ProjectionList []Projection
func (cl ProjectionList) fromImpl(subQuery SelectTable) Projection {
func (pl ProjectionList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, projection := range cl {
for _, projection := range pl {
newProjectionList = append(newProjectionList, projection.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ProjectionList) serializeForProjection(statement StatementType, out *SQLBuilder) {
SerializeProjectionList(statement, cl, out)
func (pl ProjectionList) serializeForProjection(statement StatementType, out *SQLBuilder) {
SerializeProjectionList(statement, pl, out)
}
// As will create new projection list where each column is wrapped with a new table alias.
// tableAlias should be in the form 'name' or 'name.*'.
// For instance: If projection list has a column 'Artist.Name', and tableAlias is 'Musician.*', returned projection list will
// have a column wrapped in alias 'Musician.Name'.
func (pl ProjectionList) As(tableAlias string) ProjectionList {
tableAlias = strings.TrimRight(tableAlias, ".*")
newProjectionList := ProjectionList{}
for _, projection := range pl {
switch p := projection.(type) {
case ProjectionList:
newProjectionList = append(newProjectionList, p.As(tableAlias))
case ColumnExpression:
newProjectionList = append(newProjectionList, newAlias(p, tableAlias+"."+p.Name()))
case *alias:
newAlias := *p
_, columnName := extractTableAndColumnName(newAlias.alias)
newAlias.alias = tableAlias + "." + columnName
newProjectionList = append(newProjectionList, &newAlias)
}
}
return newProjectionList
}
// Except will create new projection list in which columns contained in excluded column names are removed
func (pl ProjectionList) Except(toExclude ...Column) ProjectionList {
excludedColumnList := UnwidColumnList(toExclude)
excludedColumnNames := map[string]bool{}
for _, excludedColumn := range excludedColumnList {
excludedColumnNames[excludedColumn.Name()] = true
}
var ret ProjectionList
for _, projection := range pl {
switch p := projection.(type) {
case ProjectionList:
ret = append(ret, p.Except(toExclude...))
case ColumnExpression:
if excludedColumnNames[p.Name()] {
continue
}
ret = append(ret, p)
}
}
return ret
}

View file

@ -0,0 +1,46 @@
package jet
import "testing"
func TestProjectionAs(t *testing.T) {
projectionList := ProjectionList{
table1Col3,
SUM(table1ColInt).AS("sum"),
SUM(table1ColInt).AS("table.sum"),
ProjectionList{
table1ColBool,
AVG(table1ColInt).AS("avg"),
AVG(table1ColInt).AS("t.avg"),
},
}
aliasedProjectionList := projectionList.As("new_alias.*")
assertProjectionSerialize(t, aliasedProjectionList,
`table1.col3 AS "new_alias.col3",
SUM(table1.col_int) AS "new_alias.sum",
SUM(table1.col_int) AS "new_alias.sum",
table1.col_bool AS "new_alias.col_bool",
AVG(table1.col_int) AS "new_alias.avg",
AVG(table1.col_int) AS "new_alias.avg"`)
subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery"))
assertProjectionSerialize(t, subQueryProjections,
`"subQuery"."table1.col3" AS "table1.col3",
"subQuery".sum AS "sum",
"subQuery"."table.sum" AS "table.sum",
"subQuery"."table1.col_bool" AS "table1.col_bool",
"subQuery".avg AS "avg",
"subQuery"."t.avg" AS "t.avg"`)
aliasedSubQueryProjectionList := subQueryProjections.(ProjectionList).As("subAlias")
assertProjectionSerialize(t, aliasedSubQueryProjectionList,
`"subQuery"."table1.col3" AS "subAlias.col3",
"subQuery".sum AS "subAlias.sum",
"subQuery"."table.sum" AS "subAlias.sum",
"subQuery"."table1.col_bool" AS "subAlias.col_bool",
"subQuery".avg AS "subAlias.avg",
"subQuery"."t.avg" AS "subAlias.avg"`)
}

View file

@ -2,38 +2,41 @@ package jet
// SelectTable is interface for SELECT sub-queries
type SelectTable interface {
Serializer
SerializerHasProjections
Alias() string
AllColumns() ProjectionList
}
type selectTableImpl struct {
selectStmt SerializerStatement
alias string
Statement SerializerHasProjections
alias string
}
// NewSelectTable func
func NewSelectTable(selectStmt SerializerStatement, alias string) selectTableImpl {
selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias}
func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl {
selectTable := selectTableImpl{
Statement: selectStmt,
alias: alias,
}
return selectTable
}
func (s selectTableImpl) projections() ProjectionList {
return s.Statement.projections()
}
func (s selectTableImpl) Alias() string {
return s.alias
}
func (s selectTableImpl) AllColumns() ProjectionList {
statementWithProjections, ok := s.selectStmt.(HasProjections)
if !ok {
return ProjectionList{}
}
projectionList := statementWithProjections.projections().fromImpl(s)
projectionList := s.projections().fromImpl(s)
return projectionList.(ProjectionList)
}
func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.selectStmt.serialize(statement, out)
s.Statement.serialize(statement, out)
out.WriteString("AS")
out.WriteIdentifier(s.alias)
@ -52,7 +55,7 @@ func NewLateral(selectStmt SerializerStatement, alias string) SelectTable {
func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("LATERAL")
s.selectStmt.serialize(statement, out)
s.Statement.serialize(statement, out)
out.WriteString("AS")
out.WriteIdentifier(s.alias)

View file

@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"github.com/go-jet/jet/v2/qrm"
"time"
)
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
@ -21,9 +22,9 @@ type Statement interface {
// Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error
//Exec executes statement over db connection/transaction without returning any rows.
// Exec executes statement over db connection/transaction without returning any rows.
Exec(db qrm.DB) (sql.Result, error)
//Exec executes statement with context over db connection/transaction without returning any rows.
// ExecContext executes statement with context over db connection/transaction without returning any rows.
ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error)
// Rows executes statements over db connection/transaction and returns rows
Rows(ctx context.Context, db qrm.DB) (*Rows, error)
@ -51,6 +52,12 @@ type HasProjections interface {
projections() ProjectionList
}
// SerializerHasProjections interface is combination of Serializer and HasProjections interface
type SerializerHasProjections interface {
Serializer
HasProjections
}
// serializerStatementInterfaceImpl struct
type serializerStatementInterfaceImpl struct {
dialect Dialect
@ -78,12 +85,7 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
}
func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error {
query, args := s.Sql()
ctx := context.Background()
callLogger(ctx, s)
return qrm.Query(ctx, db, query, args, destination)
return s.QueryContext(context.Background(), db, destination)
}
func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error {
@ -91,15 +93,25 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db
callLogger(ctx, s)
return qrm.Query(ctx, db, query, args, destination)
var rowsProcessed int64
var err error
duration := duration(func() {
rowsProcessed, err = qrm.Query(ctx, db, query, args, destination)
})
callQueryLoggerFunc(ctx, QueryInfo{
Statement: s,
RowsProcessed: rowsProcessed,
Duration: duration,
Err: err,
})
return err
}
func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) {
query, args := s.Sql()
callLogger(context.Background(), s)
return db.Exec(query, args...)
return s.ExecContext(context.Background(), db)
}
func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) {
@ -107,7 +119,24 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q
callLogger(ctx, s)
return db.ExecContext(ctx, query, args...)
duration := duration(func() {
res, err = db.ExecContext(ctx, query, args...)
})
var rowsAffected int64
if err == nil {
rowsAffected, _ = res.RowsAffected()
}
callQueryLoggerFunc(ctx, QueryInfo{
Statement: s,
RowsProcessed: rowsAffected,
Duration: duration,
Err: err,
})
return res, err
}
func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) (*Rows, error) {
@ -115,7 +144,18 @@ func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB)
callLogger(ctx, s)
rows, err := db.QueryContext(ctx, query, args...)
var rows *sql.Rows
var err error
duration := duration(func() {
rows, err = db.QueryContext(ctx, query, args...)
})
callQueryLoggerFunc(ctx, QueryInfo{
Statement: s,
Duration: duration,
Err: err,
})
if err != nil {
return nil, err
@ -124,10 +164,12 @@ func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB)
return &Rows{rows}, nil
}
func callLogger(ctx context.Context, statement Statement) {
if logger != nil {
logger(ctx, statement)
}
func duration(f func()) time.Duration {
start := time.Now()
f()
return time.Now().Sub(start)
}
// ExpressionStatement interfacess
@ -200,7 +242,7 @@ func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, opti
}
for _, clause := range s.Clauses {
clause.Serialize(statement, out, FallTrough(options)...)
clause.Serialize(s.statementType, out, FallTrough(options)...)
}
if contains(options, Ident) {

View file

@ -13,6 +13,8 @@ type StringExpression interface {
LT_EQ(rhs StringExpression) BoolExpression
GT(rhs StringExpression) BoolExpression
GT_EQ(rhs StringExpression) BoolExpression
BETWEEN(min, max StringExpression) BoolExpression
NOT_BETWEEN(min, max StringExpression) BoolExpression
CONCAT(rhs Expression) StringExpression
@ -59,6 +61,14 @@ func (s *stringInterfaceImpl) LT_EQ(rhs StringExpression) BoolExpression {
return LtEq(s.parent, rhs)
}
func (s *stringInterfaceImpl) BETWEEN(min, max StringExpression) BoolExpression {
return NewBetweenOperatorExpression(s.parent, min, max, false)
}
func (s *stringInterfaceImpl) NOT_BETWEEN(min, max StringExpression) BoolExpression {
return NewBetweenOperatorExpression(s.parent, min, max, true)
}
func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression {
return newBinaryStringOperatorExpression(s.parent, rhs, StringConcatOperator)
}

View file

@ -13,6 +13,8 @@ type TimeExpression interface {
LT_EQ(rhs TimeExpression) BoolExpression
GT(rhs TimeExpression) BoolExpression
GT_EQ(rhs TimeExpression) BoolExpression
BETWEEN(min, max TimeExpression) BoolExpression
NOT_BETWEEN(min, max TimeExpression) BoolExpression
ADD(rhs Interval) TimeExpression
SUB(rhs Interval) TimeExpression
@ -54,6 +56,14 @@ func (t *timeInterfaceImpl) GT_EQ(rhs TimeExpression) BoolExpression {
return GtEq(t.parent, rhs)
}
func (t *timeInterfaceImpl) BETWEEN(min, max TimeExpression) BoolExpression {
return NewBetweenOperatorExpression(t.parent, min, max, false)
}
func (t *timeInterfaceImpl) NOT_BETWEEN(min, max TimeExpression) BoolExpression {
return NewBetweenOperatorExpression(t.parent, min, max, true)
}
func (t *timeInterfaceImpl) ADD(rhs Interval) TimeExpression {
return TimeExp(Add(t.parent, rhs))
}

View file

@ -13,6 +13,8 @@ type TimestampExpression interface {
LT_EQ(rhs TimestampExpression) BoolExpression
GT(rhs TimestampExpression) BoolExpression
GT_EQ(rhs TimestampExpression) BoolExpression
BETWEEN(min, max TimestampExpression) BoolExpression
NOT_BETWEEN(min, max TimestampExpression) BoolExpression
ADD(rhs Interval) TimestampExpression
SUB(rhs Interval) TimestampExpression
@ -54,6 +56,14 @@ func (t *timestampInterfaceImpl) GT_EQ(rhs TimestampExpression) BoolExpression {
return GtEq(t.parent, rhs)
}
func (t *timestampInterfaceImpl) BETWEEN(min, max TimestampExpression) BoolExpression {
return NewBetweenOperatorExpression(t.parent, min, max, false)
}
func (t *timestampInterfaceImpl) NOT_BETWEEN(min, max TimestampExpression) BoolExpression {
return NewBetweenOperatorExpression(t.parent, min, max, true)
}
func (t *timestampInterfaceImpl) ADD(rhs Interval) TimestampExpression {
return TimestampExp(Add(t.parent, rhs))
}

View file

@ -13,6 +13,8 @@ type TimestampzExpression interface {
LT_EQ(rhs TimestampzExpression) BoolExpression
GT(rhs TimestampzExpression) BoolExpression
GT_EQ(rhs TimestampzExpression) BoolExpression
BETWEEN(min, max TimestampzExpression) BoolExpression
NOT_BETWEEN(min, max TimestampzExpression) BoolExpression
ADD(rhs Interval) TimestampzExpression
SUB(rhs Interval) TimestampzExpression
@ -54,6 +56,14 @@ func (t *timestampzInterfaceImpl) GT_EQ(rhs TimestampzExpression) BoolExpression
return GtEq(t.parent, rhs)
}
func (t *timestampzInterfaceImpl) BETWEEN(min, max TimestampzExpression) BoolExpression {
return NewBetweenOperatorExpression(t.parent, min, max, false)
}
func (t *timestampzInterfaceImpl) NOT_BETWEEN(min, max TimestampzExpression) BoolExpression {
return NewBetweenOperatorExpression(t.parent, min, max, true)
}
func (t *timestampzInterfaceImpl) ADD(rhs Interval) TimestampzExpression {
return TimestampzExp(Add(t.parent, rhs))
}

View file

@ -13,6 +13,8 @@ type TimezExpression interface {
LT_EQ(rhs TimezExpression) BoolExpression
GT(rhs TimezExpression) BoolExpression
GT_EQ(rhs TimezExpression) BoolExpression
BETWEEN(min, max TimezExpression) BoolExpression
NOT_BETWEEN(min, max TimezExpression) BoolExpression
ADD(rhs Interval) TimezExpression
SUB(rhs Interval) TimezExpression
@ -54,6 +56,14 @@ func (t *timezInterfaceImpl) GT_EQ(rhs TimezExpression) BoolExpression {
return GtEq(t.parent, rhs)
}
func (t *timezInterfaceImpl) BETWEEN(min, max TimezExpression) BoolExpression {
return NewBetweenOperatorExpression(t.parent, min, max, false)
}
func (t *timezInterfaceImpl) NOT_BETWEEN(min, max TimezExpression) BoolExpression {
return NewBetweenOperatorExpression(t.parent, min, max, true)
}
func (t *timezInterfaceImpl) ADD(rhs Interval) TimezExpression {
return TimezExp(Add(t.parent, rhs))
}

View file

@ -3,6 +3,7 @@ package jet
import (
"github.com/go-jet/jet/v2/internal/utils"
"reflect"
"strings"
)
// SerializeClauseList func
@ -33,7 +34,9 @@ func serializeExpressionList(
out.WriteString(separator)
}
expression.serialize(statement, out, options...)
if expression != nil {
expression.serialize(statement, out, options...)
}
}
}
@ -68,8 +71,8 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) {
}
}
// SerializeColumnExpressionNames func
func SerializeColumnExpressionNames(columns []ColumnExpression, statementType StatementType,
// SerializeColumnExpressions func
func SerializeColumnExpressions(columns []ColumnExpression, statementType StatementType,
out *SQLBuilder, options ...SerializeOption) {
for i, col := range columns {
if i > 0 {
@ -84,6 +87,21 @@ func SerializeColumnExpressionNames(columns []ColumnExpression, statementType St
}
}
// SerializeColumnExpressionNames func
func SerializeColumnExpressionNames(columns []ColumnExpression, out *SQLBuilder) {
for i, col := range columns {
if i > 0 {
out.WriteString(", ")
}
if col == nil {
panic("jet: nil column in columns list")
}
out.WriteIdentifier(col.Name())
}
}
// ExpressionListToSerializerList converts list of expressions to list of serializers
func ExpressionListToSerializerList(expressions []Expression) []Serializer {
var ret []Serializer
@ -229,3 +247,22 @@ func OptionalOrDefaultExpression(defaultExpression Expression, expression ...Exp
return defaultExpression
}
func extractTableAndColumnName(alias string) (tableName string, columnName string) {
parts := strings.Split(alias, ".")
if len(parts) >= 2 {
tableName = parts[0]
columnName = parts[1]
} else {
columnName = parts[0]
}
return
}
func serializeToDefaultDebugString(expr Serializer) string {
out := SQLBuilder{Dialect: defaultDialect, Debug: true}
expr.serialize(SelectStatementType, &out)
return out.Buff.String()
}

View file

@ -1,9 +1,12 @@
package jet
import "fmt"
// WITH function creates new with statement from list of common table expressions for specified dialect
func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement Statement) Statement {
func WITH(dialect Dialect, recursive bool, cte ...*CommonTableExpression) func(statement Statement) Statement {
newWithImpl := &withImpl{
ctes: cte,
recursive: recursive,
ctes: cte,
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
dialect: dialect,
statementType: WithStatementType,
@ -23,7 +26,8 @@ func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statemen
type withImpl struct {
serializerStatementInterfaceImpl
ctes []CommonTableExpressionDefinition
recursive bool
ctes []*CommonTableExpression
primaryStatement SerializerStatement
}
@ -31,6 +35,10 @@ func (w withImpl) serialize(statement StatementType, out *SQLBuilder, options ..
out.NewLine()
out.WriteString("WITH")
if w.recursive {
out.WriteString("RECURSIVE")
}
for i, cte := range w.ctes {
if i > 0 {
out.WriteString(",")
@ -48,35 +56,55 @@ func (w withImpl) projections() ProjectionList {
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
selectTableImpl
NotMaterialized bool
Columns []ColumnExpression
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
return CommonTableExpression{
selectTableImpl: selectTableImpl{
selectStmt: nil,
alias: name,
},
func CTE(name string, columns ...ColumnExpression) CommonTableExpression {
cte := CommonTableExpression{
selectTableImpl: NewSelectTable(nil, name),
Columns: columns,
}
for _, column := range cte.Columns {
column.setSubQuery(cte)
}
return cte
}
func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.alias)
if statement == WithStatementType { // serialize CTE definition
out.WriteIdentifier(c.alias)
if len(c.Columns) > 0 {
out.WriteByte('(')
SerializeColumnExpressionNames(c.Columns, out)
out.WriteByte(')')
}
out.WriteString("AS")
if c.NotMaterialized {
out.WriteString("NOT MATERIALIZED")
}
if c.Statement == nil {
panic(fmt.Sprintf("jet: '%s' CTE is not defined", c.alias))
}
c.Statement.serialize(statement, out, FallTrough(options)...)
} else { // serialize CTE in FROM clause
out.WriteIdentifier(c.alias)
}
}
// AS returns sets definition for a CTE
func (c *CommonTableExpression) AS(statement SerializerStatement) CommonTableExpressionDefinition {
c.selectStmt = statement
return CommonTableExpressionDefinition{cte: c}
}
// AllColumns returns list of all projections in the CTE
func (c CommonTableExpression) AllColumns() ProjectionList {
if len(c.Columns) > 0 {
return ColumnListToProjectionList(c.Columns)
}
// CommonTableExpressionDefinition contains implementation details of CTE
type CommonTableExpressionDefinition struct {
cte *CommonTableExpression
}
func (c CommonTableExpressionDefinition) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.cte.alias)
out.WriteString("AS")
c.cte.selectStmt.serialize(statement, out, FallTrough(options)...)
return c.selectTableImpl.AllColumns()
}

View file

@ -56,6 +56,12 @@ func PrintJson(v interface{}) {
fmt.Println(string(jsonText))
}
// ToJSON converts v into json string
func ToJSON(v interface{}) string {
jsonText, _ := json.MarshalIndent(v, "", "\t")
return string(jsonText)
}
// AssertJSON check if data json output is the same as expectedJSON
func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
jsonData, err := json.MarshalIndent(data, "", "\t")

View file

@ -6,6 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet"
type DeleteStatement interface {
Statement
USING(tables ...ReadableTable) DeleteStatement
WHERE(expression BoolExpression) DeleteStatement
ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement
LIMIT(limit int64) DeleteStatement
@ -15,6 +16,7 @@ type deleteStatementImpl struct {
jet.SerializerStatement
Delete jet.ClauseStatementBegin
Using jet.ClauseFrom
Where jet.ClauseWhere
OrderBy jet.ClauseOrderBy
Limit jet.ClauseLimit
@ -22,10 +24,15 @@ type deleteStatementImpl struct {
func newDeleteStatement(table Table) DeleteStatement {
newDelete := &deleteStatementImpl{}
newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete, &newDelete.Delete,
&newDelete.Where, &newDelete.OrderBy, &newDelete.Limit)
newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete,
&newDelete.Delete,
&newDelete.Using,
&newDelete.Where,
&newDelete.OrderBy,
&newDelete.Limit)
newDelete.Delete.Name = "DELETE FROM"
newDelete.Using.Name = "USING"
newDelete.Delete.Tables = append(newDelete.Delete.Tables, table)
newDelete.Where.Mandatory = true
newDelete.Limit.Count = -1
@ -33,6 +40,11 @@ func newDeleteStatement(table Table) DeleteStatement {
return newDelete
}
func (d *deleteStatementImpl) USING(tables ...ReadableTable) DeleteStatement {
d.Using.Tables = readableTablesToSerializerList(tables)
return d
}
func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
d.Where.Condition = expression
return d

View file

@ -87,7 +87,7 @@ var (
RawDate = jet.RawDate
)
// Func can be used to call an custom or as of yet unsupported function in the database.
// Func can be used to call custom or unsupported database functions.
var Func = jet.Func
// NewEnumValue creates new named enum value

View file

@ -7,3 +7,6 @@ var NOT = jet.NOT
// BIT_NOT inverts every bit in integer expression result
var BIT_NOT = jet.BIT_NOT
// DISTINCT operator can be used to return distinct values of expr
var DISTINCT = jet.DISTINCT

View file

@ -58,7 +58,7 @@ type SelectStatement interface {
AsTable(alias string) SelectTable
}
//SELECT creates new SelectStatement with list of projections
// SELECT creates new SelectStatement with list of projections
func SELECT(projection Projection, projections ...Projection) SelectStatement {
return newSelectStatement(nil, append([]Projection{projection}, projections...))
}
@ -106,10 +106,7 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement {
}
func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement {
s.From.Tables = nil
for _, table := range tables {
s.From.Tables = append(s.From.Tables, table)
}
s.From.Tables = readableTablesToSerializerList(tables)
return s
}
@ -189,3 +186,11 @@ func toJetFrameOffset(offset interface{}) jet.Serializer {
return jet.FixedLiteral(offset)
}
func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer {
var ret []jet.Serializer
for _, table := range tables {
ret = append(ret, table)
}
return ret
}

View file

@ -147,10 +147,10 @@ func TestSelect_NOT_EXISTS(t *testing.T) {
))), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE (NOT (EXISTS (
WHERE NOT (EXISTS (
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
WHERE table1.col_int = table2.col_int
)));
));
`)
}

View file

@ -13,7 +13,7 @@ type selectTableImpl struct {
readableTableInterfaceImpl
}
func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable {
func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias),
}

View file

@ -24,4 +24,11 @@ type OrderByClause = jet.OrderByClause
type GroupByClause = jet.GroupByClause
// SetLogger sets automatic statement logging
// Deprecated: use SetQueryLogger instead.
var SetLogger = jet.SetLoggerFunc
// SetQueryLogger sets automatic query logging function.
var SetQueryLogger = jet.SetQueryLogger
// QueryInfo contains information about executed query
type QueryInfo = jet.QueryInfo

View file

@ -2,25 +2,65 @@ package mysql
import "github.com/go-jet/jet/v2/internal/jet"
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
// CommonTableExpression defines set of interface methods for postgres CTEs
type CommonTableExpression interface {
SelectTable
AS(statement jet.SerializerStatement) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
internalCTE() *jet.CommonTableExpression
}
type commonTableExpression struct {
readableTableInterfaceImpl
jet.CommonTableExpression
}
// WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, cte...)
func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, false, toInternalCTE(cte)...)
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
cte := CommonTableExpression{
// WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions
func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, true, toInternalCTE(cte)...)
}
// CTE creates new named commonTableExpression
func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
cte := &commonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name),
CommonTableExpression: jet.CTE(name, columns...),
}
cte.parent = &cte
cte.parent = cte
return cte
}
// AS is used to define a CTE query
func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
return &c.CommonTableExpression
}
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
return newSelectTable(c, name)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {
var ret []*jet.CommonTableExpression
for _, cte := range ctes {
ret = append(ret, cte.internalCTE())
}
return ret
}

View file

@ -52,7 +52,7 @@ func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.S
out.WriteString("ON CONFLICT")
if len(o.indexExpressions) > 0 {
out.WriteString("(")
jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName)
jet.SerializeColumnExpressions(o.indexExpressions, statementType, out, jet.ShortName)
out.WriteString(")")
}

View file

@ -29,7 +29,7 @@ ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`)
)
assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE
SET col_bool = $1,
SET col_bool = $1::boolean,
col_int = $2
WHERE table2.col_float > $3`)
}

View file

@ -6,8 +6,8 @@ import "github.com/go-jet/jet/v2/internal/jet"
type DeleteStatement interface {
jet.SerializerStatement
USING(tables ...ReadableTable) DeleteStatement
WHERE(expression BoolExpression) DeleteStatement
RETURNING(projections ...jet.Projection) DeleteStatement
}
@ -15,22 +15,32 @@ type deleteStatementImpl struct {
jet.SerializerStatement
Delete jet.ClauseStatementBegin
Using jet.ClauseFrom
Where jet.ClauseWhere
Returning jet.ClauseReturning
}
func newDeleteStatement(table WritableTable) DeleteStatement {
newDelete := &deleteStatementImpl{}
newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete, &newDelete.Delete,
&newDelete.Where, &newDelete.Returning)
newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete,
&newDelete.Delete,
&newDelete.Using,
&newDelete.Where,
&newDelete.Returning)
newDelete.Delete.Name = "DELETE FROM"
newDelete.Delete.Tables = append(newDelete.Delete.Tables, table)
newDelete.Using.Name = "USING"
newDelete.Where.Mandatory = true
return newDelete
}
func (d *deleteStatementImpl) USING(tables ...ReadableTable) DeleteStatement {
d.Using.Tables = readableTablesToSerializerList(tables)
return d
}
func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
d.Where.Condition = expression
return d

View file

@ -33,7 +33,7 @@ func TestExists(t *testing.T) {
).EQ(Bool(true)),
`((EXISTS (
SELECT $1
)) = $2)`, int64(1), true)
)) = $2::boolean)`, int64(1), true)
assertProjectionSerialize(t, EXISTS(
SELECT(Int(1)),

View file

@ -100,7 +100,7 @@ var (
RawDate = jet.RawDate
)
// Func can be used to call an custom or as of yet unsupported function in the database.
// Func can be used to call custom or unsupported database functions.
var Func = jet.Func
// NewEnumValue creates new named enum value

View file

@ -336,3 +336,25 @@ func explicitLiteralCast(expresion Expression) jet.Expression {
return expresion
}
// MODE computes the most frequent value of the aggregated argument
var MODE = jet.MODE
// PERCENTILE_CONT computes a value corresponding to the specified fraction within the ordered set of
// aggregated argument values. This will interpolate between adjacent input items if needed.
func PERCENTILE_CONT(fraction FloatExpression) *jet.OrderSetAggregateFunc {
return jet.PERCENTILE_CONT(castFloatLiteral(fraction))
}
// PERCENTILE_DISC computes the first value within the ordered set of aggregated argument values whose position
// in the ordering equals or exceeds the specified fraction. The aggregated argument must be of a sortable type.
func PERCENTILE_DISC(fraction FloatExpression) *jet.OrderSetAggregateFunc {
return jet.PERCENTILE_DISC(castFloatLiteral(fraction))
}
func castFloatLiteral(fraction FloatExpression) FloatExpression {
if _, ok := fraction.(jet.LiteralExpression); ok {
return CAST(fraction).AS_DOUBLE() // to make postgres aware of the type
}
return fraction
}

View file

@ -0,0 +1,12 @@
package postgres
import "testing"
func TestROW(t *testing.T) {
assertSerialize(t, ROW(SELECT(Int(1))), `ROW((
SELECT $1
))`)
assertSerialize(t, ROW(Int(1), SELECT(Int(2)), Float(11.11)), `ROW($1, (
SELECT $2
), $3)`)
}

View file

@ -165,7 +165,7 @@ VALUES ('one', 'two'),
('1', '2'),
('theta', 'beta')
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
SET col_bool = TRUE,
SET col_bool = TRUE::boolean,
col_int = 1,
(col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2
@ -191,7 +191,7 @@ INSERT INTO db.table1 (col1, col_bool)
VALUES ('one', 'two'),
('1', '2')
ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE
SET col_bool = FALSE,
SET col_bool = FALSE::boolean,
col_int = 1,
(col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2

View file

@ -41,6 +41,8 @@ type IntervalExpression interface {
LT_EQ(rhs IntervalExpression) BoolExpression
GT(rhs IntervalExpression) BoolExpression
GT_EQ(rhs IntervalExpression) BoolExpression
BETWEEN(min, max IntervalExpression) BoolExpression
NOT_BETWEEN(min, max IntervalExpression) BoolExpression
ADD(rhs IntervalExpression) IntervalExpression
SUB(rhs IntervalExpression) IntervalExpression
@ -87,6 +89,14 @@ func (i *intervalInterfaceImpl) GT_EQ(rhs IntervalExpression) BoolExpression {
return jet.GtEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) BETWEEN(min, max IntervalExpression) BoolExpression {
return jet.NewBetweenOperatorExpression(i.parent, min, max, false)
}
func (i *intervalInterfaceImpl) NOT_BETWEEN(min, max IntervalExpression) BoolExpression {
return jet.NewBetweenOperatorExpression(i.parent, min, max, true)
}
func (i *intervalInterfaceImpl) ADD(rhs IntervalExpression) IntervalExpression {
return IntervalExp(jet.Add(i.parent, rhs))
}

View file

@ -67,7 +67,7 @@ func TestIntervalExpressionMethods(t *testing.T) {
assertSerialize(t, table1ColInterval.EQ(INTERVAL(10, SECOND)), "(table1.col_interval = INTERVAL '10 SECOND')")
assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)), "(table1.col_interval = INTERVAL '11 MINUTE')")
assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)).EQ(Bool(false)),
"((table1.col_interval = INTERVAL '11 MINUTE') = $1)", false)
"((table1.col_interval = INTERVAL '11 MINUTE') = $1::boolean)", false)
assertSerialize(t, table1ColInterval.NOT_EQ(table2ColInterval), "(table1.col_interval != table2.col_interval)")
assertSerialize(t, table1ColInterval.IS_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS DISTINCT FROM table2.col_interval)")
assertSerialize(t, table1ColInterval.IS_NOT_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS NOT DISTINCT FROM table2.col_interval)")

View file

@ -6,35 +6,53 @@ import (
"github.com/go-jet/jet/v2/internal/jet"
)
// Bool creates new bool literal expression
var Bool = jet.Bool
// Bool is boolean literal constructor
func Bool(value bool) BoolExpression {
return CAST(jet.Bool(value)).AS_BOOL()
}
// Int is constructor for 64 bit signed integer expressions literals.
var Int = jet.Int
// Int8 is constructor for 8 bit signed integer expressions literals.
var Int8 = jet.Int8
func Int8(value int8) IntegerExpression {
return CAST(jet.Int8(value)).AS_SMALLINT()
}
// Int16 is constructor for 16 bit signed integer expressions literals.
var Int16 = jet.Int16
func Int16(value int16) IntegerExpression {
return CAST(jet.Int16(value)).AS_SMALLINT()
}
// Int32 is constructor for 32 bit signed integer expressions literals.
var Int32 = jet.Int32
func Int32(value int32) IntegerExpression {
return CAST(jet.Int32(value)).AS_INTEGER()
}
// Int64 is constructor for 64 bit signed integer expressions literals.
var Int64 = jet.Int
func Int64(value int64) IntegerExpression {
return CAST(jet.Int(value)).AS_BIGINT()
}
// Uint8 is constructor for 8 bit unsigned integer expressions literals.
var Uint8 = jet.Uint8
func Uint8(value uint8) IntegerExpression {
return CAST(jet.Uint8(value)).AS_SMALLINT()
}
// Uint16 is constructor for 16 bit unsigned integer expressions literals.
var Uint16 = jet.Uint16
func Uint16(value uint16) IntegerExpression {
return CAST(jet.Uint16(value)).AS_INTEGER()
}
// Uint32 is constructor for 32 bit unsigned integer expressions literals.
var Uint32 = jet.Uint32
func Uint32(value uint32) IntegerExpression {
return CAST(jet.Uint32(value)).AS_BIGINT()
}
// Uint64 is constructor for 64 bit unsigned integer expressions literals.
var Uint64 = jet.Uint64
func Uint64(value uint64) IntegerExpression {
return CAST(jet.Uint64(value)).AS_BIGINT()
}
// Float creates new float literal expression
var Float = jet.Float

View file

@ -7,7 +7,7 @@ import (
)
func TestBool(t *testing.T) {
assertSerialize(t, Bool(false), `$1`, false)
assertSerialize(t, Bool(false), `$1::boolean`, false)
}
func TestInt(t *testing.T) {
@ -16,42 +16,42 @@ func TestInt(t *testing.T) {
func TestInt8(t *testing.T) {
val := int8(math.MinInt8)
assertSerialize(t, Int8(val), `$1`, val)
assertSerialize(t, Int8(val), `$1::smallint`, val)
}
func TestInt16(t *testing.T) {
val := int16(math.MinInt16)
assertSerialize(t, Int16(val), `$1`, val)
assertSerialize(t, Int16(val), `$1::smallint`, val)
}
func TestInt32(t *testing.T) {
val := int32(math.MinInt32)
assertSerialize(t, Int32(val), `$1`, val)
assertSerialize(t, Int32(val), `$1::integer`, val)
}
func TestInt64(t *testing.T) {
val := int64(math.MinInt64)
assertSerialize(t, Int64(val), `$1`, val)
assertSerialize(t, Int64(val), `$1::bigint`, val)
}
func TestUint8(t *testing.T) {
val := uint8(math.MaxUint8)
assertSerialize(t, Uint8(val), `$1`, val)
assertSerialize(t, Uint8(val), `$1::smallint`, val)
}
func TestUint16(t *testing.T) {
val := uint16(math.MaxUint16)
assertSerialize(t, Uint16(val), `$1`, val)
assertSerialize(t, Uint16(val), `$1::integer`, val)
}
func TestUint32(t *testing.T) {
val := uint32(math.MaxUint32)
assertSerialize(t, Uint32(val), `$1`, val)
assertSerialize(t, Uint32(val), `$1::bigint`, val)
}
func TestUint64(t *testing.T) {
val := uint64(math.MaxUint64)
assertSerialize(t, Uint64(val), `$1`, val)
assertSerialize(t, Uint64(val), `$1::bigint`, val)
}
func TestFloat(t *testing.T) {

View file

@ -7,3 +7,6 @@ var NOT = jet.NOT
// BIT_NOT inverts every bit in integer expression result
var BIT_NOT = jet.BIT_NOT
// DISTINCT operator can be used to return distinct values of expr
var DISTINCT = jet.DISTINCT

View file

@ -44,7 +44,7 @@ type SelectStatement interface {
jet.HasProjections
Expression
DISTINCT() SelectStatement
DISTINCT(on ...jet.ColumnExpression) SelectStatement
FROM(tables ...ReadableTable) SelectStatement
WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...GroupByClause) SelectStatement
@ -104,16 +104,14 @@ type selectStatementImpl struct {
For jet.ClauseFor
}
func (s *selectStatementImpl) DISTINCT() SelectStatement {
func (s *selectStatementImpl) DISTINCT(on ...jet.ColumnExpression) SelectStatement {
s.Select.Distinct = true
s.Select.DistinctOnColumns = on
return s
}
func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement {
s.From.Tables = nil
for _, table := range tables {
s.From.Tables = append(s.From.Tables, table)
}
s.From.Tables = readableTablesToSerializerList(tables)
return s
}
@ -182,3 +180,11 @@ func toJetFrameOffset(offset int64) jet.Serializer {
}
return jet.FixedLiteral(offset)
}
func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer {
var ret []jet.Serializer
for _, table := range tables {
ret = append(ret, table)
}
return ret
}

View file

@ -23,7 +23,7 @@ func TestSelectLiterals(t *testing.T) {
assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), `
SELECT $1,
$2,
$3
$3::boolean
FROM db.table1;
`, int64(1), 2.2, false)
}
@ -59,7 +59,7 @@ func TestSelectWhere(t *testing.T) {
assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE $1;
WHERE $1::boolean;
`, true)
assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), `
SELECT table1.col_int AS "table1.col_int"
@ -80,7 +80,7 @@ func TestSelectHaving(t *testing.T) {
assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), `
SELECT table3.col_int AS "table3.col_int"
FROM db.table3
HAVING table1.col_bool = $1;
HAVING table1.col_bool = $1::boolean;
`, true)
}

View file

@ -2,7 +2,7 @@ package postgres
import "github.com/go-jet/jet/v2/internal/jet"
// SelectTable is interface for MySQL sub-queries
// SelectTable is interface for postgres sub-queries
type SelectTable interface {
readableTable
jet.SelectTable
@ -13,7 +13,7 @@ type selectTableImpl struct {
readableTableInterfaceImpl
}
func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable {
func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias),
}

View file

@ -23,5 +23,12 @@ type OrderByClause = jet.OrderByClause
// GroupByClause interface to use as input for GROUP_BY
type GroupByClause = jet.GroupByClause
// SetLogger sets automatic statement logging
// SetLogger sets automatic statement logging function
// Deprecated: use SetQueryLogger instead.
var SetLogger = jet.SetLoggerFunc
// SetQueryLogger sets automatic query logging function.
var SetQueryLogger = jet.SetQueryLogger
// QueryInfo contains information about executed query
type QueryInfo = jet.QueryInfo

View file

@ -11,8 +11,9 @@ type UpdateStatement interface {
SET(value interface{}, values ...interface{}) UpdateStatement
MODEL(data interface{}) UpdateStatement
FROM(tables ...ReadableTable) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...jet.Projection) UpdateStatement
RETURNING(projections ...Projection) UpdateStatement
}
type updateStatementImpl struct {
@ -21,6 +22,7 @@ type updateStatementImpl struct {
Update jet.ClauseUpdate
Set clauseSet
SetNew jet.SetClauseNew
From jet.ClauseFrom
Where jet.ClauseWhere
Returning jet.ClauseReturning
}
@ -31,6 +33,7 @@ func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStateme
&update.Update,
&update.Set,
&update.SetNew,
&update.From,
&update.Where,
&update.Returning)
@ -61,6 +64,11 @@ func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement {
return u
}
func (u *updateStatementImpl) FROM(tables ...ReadableTable) UpdateStatement {
u.From.Tables = readableTablesToSerializerList(tables)
return u
}
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.Where.Condition = expression
return u

View file

@ -2,25 +2,73 @@ package postgres
import "github.com/go-jet/jet/v2/internal/jet"
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
// CommonTableExpression defines set of interface methods for postgres CTEs
type CommonTableExpression interface {
SelectTable
AS(statement jet.SerializerStatement) CommonTableExpression
AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
internalCTE() *jet.CommonTableExpression
}
type commonTableExpression struct {
readableTableInterfaceImpl
jet.CommonTableExpression
}
// WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, cte...)
func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, false, toInternalCTE(cte)...)
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
cte := CommonTableExpression{
// WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions
func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, true, toInternalCTE(cte)...)
}
// CTE creates new named commonTableExpression
func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
cte := &commonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name),
CommonTableExpression: jet.CTE(name, columns...),
}
cte.parent = &cte
cte.parent = cte
return cte
}
// AS is used to define a CTE query
func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
// AS_NOT_MATERIALIZED is used to define not materialized CTE query
func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression {
c.CommonTableExpression.NotMaterialized = true
c.CommonTableExpression.Statement = statement
return c
}
func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
return &c.CommonTableExpression
}
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
return newSelectTable(c, name)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {
var ret []*jet.CommonTableExpression
for _, cte := range ctes {
ret = append(ret, cte.internalCTE())
}
return ret
}

View file

@ -17,7 +17,7 @@ var ErrNoRows = errors.New("qrm: no rows in result set")
// using context `ctx` into destination `destPtr`.
// Destination can be either pointer to struct or pointer to slice of structs.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) error {
func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) {
utils.MustBeInitializedPtr(db, "jet: db is nil")
utils.MustBeInitializedPtr(destPtr, "jet: destination is nil")
@ -26,11 +26,11 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
destinationPtrType := reflect.TypeOf(destPtr)
if destinationPtrType.Elem().Kind() == reflect.Slice {
_, err := queryToSlice(ctx, db, query, args, destPtr)
rowsProcessed, err := queryToSlice(ctx, db, query, args, destPtr)
if err != nil {
return fmt.Errorf("jet: %w", err)
return rowsProcessed, fmt.Errorf("jet: %w", err)
}
return nil
return rowsProcessed, nil
} else if destinationPtrType.Elem().Kind() == reflect.Struct {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem()
@ -38,16 +38,16 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface())
if err != nil {
return fmt.Errorf("jet: %w", err)
return rowsProcessed, fmt.Errorf("jet: %w", err)
}
if rowsProcessed == 0 {
return ErrNoRows
return 0, ErrNoRows
}
// edge case when row result set contains only NULLs.
if tempSliceValue.Len() == 0 {
return nil
return rowsProcessed, nil
}
structValue := reflect.ValueOf(destPtr).Elem()
@ -56,7 +56,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
if structValue.Type().AssignableTo(firstTempStruct.Type()) {
structValue.Set(tempSliceValue.Index(0).Elem())
}
return nil
return rowsProcessed, nil
} else {
panic("jet: destination has to be a pointer to slice or pointer to struct")
}
@ -87,7 +87,7 @@ func ScanOneRowToDest(rows *sql.Rows, destPtr interface{}) error {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem()
_, err = mapRowToSlice(scanContext, "", tempSlicePtrValue, nil)
_, err = mapRowToSlice(scanContext, "", newTypeStack(), tempSlicePtrValue, nil)
if err != nil {
return fmt.Errorf("failed to map a row, %w", err)
@ -136,35 +136,32 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
err = rows.Scan(scanContext.row...)
if err != nil {
return
return scanContext.rowNum, err
}
scanContext.rowNum++
_, err = mapRowToSlice(scanContext, "", slicePtrValue, nil)
_, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil)
if err != nil {
return
return scanContext.rowNum, err
}
}
err = rows.Close()
if err != nil {
return
return scanContext.rowNum, err
}
err = rows.Err()
if err != nil {
return
}
rowsProcessed = scanContext.rowNum
return
return scanContext.rowNum, rows.Err()
}
func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) {
func mapRowToSlice(
scanContext *scanContext,
groupKey string,
typesVisited *typeStack,
slicePtrValue reflect.Value,
field *reflect.StructField) (updated bool, err error) {
sliceElemType := getSliceElemType(slicePtrValue)
@ -184,12 +181,12 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
if ok {
structPtrValue := getSliceElemPtrAt(slicePtrValue, index)
return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true)
return mapRowToStruct(scanContext, groupKey, typesVisited, structPtrValue, field, true)
}
destinationStructPtr := newElemPtrValueForSlice(slicePtrValue)
updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field)
updated, err = mapRowToStruct(scanContext, groupKey, typesVisited, destinationStructPtr, field)
if err != nil {
return
@ -228,10 +225,25 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value
return
}
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) {
func mapRowToStruct(
scanContext *scanContext,
groupKey string,
typesVisited *typeStack, // to prevent circular dependency scan
structPtrValue reflect.Value,
parentField *reflect.StructField,
onlySlices ...bool, // small optimization, not to assign to already assigned struct fields
) (updated bool, err error) {
mapOnlySlices := len(onlySlices) > 0
structType := structPtrValue.Type().Elem()
if typesVisited.contains(&structType) {
return false, nil
}
typesVisited.push(&structType)
defer typesVisited.pop()
typeInf := scanContext.getTypeInfo(structType, parentField)
structValue := structPtrValue.Elem()
@ -248,7 +260,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
if fieldMap.complexType {
var changed bool
changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field)
changed, err = mapRowToDestinationValue(scanContext, groupKey, typesVisited, fieldValue, &field)
if err != nil {
return
@ -295,7 +307,12 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
return
}
func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) {
func mapRowToDestinationValue(
scanContext *scanContext,
groupKey string,
typesVisited *typeStack,
dest reflect.Value,
structField *reflect.StructField) (updated bool, err error) {
var destPtrValue reflect.Value
@ -309,7 +326,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
}
}
updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField)
updated, err = mapRowToDestinationPtr(scanContext, groupKey, typesVisited, destPtrValue, structField)
if err != nil {
return
@ -322,16 +339,21 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
return
}
func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
func mapRowToDestinationPtr(
scanContext *scanContext,
groupKey string,
typesVisited *typeStack,
destPtrValue reflect.Value,
structField *reflect.StructField) (updated bool, err error) {
utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.")
destValueKind := destPtrValue.Elem().Kind()
if destValueKind == reflect.Struct {
return mapRowToStruct(scanContext, groupKey, destPtrValue, structField)
return mapRowToStruct(scanContext, groupKey, typesVisited, destPtrValue, structField)
} else if destValueKind == reflect.Slice {
return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
return mapRowToSlice(scanContext, groupKey, typesVisited, destPtrValue, structField)
} else {
panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
}

View file

@ -132,7 +132,7 @@ func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.
return s.constructGroupKey(groupKeyInfo)
}
groupKeyInfo := s.getGroupKeyInfo(structType, structField)
groupKeyInfo := s.getGroupKeyInfo(structType, structField, newTypeStack())
s.groupKeyInfoCache[mapKey] = groupKeyInfo
@ -144,7 +144,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
return fmt.Sprintf("|ROW:%d|", s.rowNum)
}
groupKeys := []string{}
var groupKeys []string
for _, index := range groupKeyInfo.indexes {
cellValue := s.rowElem(index)
@ -153,7 +153,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
groupKeys = append(groupKeys, subKey)
}
subTypesGroupKeys := []string{}
var subTypesGroupKeys []string
for _, subType := range groupKeyInfo.subTypes {
subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType))
}
@ -161,9 +161,20 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")"
}
func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo {
func (s *scanContext) getGroupKeyInfo(
structType reflect.Type,
parentField *reflect.StructField,
typeVisited *typeStack) groupKeyInfo {
ret := groupKeyInfo{typeName: structType.Name()}
if typeVisited.contains(&structType) {
return ret
}
typeVisited.push(&structType)
defer typeVisited.pop()
typeName := getTypeName(structType, parentField)
primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField)
@ -176,7 +187,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl
continue
}
subType := s.getGroupKeyInfo(fieldType, &field)
subType := s.getGroupKeyInfo(fieldType, &field, typeVisited)
if len(subType.indexes) != 0 || len(subType.subTypes) != 0 {
ret.subTypes = append(ret.subTypes, subType)

40
qrm/type_stack.go Normal file
View file

@ -0,0 +1,40 @@
package qrm
import "reflect"
type typeStack []*reflect.Type
func newTypeStack() *typeStack {
stack := make(typeStack, 0, 20)
return &stack
}
func (s *typeStack) isEmpty() bool {
return len(*s) == 0
}
func (s *typeStack) push(t *reflect.Type) {
*s = append(*s, t)
}
func (s *typeStack) pop() bool {
if s.isEmpty() {
return false
}
*s = (*s)[:len(*s)-1]
return true
}
func (s *typeStack) contains(t *reflect.Type) bool {
if s.isEmpty() {
return false
}
for _, typ := range *s {
if *typ == *t {
return true
}
}
return false
}

View file

@ -9,7 +9,7 @@ type DeleteStatement interface {
WHERE(expression BoolExpression) DeleteStatement
ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement
LIMIT(limit int64) DeleteStatement
RETURNING(projections ...jet.Projection) DeleteStatement
RETURNING(projections ...Projection) DeleteStatement
}
type deleteStatementImpl struct {

View file

@ -90,7 +90,7 @@ var (
RawDate = jet.RawDate
)
// Func can be used to call an custom or as of yet unsupported function in the database.
// Func can be used to call custom or unsupported database functions.
var Func = jet.Func
// NewEnumValue creates new named enum value

View file

@ -24,7 +24,6 @@ func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert,
&newInsert.ValuesQuery,
&newInsert.OnDuplicateKey,
&newInsert.DefaultValues,
&newInsert.OnConflict,
&newInsert.Returning,
@ -40,12 +39,11 @@ func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
type insertStatementImpl struct {
jet.SerializerStatement
Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery
OnDuplicateKey onDuplicateKeyUpdateClause
DefaultValues jet.ClauseOptional
OnConflict onConflictClause
Returning jet.ClauseReturning
Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery
DefaultValues jet.ClauseOptional
OnConflict onConflictClause
Returning jet.ClauseReturning
}
func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
@ -65,11 +63,6 @@ func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement {
return is
}
func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement {
is.OnDuplicateKey = assigments
return is
}
func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
is.ValuesQuery.Query = selectStatement
return is
@ -85,29 +78,6 @@ func (is *insertStatementImpl) RETURNING(projections ...jet.Projection) InsertSt
return is
}
type onDuplicateKeyUpdateClause []jet.ColumnAssigment
// Serialize for SetClause
func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(s) == 0 {
return
}
out.NewLine()
out.WriteString("ON DUPLICATE KEY UPDATE")
out.IncreaseIdent(24)
for i, assigment := range s {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...)
}
out.DecreaseIdent(24)
}
func (is *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict {
is.OnConflict = onConflictClause{
insertStatement: is,

View file

@ -45,7 +45,7 @@ func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.S
out.WriteString("ON CONFLICT")
if len(o.indexExpressions) > 0 {
out.WriteString("(")
jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName)
jet.SerializeColumnExpressions(o.indexExpressions, statementType, out, jet.ShortName)
out.WriteString(")")
}

View file

@ -7,3 +7,6 @@ var NOT = jet.NOT
// BIT_NOT inverts every bit in integer expression result
var BIT_NOT = jet.BIT_NOT
// DISTINCT operator can be used to return distinct values of expr
var DISTINCT = jet.DISTINCT

View file

@ -106,10 +106,7 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement {
}
func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement {
s.From.Tables = nil
for _, table := range tables {
s.From.Tables = append(s.From.Tables, table)
}
s.From.Tables = readableTablesToSerializerList(tables)
return s
}
@ -184,3 +181,11 @@ func toJetFrameOffset(offset interface{}) jet.Serializer {
return jet.FixedLiteral(offset)
}
func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer {
var ret []jet.Serializer
for _, table := range tables {
ret = append(ret, table)
}
return ret
}

View file

@ -147,10 +147,10 @@ func TestSelect_NOT_EXISTS(t *testing.T) {
))), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE (NOT (EXISTS (
WHERE NOT (EXISTS (
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
WHERE table1.col_int = table2.col_int
)));
));
`)
}

View file

@ -13,7 +13,7 @@ type selectTableImpl struct {
readableTableInterfaceImpl
}
func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable {
func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias),
}

View file

@ -23,5 +23,12 @@ type OrderByClause = jet.OrderByClause
// GroupByClause interface to use as input for GROUP_BY
type GroupByClause = jet.GroupByClause
// SetLogger sets automatic statement logging
// SetLogger sets automatic statement logging.
// Deprecated: use SetQueryLogger instead.
var SetLogger = jet.SetLoggerFunc
// SetQueryLogger sets automatic query logging function.
var SetQueryLogger = jet.SetQueryLogger
// QueryInfo contains information about executed query
type QueryInfo = jet.QueryInfo

View file

@ -9,14 +9,16 @@ type UpdateStatement interface {
SET(value interface{}, values ...interface{}) UpdateStatement
MODEL(data interface{}) UpdateStatement
FROM(tables ...ReadableTable) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...jet.Projection) UpdateStatement
RETURNING(projections ...Projection) UpdateStatement
}
type updateStatementImpl struct {
jet.SerializerStatement
Update jet.ClauseUpdate
From jet.ClauseFrom
Set jet.SetClause
SetNew jet.SetClauseNew
Where jet.ClauseWhere
@ -29,6 +31,7 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
&update.Update,
&update.Set,
&update.SetNew,
&update.From,
&update.Where,
&update.Returning)
@ -59,12 +62,17 @@ func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement {
return u
}
func (u *updateStatementImpl) FROM(tables ...ReadableTable) UpdateStatement {
u.From.Tables = readableTablesToSerializerList(tables)
return u
}
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.Where.Condition = expression
return u
}
func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement {
func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement {
u.Returning.ProjectionList = projections
return u
}

View file

@ -2,25 +2,73 @@ package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
// CommonTableExpression defines set of interface methods for postgres CTEs
type CommonTableExpression interface {
SelectTable
AS(statement jet.SerializerStatement) CommonTableExpression
AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
internalCTE() *jet.CommonTableExpression
}
type commonTableExpression struct {
readableTableInterfaceImpl
jet.CommonTableExpression
}
// WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, cte...)
func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, false, toInternalCTE(cte)...)
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
cte := CommonTableExpression{
// WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions
func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, true, toInternalCTE(cte)...)
}
// CTE creates new named commonTableExpression
func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
cte := &commonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name),
CommonTableExpression: jet.CTE(name, columns...),
}
cte.parent = &cte
cte.parent = cte
return cte
}
// AS is used to define a CTE query
func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
// AS_NOT_MATERIALIZED is used to define not materialized CTE query
func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression {
c.CommonTableExpression.NotMaterialized = true
c.CommonTableExpression.Statement = statement
return c
}
func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
return &c.CommonTableExpression
}
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
return newSelectTable(c, name)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {
var ret []*jet.CommonTableExpression
for _, cte := range ctes {
ret = append(ret, cte.internalCTE())
}
return ret
}

62
tests/Makefile Normal file
View file

@ -0,0 +1,62 @@
setup: checkout-testdata docker-compose-up
# checkout-testdata will checkout testdata from separate repository into git submodule.
checkout-testdata:
git submodule init
git submodule update
cd ./testdata && git fetch && git checkout master && git pull
# docker-compose-up will download docker image for each of the databases listed in docker-compose.yaml file, and then it will initialize
# database with testdata retrieved in previous step.
# On the first run this action might take couple of minutes. Docker temp data are stored in .docker directory.
docker-compose-up:
docker-compose up
init-all:
go run ./init/init.go -testsuite all
init-postgres:
go run ./init/init.go -testsuite postgres
init-mysql:
go run ./init/init.go -testsuite mysql
init-mariadb:
go run ./init/init.go -testsuite mariadb
init-sqlite:
go run ./init/init.go -testsuite sqlite
# jet-gen will call generator on each of the test databases to generate sql builder and model files need to run the tests.
jet-gen-all: install-jet-gen jet-gen-postgres jet-gen-mysql jet-gen-mariadb jet-gen-sqlite
install-jet-gen:
go build -o ${GOPATH}/bin/jet ../cmd/jet/
jet-gen-postgres:
jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=dvds -path=./.gentestdata/
jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook -path=./.gentestdata/
jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook2 -path=./.gentestdata/
jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=test_sample -path=./.gentestdata/
jet-gen-mysql:
jet -source=mysql -dsn="jet:jet@tcp(localhost:50902)/dvds" -path=./.gentestdata/mysql
jet -source=mysql -dsn="jet:jet@tcp(localhost:50902)/dvds2" -path=./.gentestdata/mysql
jet -source=mysql -dsn="jet:jet@tcp(localhost:50902)/test_sample" -path=./.gentestdata/mysql
jet-gen-mariadb:
jet -source=mariadb -dsn="jet:jet@tcp(localhost:50903)/dvds" -path=./.gentestdata/mysql
jet -source=mariadb -dsn="jet:jet@tcp(localhost:50903)/dvds2" -path=./.gentestdata/mysql
jet -source=mariadb -dsn="jet:jet@tcp(localhost:50903)/test_sample" -path=./.gentestdata/mysql
jet-gen-sqlite:
jet -source=sqlite -dsn="./testdata/init/sqlite/chinook.db" -schema=dvds -path=./.gentestdata/sqlite/chinook
jet -source=sqlite -dsn="./testdata/init/sqlite/sakila.db" -schema=dvds -path=./.gentestdata/sqlite/sakila
jet -source=sqlite -dsn="./testdata/init/sqlite/test_sample.db" -schema=dvds -path=./.gentestdata/sqlite/test_sample
# docker-compose-cleanup will stop and remove test containers, volumes, and images.
cleanup:
docker-compose down --volumes

29
tests/Readme.md Normal file
View file

@ -0,0 +1,29 @@
# Integration tests
This folder contains integration tests intended to test jet generator, statements and query result mapping with a running database.
## How to run tests?
Before we can run tests, we need to set up and initialize test databases.
To simplify the process there is a Makefile with a list of helper commands.
```shell
# We first need to checkout testdata from separate repository into git submodule,
# then download docker image for each of the databases listed in docker-compose.yaml file, and
# finally run and initialize databases with downloaded test data.
# Note that on the first run this command might take a couple of minutes.
make setup
# When databases are ready, we can generate sql builder and model types for each of the test databases
make jet-gen-all
```
Then we can run the tests the usual way:
```shell
go test -v ./...
```
To removes test containers, volumes, and images:
```shell
make cleanup
```

View file

@ -8,7 +8,7 @@ import (
// Postgres test database connection parameters
const (
PgHost = "localhost"
PgPort = 5432
PgPort = 50901
PgUser = "jet"
PgPassword = "jet"
PgDBName = "jetdb"
@ -19,14 +19,25 @@ var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbn
// MySQL test database connection parameters
const (
MySqLHost = "localhost"
MySQLPort = 3306
MySqLHost = "127.0.0.1"
MySQLPort = 50902
MySQLUser = "jet"
MySQLPassword = "jet"
MariaDBHost = "127.0.0.1"
MariaDBPort = 50903
MariaDBUser = "jet"
MariaDBPassword = "jet"
)
// MySQLConnectionString is MySQL driver connection string to test database
var MySQLConnectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/", MySQLUser, MySQLPassword, MySqLHost, MySQLPort)
// MySQLConnectionString is MySQL connection string for test database
func MySQLConnectionString(isMariaDB bool, dbName string) string {
if isMariaDB {
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", MariaDBUser, MariaDBPassword, MariaDBHost, MariaDBPort, dbName)
}
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", MySQLUser, MySQLPassword, MySqLHost, MySQLPort, dbName)
}
// sqllite
var (

39
tests/docker-compose.yaml Normal file
View file

@ -0,0 +1,39 @@
version: '3'
services:
postgres:
image: postgres:14.1
restart: always
environment:
- POSTGRES_USER=jet
- POSTGRES_PASSWORD=jet
- POSTGRES_DB=jetdb
ports:
- '50901:5432'
volumes:
- ./testdata/init/postgres:/docker-entrypoint-initdb.d
mysql:
image: mysql:8.0.27
command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1']
restart: always
environment:
MYSQL_ROOT_PASSWORD: jet
MYSQL_USER: jet
MYSQL_PASSWORD: jet
ports:
- '50902:3306'
volumes:
- ./testdata/init/mysql:/docker-entrypoint-initdb.d
mariadb:
image: mariadb:10.3.32
command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1']
restart: always
environment:
MYSQL_ROOT_PASSWORD: jet
MYSQL_USER: jet
MYSQL_PASSWORD: jet
ports:
- '50903:3306'
volumes:
- ./testdata/init/mysql:/docker-entrypoint-initdb.d

6
tests/init/Readme.md Normal file
View file

@ -0,0 +1,6 @@
The `init` command can be used to initialize test databases on the local host machine, if needed.
Update [dbconfig](../dbconfig/dbconfig.go) with your local database parameters.
The recommended way to initialize test databases is by a docker container.
See tests [Readme.md](../Readme.md).

View file

@ -4,6 +4,7 @@ import (
"database/sql"
"flag"
"fmt"
"github.com/go-jet/jet/v2/generator/mysql"
"github.com/go-jet/jet/v2/generator/sqlite"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
"io/ioutil"
@ -11,7 +12,6 @@ import (
"os/exec"
"strings"
"github.com/go-jet/jet/v2/generator/mysql"
"github.com/go-jet/jet/v2/generator/postgres"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/tests/dbconfig"
@ -39,7 +39,7 @@ func main() {
}
if testSuite == "mysql" || testSuite == "mariadb" {
initMySQLDB()
initMySQLDB(testSuite == "mariadb")
return
}
@ -48,8 +48,9 @@ func main() {
return
}
initMySQLDB()
initPostgresDB()
initMySQLDB(false)
initMySQLDB(true)
initSQLiteDB()
}
@ -62,7 +63,7 @@ func initSQLiteDB() {
throw.OnError(err)
}
func initMySQLDB() {
func initMySQLDB(isMariaDB bool) {
mySQLDBs := []string{
"dvds",
@ -71,8 +72,20 @@ func initMySQLDB() {
}
for _, dbName := range mySQLDBs {
cmdLine := fmt.Sprintf("mysql -h 127.0.0.1 -u %s -p%s %s < %s",
dbconfig.MySQLUser, dbconfig.MySQLPassword, dbName, "./testdata/init/mysql/"+dbName+".sql")
host := dbconfig.MySqLHost
port := dbconfig.MySQLPort
user := dbconfig.MySQLUser
pass := dbconfig.MySQLPassword
if isMariaDB {
host = dbconfig.MariaDBHost
port = dbconfig.MariaDBPort
user = dbconfig.MariaDBUser
pass = dbconfig.MariaDBPassword
}
cmdLine := fmt.Sprintf("mysql -h %s -P %d -u %s -p%s %s < %s", host, port, user, pass, dbName,
"./testdata/init/mysql/"+dbName+".sql")
fmt.Println(cmdLine)
@ -85,10 +98,10 @@ func initMySQLDB() {
throw.OnError(err)
err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{
Host: dbconfig.MySqLHost,
Port: dbconfig.MySQLPort,
User: dbconfig.MySQLUser,
Password: dbconfig.MySQLPassword,
Host: host,
Port: port,
User: user,
Password: pass,
DBName: dbName,
})
@ -99,7 +112,7 @@ func initMySQLDB() {
func initPostgresDB() {
db, err := sql.Open("postgres", dbconfig.PostgresConnectString)
if err != nil {
panic("Failed to connect to test db")
panic("Failed to connect to test db: " + err.Error())
}
defer func() {
err := db.Close()

View file

@ -31,10 +31,6 @@ func TestAllTypes(t *testing.T) {
require.Equal(t, len(dest), 2)
if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert
return
}
//testutils.PrintJson(dest)
testutils.AssertJSON(t, dest, allTypesJson)
}
@ -49,10 +45,6 @@ func TestAllTypesViewSelect(t *testing.T) {
require.NoError(t, err)
require.Equal(t, len(dest), 2)
if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert
return
}
testutils.AssertJSON(t, dest, allTypesJson)
}
@ -224,6 +216,8 @@ func TestFloatOperators(t *testing.T) {
AllTypes.Numeric.LT(Float(34.56)).AS("lt2"),
AllTypes.Numeric.GT(Float(124)).AS("gt1"),
AllTypes.Numeric.GT(Float(34.56)).AS("gt2"),
AllTypes.Numeric.BETWEEN(Float(1.34), AllTypes.Decimal).AS("between"),
AllTypes.Numeric.NOT_BETWEEN(AllTypes.Decimal.MUL(Float(3)), Float(100.12)).AS("not_between"),
TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Int(2)).AS("add1"),
TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int(2)).AS("add2"),
@ -252,11 +246,9 @@ func TestFloatOperators(t *testing.T) {
TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"),
).LIMIT(2)
queryStr, _ := query.Sql()
// fmt.Println(query.Sql())
//fmt.Println(queryStr)
require.Equal(t, queryStr, strings.Replace(`
testutils.AssertStatementSql(t, query, strings.Replace(`
SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1",
(all_types.'decimal' = ?) AS "eq2",
(all_types.'real' = ?) AS "eq3",
@ -270,22 +262,24 @@ SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1",
(all_types.'numeric' < ?) AS "lt2",
(all_types.'numeric' > ?) AS "gt1",
(all_types.'numeric' > ?) AS "gt2",
TRUNCATE((all_types.'decimal' + all_types.'decimal'), ?) AS "add1",
TRUNCATE((all_types.'decimal' + ?), ?) AS "add2",
TRUNCATE((all_types.'decimal' - all_types.decimal_ptr), ?) AS "sub1",
TRUNCATE((all_types.'decimal' - ?), ?) AS "sub2",
TRUNCATE((all_types.'decimal' * all_types.decimal_ptr), ?) AS "mul1",
TRUNCATE((all_types.'decimal' * ?), ?) AS "mul2",
TRUNCATE((all_types.'decimal' / all_types.decimal_ptr), ?) AS "div1",
TRUNCATE((all_types.'decimal' / ?), ?) AS "div2",
TRUNCATE((all_types.'decimal' % all_types.decimal_ptr), ?) AS "mod1",
TRUNCATE((all_types.'decimal' % ?), ?) AS "mod2",
(all_types.'numeric' BETWEEN ? AND all_types.'decimal') AS "between",
(all_types.'numeric' NOT BETWEEN (all_types.'decimal' * ?) AND ?) AS "not_between",
TRUNCATE(all_types.'decimal' + all_types.'decimal', ?) AS "add1",
TRUNCATE(all_types.'decimal' + ?, ?) AS "add2",
TRUNCATE(all_types.'decimal' - all_types.decimal_ptr, ?) AS "sub1",
TRUNCATE(all_types.'decimal' - ?, ?) AS "sub2",
TRUNCATE(all_types.'decimal' * all_types.decimal_ptr, ?) AS "mul1",
TRUNCATE(all_types.'decimal' * ?, ?) AS "mul2",
TRUNCATE(all_types.'decimal' / all_types.decimal_ptr, ?) AS "div1",
TRUNCATE(all_types.'decimal' / ?, ?) AS "div2",
TRUNCATE(all_types.'decimal' % all_types.decimal_ptr, ?) AS "mod1",
TRUNCATE(all_types.'decimal' % ?, ?) AS "mod2",
TRUNCATE(POW(all_types.'decimal', all_types.decimal_ptr), ?) AS "pow1",
TRUNCATE(POW(all_types.'decimal', ?), ?) AS "pow2",
TRUNCATE(ABS(all_types.'decimal'), ?) AS "abs",
TRUNCATE(POWER(all_types.'decimal', ?), ?) AS "power",
TRUNCATE(SQRT(all_types.'decimal'), ?) AS "sqrt",
TRUNCATE(POWER(all_types.'decimal', (? / ?)), ?) AS "cbrt",
TRUNCATE(POWER(all_types.'decimal', ? / ?), ?) AS "cbrt",
CEIL(all_types.'real') AS "ceil",
FLOOR(all_types.'real') AS "floor",
ROUND(all_types.'decimal') AS "round1",
@ -316,61 +310,48 @@ func TestIntegerOperators(t *testing.T) {
AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"),
AllTypes.BigInt.EQ(Int(12)).AS("eq2"),
AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"),
AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"),
AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"),
AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"),
AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"),
AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"),
AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"),
AllTypes.BigInt.LT(Int(65)).AS("lt2"),
AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"),
AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"),
AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"),
AllTypes.BigInt.GT(Int(65)).AS("gt2"),
AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"),
AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"),
AllTypes.Integer.BETWEEN(Int(11), Int(200)).AS("between"),
AllTypes.Integer.NOT_BETWEEN(Int(66), Int(77)).AS("not_between"),
AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"),
AllTypes.BigInt.ADD(Int(11)).AS("add2"),
AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"),
AllTypes.BigInt.SUB(Int(11)).AS("sub2"),
AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"),
AllTypes.BigInt.MUL(Int(11)).AS("mul2"),
AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"),
AllTypes.BigInt.DIV(Int(11)).AS("div2"),
AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"),
AllTypes.BigInt.MOD(Int(11)).AS("mod2"),
AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int(3))).AS("pow1"),
AllTypes.SmallInt.POW(Int(6)).AS("pow2"),
AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"),
AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"),
AllTypes.SmallInt.BIT_OR(AllTypes.SmallInt).AS("bit or 1"),
AllTypes.SmallInt.BIT_OR(Int(22)).AS("bit or 2"),
AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"),
AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"),
BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"),
BIT_NOT(Int(-1).MUL(Int(11))).AS("bit_not_2"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"),
AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"),
AllTypes.SmallInt.BIT_SHIFT_RIGHT(Int(1)).AS("bit shift right 2"),
@ -379,9 +360,9 @@ func TestIntegerOperators(t *testing.T) {
CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"),
).LIMIT(2)
//fmt.Println(query.Sql())
// fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, `
testutils.AssertStatementSql(t, query, strings.ReplaceAll(`
SELECT all_types.big_int AS "all_types.big_int",
all_types.big_int_ptr AS "all_types.big_int_ptr",
all_types.small_int AS "all_types.small_int",
@ -402,6 +383,8 @@ SELECT all_types.big_int AS "all_types.big_int",
(all_types.big_int > ?) AS "gt2",
(all_types.big_int >= all_types.big_int_ptr) AS "gte1",
(all_types.big_int >= ?) AS "gte2",
(all_types.''integer'' BETWEEN ? AND ?) AS "between",
(all_types.''integer'' NOT BETWEEN ? AND ?) AS "not_between",
(all_types.big_int + all_types.big_int) AS "add1",
(all_types.big_int + ?) AS "add2",
(all_types.big_int - all_types.big_int) AS "sub1",
@ -412,7 +395,7 @@ SELECT all_types.big_int AS "all_types.big_int",
(all_types.big_int DIV ?) AS "div2",
(all_types.big_int % all_types.big_int) AS "mod1",
(all_types.big_int % ?) AS "mod2",
POW(all_types.small_int, (all_types.small_int DIV ?)) AS "pow1",
POW(all_types.small_int, all_types.small_int DIV ?) AS "pow1",
POW(all_types.small_int, ?) AS "pow2",
(all_types.small_int & all_types.small_int) AS "bit_and1",
(all_types.small_int & all_types.small_int) AS "bit_and2",
@ -428,10 +411,10 @@ SELECT all_types.big_int AS "all_types.big_int",
(all_types.small_int >> ?) AS "bit shift right 2",
ABS(all_types.big_int) AS "abs",
SQRT(ABS(all_types.big_int)) AS "sqrt",
POWER(ABS(all_types.big_int), (? / ?)) AS "cbrt"
POWER(ABS(all_types.big_int), ? / ?) AS "cbrt"
FROM test_sample.all_types
LIMIT ?;
`)
`, "''", "`"))
var dest []struct {
common.AllTypesIntegerExpResult `alias:"."`
@ -461,6 +444,8 @@ func TestStringOperators(t *testing.T) {
AllTypes.Text.LT(String("Text")),
AllTypes.Text.LT_EQ(AllTypes.VarCharPtr),
AllTypes.Text.LT_EQ(String("Text")),
AllTypes.Text.BETWEEN(String("min"), String("max")),
AllTypes.Text.NOT_BETWEEN(AllTypes.VarChar, AllTypes.CharPtr),
AllTypes.Text.CONCAT(String("text2")),
AllTypes.Text.CONCAT(Int(11)),
AllTypes.Text.LIKE(String("abc")),
@ -528,24 +513,21 @@ func TestTimeExpressions(t *testing.T) {
AllTypes.TimePtr.NOT_EQ(AllTypes.Time),
AllTypes.TimePtr.NOT_EQ(Time(20, 16, 6)),
AllTypes.Time.IS_DISTINCT_FROM(AllTypes.Time),
AllTypes.Time.IS_DISTINCT_FROM(Time(19, 26, 6)),
AllTypes.Time.IS_NOT_DISTINCT_FROM(AllTypes.Time),
AllTypes.Time.IS_NOT_DISTINCT_FROM(Time(18, 36, 6)),
AllTypes.Time.LT(AllTypes.Time),
AllTypes.Time.LT(Time(17, 46, 6)),
AllTypes.Time.LT_EQ(AllTypes.Time),
AllTypes.Time.LT_EQ(Time(16, 56, 56)),
AllTypes.Time.GT(AllTypes.Time),
AllTypes.Time.GT(Time(15, 16, 46)),
AllTypes.Time.GT_EQ(AllTypes.Time),
AllTypes.Time.GT_EQ(Time(14, 26, 36)),
AllTypes.Time.BETWEEN(Time(11, 0, 30, 100), AllTypes.TimePtr),
AllTypes.Time.NOT_BETWEEN(AllTypes.TimePtr, AllTypes.Time.ADD(INTERVAL(2, HOUR))),
AllTypes.Time.ADD(INTERVAL(10, MINUTE)),
AllTypes.Time.ADD(INTERVALe(AllTypes.Integer, MINUTE)),
@ -583,6 +565,8 @@ SELECT CAST('20:34:58' AS TIME),
all_types.time > CAST('15:16:46' AS TIME),
all_types.time >= all_types.time,
all_types.time >= CAST('14:26:36' AS TIME),
all_types.time BETWEEN CAST('11:00:30.0000001' AS TIME) AND all_types.time_ptr,
all_types.time NOT BETWEEN all_types.time_ptr AND (all_types.time + INTERVAL 2 HOUR),
all_types.time + INTERVAL 10 MINUTE,
all_types.time + INTERVAL all_types.''integer'' MINUTE,
all_types.time + INTERVAL 3 HOUR,
@ -594,7 +578,7 @@ SELECT CAST('20:34:58' AS TIME),
CURRENT_TIME(3)
FROM test_sample.all_types;
`, "''", "`", -1), "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06",
"19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36")
"19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36", "11:00:30.0000001")
dest := []struct{}{}
err := query.Query(db, &dest)
@ -608,27 +592,23 @@ func TestDateExpressions(t *testing.T) {
AllTypes.Date.EQ(AllTypes.Date),
AllTypes.Date.EQ(Date(2019, 6, 6)),
AllTypes.DatePtr.NOT_EQ(AllTypes.Date),
AllTypes.DatePtr.NOT_EQ(Date(2019, 1, 6)),
AllTypes.Date.IS_DISTINCT_FROM(AllTypes.Date),
AllTypes.Date.IS_DISTINCT_FROM(Date(2019, 2, 6)),
AllTypes.Date.IS_NOT_DISTINCT_FROM(AllTypes.Date),
AllTypes.Date.IS_NOT_DISTINCT_FROM(Date(2019, 3, 6)),
AllTypes.Date.LT(AllTypes.Date),
AllTypes.Date.LT(Date(2019, 4, 6)),
AllTypes.Date.LT_EQ(AllTypes.Date),
AllTypes.Date.LT_EQ(Date(2019, 5, 5)),
AllTypes.Date.GT(AllTypes.Date),
AllTypes.Date.GT(Date(2019, 1, 4)),
AllTypes.Date.GT_EQ(AllTypes.Date),
AllTypes.Date.GT_EQ(Date(2019, 2, 3)),
AllTypes.Date.BETWEEN(Date(2000, 2, 2), AllTypes.DatePtr),
AllTypes.Date.NOT_BETWEEN(AllTypes.DatePtr, Date(2000, 2, 2)),
AllTypes.Date.ADD(INTERVAL("10:20.000100", MINUTE_MICROSECOND)),
AllTypes.Date.ADD(INTERVALe(AllTypes.BigInt, MINUTE)),
@ -661,6 +641,8 @@ SELECT CAST('2009-11-17' AS DATE),
all_types.date > CAST('2019-01-04' AS DATE),
all_types.date >= all_types.date,
all_types.date >= CAST('2019-02-03' AS DATE),
all_types.date BETWEEN CAST('2000-02-02' AS DATE) AND all_types.date_ptr,
all_types.date NOT BETWEEN all_types.date_ptr AND CAST('2000-02-02' AS DATE),
all_types.date + INTERVAL '10:20.000100' MINUTE_MICROSECOND,
all_types.date + INTERVAL all_types.big_int MINUTE,
all_types.date + INTERVAL 15 HOUR,
@ -684,27 +666,23 @@ func TestDateTimeExpressions(t *testing.T) {
query := AllTypes.SELECT(
AllTypes.DateTime.EQ(AllTypes.DateTime),
AllTypes.DateTime.EQ(dateTime),
AllTypes.DateTimePtr.NOT_EQ(AllTypes.DateTime),
AllTypes.DateTimePtr.NOT_EQ(DateTime(2019, 6, 6, 10, 2, 46, 100*time.Millisecond)),
AllTypes.DateTime.IS_DISTINCT_FROM(AllTypes.DateTime),
AllTypes.DateTime.IS_DISTINCT_FROM(dateTime),
AllTypes.DateTime.IS_NOT_DISTINCT_FROM(AllTypes.DateTime),
AllTypes.DateTime.IS_NOT_DISTINCT_FROM(dateTime),
AllTypes.DateTime.LT(AllTypes.DateTime),
AllTypes.DateTime.LT(dateTime),
AllTypes.DateTime.LT_EQ(AllTypes.DateTime),
AllTypes.DateTime.LT_EQ(dateTime),
AllTypes.DateTime.GT(AllTypes.DateTime),
AllTypes.DateTime.GT(dateTime),
AllTypes.DateTime.GT_EQ(AllTypes.DateTime),
AllTypes.DateTime.GT_EQ(dateTime),
AllTypes.DateTime.BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr),
AllTypes.DateTime.NOT_BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr),
AllTypes.DateTime.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)),
AllTypes.DateTime.ADD(INTERVALe(AllTypes.BigInt, HOUR)),
@ -718,7 +696,7 @@ func TestDateTimeExpressions(t *testing.T) {
NOW(1),
)
//Println(query.DebugSql())
//fmt.Println(query.DebugSql())
testutils.AssertDebugStatementSql(t, query, `
SELECT all_types.date_time = all_types.date_time,
@ -737,6 +715,8 @@ SELECT all_types.date_time = all_types.date_time,
all_types.date_time > CAST('2019-06-06 10:02:46' AS DATETIME),
all_types.date_time >= all_types.date_time,
all_types.date_time >= CAST('2019-06-06 10:02:46' AS DATETIME),
all_types.date_time BETWEEN all_types.date_time_ptr AND all_types.timestamp_ptr,
all_types.date_time NOT BETWEEN all_types.date_time_ptr AND all_types.timestamp_ptr,
all_types.date_time + INTERVAL '05:10:20.000100' HOUR_MICROSECOND,
all_types.date_time + INTERVAL all_types.big_int HOUR,
all_types.date_time + INTERVAL 2 HOUR,
@ -761,27 +741,23 @@ func TestTimestampExpressions(t *testing.T) {
query := AllTypes.SELECT(
AllTypes.Timestamp.EQ(AllTypes.Timestamp),
AllTypes.Timestamp.EQ(timestamp),
AllTypes.TimestampPtr.NOT_EQ(AllTypes.Timestamp),
AllTypes.TimestampPtr.NOT_EQ(Timestamp(2019, 6, 6, 10, 2, 46, 100*time.Millisecond)),
AllTypes.Timestamp.IS_DISTINCT_FROM(AllTypes.Timestamp),
AllTypes.Timestamp.IS_DISTINCT_FROM(timestamp),
AllTypes.Timestamp.IS_NOT_DISTINCT_FROM(AllTypes.Timestamp),
AllTypes.Timestamp.IS_NOT_DISTINCT_FROM(timestamp),
AllTypes.Timestamp.LT(AllTypes.Timestamp),
AllTypes.Timestamp.LT(timestamp),
AllTypes.Timestamp.LT_EQ(AllTypes.Timestamp),
AllTypes.Timestamp.LT_EQ(timestamp),
AllTypes.Timestamp.GT(AllTypes.Timestamp),
AllTypes.Timestamp.GT(timestamp),
AllTypes.Timestamp.GT_EQ(AllTypes.Timestamp),
AllTypes.Timestamp.GT_EQ(timestamp),
AllTypes.Timestamp.BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr),
AllTypes.Timestamp.NOT_BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr),
AllTypes.Timestamp.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)),
AllTypes.Timestamp.ADD(INTERVALe(AllTypes.BigInt, HOUR)),
@ -814,6 +790,8 @@ SELECT all_types.timestamp = all_types.timestamp,
all_types.timestamp > TIMESTAMP('2019-06-06 10:02:46'),
all_types.timestamp >= all_types.timestamp,
all_types.timestamp >= TIMESTAMP('2019-06-06 10:02:46'),
all_types.timestamp BETWEEN all_types.date_time_ptr AND all_types.timestamp_ptr,
all_types.timestamp NOT BETWEEN all_types.date_time_ptr AND all_types.timestamp_ptr,
all_types.timestamp + INTERVAL '05:10:20.000100' HOUR_MICROSECOND,
all_types.timestamp + INTERVAL all_types.big_int HOUR,
all_types.timestamp + INTERVAL 2 HOUR,

View file

@ -4,6 +4,7 @@ import (
"context"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/require"
@ -91,3 +92,29 @@ func initForDeleteTest(t *testing.T) {
testutils.AssertExec(t, stmt, db, 2)
}
func TestDeleteWithUsing(t *testing.T) {
tx := beginTx(t)
defer tx.Rollback()
stmt := table.Rental.DELETE().
USING(
table.Rental.
INNER_JOIN(table.Staff, table.Rental.StaffID.EQ(table.Staff.StaffID)),
table.Actor,
).
WHERE(
table.Staff.StaffID.NOT_EQ(Int(2)).
AND(table.Rental.RentalID.LT(Int(100))),
)
testutils.AssertStatementSql(t, stmt, `
DELETE FROM dvds.rental
USING dvds.rental
INNER JOIN dvds.staff ON (rental.staff_id = staff.staff_id),
dvds.actor
WHERE (staff.staff_id != ?) AND (rental.rental_id < ?);
`)
testutils.AssertExec(t, stmt, tx)
}

View file

@ -25,18 +25,30 @@ var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "dvds/view")
var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "dvds/enum")
var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table", "actor.go")
var dbConnection = mysql2.DBConnection{
Host: dbconfig.MySqLHost,
Port: dbconfig.MySQLPort,
User: dbconfig.MySQLUser,
Password: dbconfig.MySQLPassword,
DBName: "dvds",
func dbConnection(dbName string) mysql2.DBConnection {
if sourceIsMariaDB() {
return mysql2.DBConnection{
Host: dbconfig.MariaDBHost,
Port: dbconfig.MariaDBPort,
User: dbconfig.MariaDBUser,
Password: dbconfig.MariaDBPassword,
DBName: dbName,
}
}
return mysql2.DBConnection{
Host: dbconfig.MySqLHost,
Port: dbconfig.MySQLPort,
User: dbconfig.MySQLUser,
Password: dbconfig.MySQLPassword,
DBName: dbName,
}
}
func TestGeneratorTemplate_Schema_ChangePath(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).UsePath("new/schema/path")
@ -54,7 +66,7 @@ func TestGeneratorTemplate_Schema_ChangePath(t *testing.T) {
func TestGeneratorTemplate_Model_SkipGeneration(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
@ -75,7 +87,7 @@ func TestGeneratorTemplate_Model_SkipGeneration(t *testing.T) {
func TestGeneratorTemplate_SQLBuilder_SkipGeneration(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
@ -98,7 +110,7 @@ func TestGeneratorTemplate_Model_ChangePath(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
@ -116,7 +128,7 @@ func TestGeneratorTemplate_SQLBuilder_ChangePath(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
@ -137,7 +149,7 @@ func TestGeneratorTemplate_SQLBuilder_ChangePath(t *testing.T) {
func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
@ -175,7 +187,7 @@ func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) {
func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
@ -203,7 +215,7 @@ func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) {
func TestGeneratorTemplate_SQLBuilder_SkipTableAndEnum(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
@ -236,7 +248,7 @@ func TestGeneratorTemplate_SQLBuilder_SkipTableAndEnum(t *testing.T) {
func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
@ -277,7 +289,7 @@ func TestGeneratorTemplate_Model_AddTags(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
@ -318,7 +330,7 @@ func TestGeneratorTemplate_Model_AddTags(t *testing.T) {
func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
@ -361,7 +373,7 @@ func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) {
func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
dbConnection("dvds"),
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).

View file

@ -1,10 +1,10 @@
package mysql
import (
"fmt"
"io/ioutil"
"os"
"os/exec"
"strconv"
"testing"
"github.com/go-jet/jet/v2/generator/mysql"
@ -19,13 +19,7 @@ const genTestDir3 = "./.gentestdata3/mysql"
func TestGenerator(t *testing.T) {
for i := 0; i < 3; i++ {
err := mysql.Generate(genTestDir3, mysql.DBConnection{
Host: dbconfig.MySqLHost,
Port: dbconfig.MySQLPort,
User: dbconfig.MySQLUser,
Password: dbconfig.MySQLPassword,
DBName: "dvds",
})
err := mysql.Generate(genTestDir3, dbConnection("dvds"))
require.NoError(t, err)
@ -33,17 +27,11 @@ func TestGenerator(t *testing.T) {
}
for i := 0; i < 3; i++ {
dsn := fmt.Sprintf("%[1]s:%[2]s@tcp(%[3]s:%[4]d)/%[5]s",
dbconfig.MySQLUser,
dbconfig.MySQLPassword,
dbconfig.MySqLHost,
dbconfig.MySQLPort,
"dvds",
)
dsn := dbconfig.MySQLConnectionString(sourceIsMariaDB(), "dvds")
err := mysql.GenerateDSN(dsn, genTestDir3)
require.NoError(t, err)
assertGeneratedFiles(t)
}
@ -55,8 +43,27 @@ func TestCmdGenerator(t *testing.T) {
err := os.RemoveAll(genTestDir3)
require.NoError(t, err)
cmd := exec.Command("jet", "-source=MySQL", "-dbname=dvds", "-host=localhost", "-port=3306",
"-user=jet", "-password=jet", "-path="+genTestDir3)
var cmd *exec.Cmd
if sourceIsMariaDB() {
cmd = exec.Command("jet",
"-source=MariaDB",
"-dbname=dvds",
"-host="+dbconfig.MariaDBHost,
"-port="+strconv.Itoa(dbconfig.MariaDBPort),
"-user="+dbconfig.MariaDBUser,
"-password="+dbconfig.MariaDBPassword,
"-path="+genTestDir3)
} else {
cmd = exec.Command("jet",
"-source=MySQL",
"-dbname=dvds",
"-host="+dbconfig.MySqLHost,
"-port="+strconv.Itoa(dbconfig.MySQLPort),
"-user="+dbconfig.MySQLUser,
"-password="+dbconfig.MySQLPassword,
"-path="+genTestDir3)
}
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
@ -70,13 +77,7 @@ func TestCmdGenerator(t *testing.T) {
require.NoError(t, err)
// check that generation via DSN works
dsn := fmt.Sprintf("mysql://%[1]s:%[2]s@tcp(%[3]s:%[4]d)/%[5]s",
dbconfig.MySQLUser,
dbconfig.MySQLPassword,
dbconfig.MySqLHost,
dbconfig.MySQLPort,
"dvds",
)
dsn := "mysql://" + dbconfig.MySQLConnectionString(sourceIsMariaDB(), "dvds")
cmd = exec.Command("jet", "-dsn="+dsn, "-path="+genTestDir3)
cmd.Stderr = os.Stderr
@ -84,9 +85,48 @@ func TestCmdGenerator(t *testing.T) {
err = cmd.Run()
require.NoError(t, err)
}
err = os.RemoveAll(genTestDirRoot)
func TestIgnoreTablesViewsEnums(t *testing.T) {
cmd := exec.Command("jet",
"-source=MySQL",
"-dbname=dvds",
"-host="+dbconfig.MySqLHost,
"-port="+strconv.Itoa(dbconfig.MySQLPort),
"-user="+dbconfig.MySQLUser,
"-password="+dbconfig.MySQLPassword,
"-ignore-tables=actor,ADDRESS,Category, city ,country,staff,store,rental",
"-ignore-views=actor_info,CUSTomER_LIST, film_list",
"-ignore-enums=film_list_rating,film_rating",
"-path="+genTestDir3)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
err := cmd.Run()
require.NoError(t, err)
tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table")
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "customer.go", "film.go", "film_actor.go",
"film_category.go", "film_text.go", "inventory.go", "language.go", "payment.go")
viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view")
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "sales_by_store.go", "staff_list.go")
enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum")
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, enumFiles, "nicer_but_slower_film_list_rating.go")
modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model")
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, modelFiles,
"customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
"payment.go", "nicer_but_slower_film_list_rating.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go",
"sales_by_store.go", "staff_list.go")
}
func assertGeneratedFiles(t *testing.T) {

View file

@ -8,6 +8,7 @@ import (
"github.com/go-jet/jet/v2/tests/dbconfig"
"github.com/stretchr/testify/require"
"math/rand"
"runtime"
"time"
_ "github.com/go-sql-driver/mysql"
@ -36,7 +37,7 @@ func TestMain(m *testing.M) {
defer profile.Start().Stop()
var err error
db, err = sql.Open("mysql", dbconfig.MySQLConnectionString)
db, err = sql.Open("mysql", dbconfig.MySQLConnectionString(sourceIsMariaDB(), ""))
if err != nil {
panic("Failed to connect to test db" + err.Error())
}
@ -51,11 +52,21 @@ var loggedSQL string
var loggedSQLArgs []interface{}
var loggedDebugSQL string
var queryInfo jetmysql.QueryInfo
var callerFile string
var callerLine int
var callerFunction string
func init() {
jetmysql.SetLogger(func(ctx context.Context, statement jetmysql.PrintableStatement) {
loggedSQL, loggedSQLArgs = statement.Sql()
loggedDebugSQL = statement.DebugSql()
})
jetmysql.SetQueryLogger(func(ctx context.Context, info jetmysql.QueryInfo) {
queryInfo = info
callerFile, callerLine, callerFunction = info.Caller()
})
}
func requireLogged(t *testing.T, statement postgres.Statement) {
@ -65,8 +76,29 @@ func requireLogged(t *testing.T, statement postgres.Statement) {
require.Equal(t, loggedDebugSQL, statement.DebugSql())
}
func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) {
query, args := statement.Sql()
queryLogged, argsLogged := queryInfo.Statement.Sql()
require.Equal(t, query, queryLogged)
require.Equal(t, args, argsLogged)
require.Equal(t, queryInfo.RowsProcessed, rowsProcessed)
pc, file, _, _ := runtime.Caller(1)
funcDetails := runtime.FuncForPC(pc)
require.Equal(t, file, callerFile)
require.NotEmpty(t, callerLine)
require.Equal(t, funcDetails.Name(), callerFunction)
}
func skipForMariaDB(t *testing.T) {
if sourceIsMariaDB() {
t.SkipNow()
}
}
func beginTx(t *testing.T) *sql.Tx {
tx, err := db.Begin()
require.NoError(t, err)
return tx
}

View file

@ -38,6 +38,7 @@ WHERE actor.actor_id = ?;
testutils.AssertDeepEqual(t, actor, actor2)
requireLogged(t, query)
requireQueryLogged(t, query, 1)
}
var actor2 = model.Actor{
@ -60,9 +61,9 @@ SELECT actor.actor_id AS "actor.actor_id",
FROM dvds.actor
ORDER BY actor.actor_id;
`)
dest := []model.Actor{}
var dest []model.Actor
err := query.Query(db, &dest)
err := query.QueryContext(context.Background(), db, &dest)
require.NoError(t, err)
@ -73,6 +74,7 @@ ORDER BY actor.actor_id;
//testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json")
requireLogged(t, query)
requireQueryLogged(t, query, 200)
}
func TestSelectGroupByHaving(t *testing.T) {
@ -153,6 +155,68 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC;
requireLogged(t, query)
}
func TestAggregateFunctionDistinct(t *testing.T) {
stmt := SELECT(
Payment.CustomerID,
COUNT(DISTINCT(Payment.Amount)).AS("distinct.count"),
SUM(DISTINCT(Payment.Amount)).AS("distinct.sum"),
AVG(DISTINCT(Payment.Amount)).AS("distinct.avg"),
MIN(DISTINCT(Payment.PaymentDate)).AS("distinct.first_payment_date"),
MAX(DISTINCT(Payment.PaymentDate)).AS("distinct.last_payment_date"),
).FROM(
Payment,
).WHERE(
Payment.CustomerID.EQ(Int(1)),
).GROUP_BY(
Payment.CustomerID,
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT payment.customer_id AS "payment.customer_id",
COUNT(DISTINCT payment.amount) AS "distinct.count",
SUM(DISTINCT payment.amount) AS "distinct.sum",
AVG(DISTINCT payment.amount) AS "distinct.avg",
MIN(DISTINCT payment.payment_date) AS "distinct.first_payment_date",
MAX(DISTINCT payment.payment_date) AS "distinct.last_payment_date"
FROM dvds.payment
WHERE payment.customer_id = 1
GROUP BY payment.customer_id;
`)
type Distinct struct {
model.Payment
Count int64
Sum float64
Avg float64
FirstPaymentDate time.Time
LastPaymentDate time.Time
}
var dest Distinct
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
{
"PaymentID": 0,
"CustomerID": 1,
"StaffID": 0,
"RentalID": null,
"Amount": 0,
"PaymentDate": "0001-01-01T00:00:00Z",
"LastUpdate": "0001-01-01T00:00:00Z",
"Count": 8,
"Sum": 38.92,
"Avg": 4.865,
"FirstPaymentDate": "2005-05-25T11:30:37Z",
"LastPaymentDate": "2005-08-22T20:03:46Z"
}
`)
}
func TestSubQuery(t *testing.T) {
rRatingFilms := Film.
@ -389,8 +453,6 @@ LIMIT ?;
).
LIMIT(1000)
//fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, int64(1000))
var dest []struct {
@ -414,12 +476,7 @@ LIMIT ?;
err := query.Query(db, &dest)
require.NoError(t, err)
//require.Equal(t, len(dest), 1)
//require.Equal(t, len(dest[0].Films), 10)
//require.Equal(t, len(dest[0].Films[0].Actors), 10)
//testutils.SaveJsonFile(dest, "./mysql/testdata/lang_film_actor_inventory_rental.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/lang_film_actor_inventory_rental.json")
}
}

View file

@ -261,15 +261,22 @@ func TestUpdateExecContext(t *testing.T) {
}
func TestUpdateWithJoin(t *testing.T) {
query := table.Staff.
INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)).
tx := beginTx(t)
defer tx.Rollback()
statement := table.Staff.INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)).
UPDATE(table.Staff.LastName).
SET(String("New name")).
SET(String("New staff name")).
WHERE(table.Staff.StaffID.EQ(Int(1)))
//fmt.Println(query.DebugSql())
testutils.AssertStatementSql(t, statement, `
UPDATE dvds.staff
INNER JOIN dvds.address ON (address.address_id = staff.address_id)
SET last_name = ?
WHERE staff.staff_id = ?;
`, "New staff name", int64(1))
_, err := query.Exec(db)
_, err := statement.Exec(tx)
require.NoError(t, err)
}

View file

@ -149,7 +149,26 @@ func TestWITH_And_DELETE(t *testing.T) {
),
)
//fmt.Println(stmt.DebugSql())
// fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(`
WITH payments_to_delete AS (
SELECT payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id",
payment.rental_id AS "payment.rental_id",
payment.amount AS "payment.amount",
payment.payment_date AS "payment.payment_date",
payment.last_update AS "payment.last_update"
FROM dvds.payment
WHERE payment.amount < 0.5
)
DELETE FROM dvds.payment
WHERE payment.payment_id IN (
SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_delete
);
`, "''", "`"))
tx, err := db.Begin()
require.NoError(t, err)
@ -157,3 +176,119 @@ func TestWITH_And_DELETE(t *testing.T) {
testutils.AssertExec(t, stmt, tx, 24)
}
func TestRecursiveWithStatement_Fibonacci(t *testing.T) {
// CTE columns are listed as part of CTE definition
n1 := IntegerColumn("n1")
fibN1 := IntegerColumn("fibN1")
nextFibN1 := IntegerColumn("nextFibN1")
fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1)
// CTE columns are columns from non-recursive select
fibonacci2 := CTE("fibonacci2")
n2 := IntegerColumn("n2").From(fibonacci2)
fibN2 := IntegerColumn("fibN2").From(fibonacci2)
nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2)
stmt := WITH_RECURSIVE(
fibonacci1.AS(
SELECT(
Int32(1), Int32(0), Int32(1),
).UNION_ALL(
SELECT(
n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1),
).FROM(
fibonacci1,
).WHERE(
n1.LT(Int(20)),
),
),
),
fibonacci2.AS(
SELECT(
Int32(1).AS(n2.Name()),
Int32(0).AS(fibN2.Name()),
Int32(1).AS(nextFibN2.Name()),
).UNION_ALL(
SELECT(
n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2),
).FROM(
fibonacci2,
).WHERE(
n2.LT(Int(20)),
),
),
),
)(
SELECT(
fibonacci1.AllColumns(),
fibonacci2.AllColumns(),
).FROM(
fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)),
).WHERE(
n1.EQ(Int(20)),
),
)
// fmt.Println(stmt.Sql())
testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(`
WITH RECURSIVE fibonacci1 (n1, ''fibN1'', ''nextFibN1'') AS (
(
SELECT ?,
?,
?
)
UNION ALL
(
SELECT fibonacci1.n1 + ?,
fibonacci1.''nextFibN1'' AS "nextFibN1",
fibonacci1.''fibN1'' + fibonacci1.''nextFibN1''
FROM fibonacci1
WHERE fibonacci1.n1 < ?
)
),fibonacci2 AS (
(
SELECT ? AS "n2",
? AS "fibN2",
? AS "nextFibN2"
)
UNION ALL
(
SELECT fibonacci2.n2 + ?,
fibonacci2.''nextFibN2'' AS "nextFibN2",
fibonacci2.''fibN2'' + fibonacci2.''nextFibN2''
FROM fibonacci2
WHERE fibonacci2.n2 < ?
)
)
SELECT fibonacci1.n1 AS "n1",
fibonacci1.''fibN1'' AS "fibN1",
fibonacci1.''nextFibN1'' AS "nextFibN1",
fibonacci2.n2 AS "n2",
fibonacci2.''fibN2'' AS "fibN2",
fibonacci2.''nextFibN2'' AS "nextFibN2"
FROM fibonacci1
INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2)
WHERE fibonacci1.n1 = ?;
`, "''", "`"))
var dest struct {
N1 int
FibN1 int
NextFibN1 int
N2 int
FibN2 int
NextFibN2 int
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.N1, 20)
require.Equal(t, dest.FibN1, 4181)
require.Equal(t, dest.NextFibN1, 6765)
require.Equal(t, dest.N2, 20)
require.Equal(t, dest.FibN2, 4181)
require.Equal(t, dest.NextFibN2, 6765)
}

View file

@ -225,7 +225,7 @@ func TestExpressionOperators(t *testing.T) {
query := AllTypes.SELECT(
AllTypes.Integer.IS_NULL().AS("result.is_null"),
AllTypes.DatePtr.IS_NOT_NULL().AS("result.is_not_null"),
AllTypes.SmallIntPtr.IN(Int(11), Int(22)).AS("result.in"),
AllTypes.SmallIntPtr.IN(Int8(11), Int8(22)).AS("result.in"),
AllTypes.SmallIntPtr.IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.in_select"),
Raw("CURRENT_USER").AS("result.raw"),
@ -233,14 +233,16 @@ func TestExpressionOperators(t *testing.T) {
Raw("#1 + all_types.integer + #2 + #1 + #3 + #4",
RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}).AS("result.raw_arg2"),
AllTypes.SmallIntPtr.NOT_IN(Int(11), Int(22), NULL).AS("result.not_in"),
AllTypes.SmallIntPtr.NOT_IN(Int(11), Int16(22), NULL).AS("result.not_in"),
AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"),
).LIMIT(2)
//fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, `
SELECT all_types.integer IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN ($1, $2)) AS "result.in",
(all_types.small_int_ptr IN ($1::smallint, $2::smallint)) AS "result.in",
(all_types.small_int_ptr IN (
SELECT all_types.integer AS "all_types.integer"
FROM test_sample.all_types
@ -248,14 +250,14 @@ SELECT all_types.integer IS NULL AS "result.is_null",
(CURRENT_USER) AS "result.raw",
($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg",
($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2",
(all_types.small_int_ptr NOT IN ($9, $10, NULL)) AS "result.not_in",
(all_types.small_int_ptr NOT IN ($9, $10::smallint, NULL)) AS "result.not_in",
(all_types.small_int_ptr NOT IN (
SELECT all_types.integer AS "all_types.integer"
FROM test_sample.all_types
)) AS "result.not_in_select"
FROM test_sample.all_types
LIMIT $11;
`, int64(11), int64(22), 78, 56, 11, 22, 33, 44, int64(11), int64(22), int64(2))
`, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), int64(2))
var dest []struct {
common.ExpressionTestResult `alias:"result.*"`
@ -359,6 +361,8 @@ func TestStringOperators(t *testing.T) {
AllTypes.Text.LT(String("Text")),
AllTypes.Text.LT_EQ(AllTypes.VarChar),
AllTypes.Text.LT_EQ(String("Text")),
AllTypes.Text.BETWEEN(String("min"), String("max")),
AllTypes.Text.NOT_BETWEEN(AllTypes.VarChar, AllTypes.CharPtr),
AllTypes.Text.CONCAT(String("text2")),
AllTypes.Text.CONCAT(Int(11)),
AllTypes.Text.LIKE(String("abc")),
@ -450,13 +454,13 @@ func TestBoolOperators(t *testing.T) {
testutils.AssertStatementSql(t, query, `
SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1",
(all_types.boolean = $1) AS "EQ2",
(all_types.boolean = $1::boolean) AS "EQ2",
(all_types.boolean != all_types.boolean_ptr) AS "NEq1",
(all_types.boolean != $2) AS "NEq2",
(all_types.boolean != $2::boolean) AS "NEq2",
(all_types.boolean IS DISTINCT FROM all_types.boolean_ptr) AS "distinct1",
(all_types.boolean IS DISTINCT FROM $3) AS "distinct2",
(all_types.boolean IS DISTINCT FROM $3::boolean) AS "distinct2",
(all_types.boolean IS NOT DISTINCT FROM all_types.boolean_ptr) AS "not_distinct_1",
(all_types.boolean IS NOT DISTINCT FROM $4) AS "NOTDISTINCT2",
(all_types.boolean IS NOT DISTINCT FROM $4::boolean) AS "NOTDISTINCT2",
all_types.boolean IS TRUE AS "ISTRUE",
all_types.boolean IS NOT TRUE AS "isnottrue",
all_types.boolean IS FALSE AS "is_False",
@ -511,24 +515,26 @@ func TestFloatOperators(t *testing.T) {
AllTypes.Numeric.LT(Float(34.56)).AS("lt2"),
AllTypes.Numeric.GT(Float(124)).AS("gt1"),
AllTypes.Numeric.GT(Float(34.56)).AS("gt2"),
AllTypes.Numeric.BETWEEN(Float(1.34), AllTypes.Decimal).AS("between"),
AllTypes.Numeric.NOT_BETWEEN(AllTypes.Decimal.MUL(Float(3)), Float(100.12)).AS("not_between"),
TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Int(2)).AS("add1"),
TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int(2)).AS("add2"),
TRUNC(AllTypes.Decimal.SUB(AllTypes.DecimalPtr), Int(2)).AS("sub1"),
TRUNC(AllTypes.Decimal.SUB(Float(11.22)), Int(2)).AS("sub2"),
TRUNC(AllTypes.Decimal.MUL(AllTypes.DecimalPtr), Int(2)).AS("mul1"),
TRUNC(AllTypes.Decimal.MUL(Float(11.22)), Int(2)).AS("mul2"),
TRUNC(AllTypes.Decimal.DIV(AllTypes.DecimalPtr), Int(2)).AS("div1"),
TRUNC(AllTypes.Decimal.DIV(Float(11.22)), Int(2)).AS("div2"),
TRUNC(AllTypes.Decimal.MOD(AllTypes.DecimalPtr), Int(2)).AS("mod1"),
TRUNC(AllTypes.Decimal.MOD(Float(11.22)), Int(2)).AS("mod2"),
TRUNC(AllTypes.Decimal.POW(AllTypes.DecimalPtr), Int(2)).AS("pow1"),
TRUNC(AllTypes.Decimal.POW(Float(2.1)), Int(2)).AS("pow2"),
TRUNC(AllTypes.Decimal.ADD(AllTypes.Decimal), Uint8(2)).AS("add1"),
TRUNC(AllTypes.Decimal.ADD(Float(11.22)), Int8(2)).AS("add2"),
TRUNC(AllTypes.Decimal.SUB(AllTypes.DecimalPtr), Uint16(2)).AS("sub1"),
TRUNC(AllTypes.Decimal.SUB(Float(11.22)), Int16(2)).AS("sub2"),
TRUNC(AllTypes.Decimal.MUL(AllTypes.DecimalPtr), Int16(2)).AS("mul1"),
TRUNC(AllTypes.Decimal.MUL(Float(11.22)), Int32(2)).AS("mul2"),
TRUNC(AllTypes.Decimal.DIV(AllTypes.DecimalPtr), Int32(2)).AS("div1"),
TRUNC(AllTypes.Decimal.DIV(Float(11.22)), Int8(2)).AS("div2"),
TRUNC(AllTypes.Decimal.MOD(AllTypes.DecimalPtr), Int8(2)).AS("mod1"),
TRUNC(AllTypes.Decimal.MOD(Float(11.22)), Int8(2)).AS("mod2"),
TRUNC(AllTypes.Decimal.POW(AllTypes.DecimalPtr), Int8(2)).AS("pow1"),
TRUNC(AllTypes.Decimal.POW(Float(2.1)), Int8(2)).AS("pow2"),
TRUNC(ABSf(AllTypes.Decimal), Int(2)).AS("abs"),
TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int(2)).AS("power"),
TRUNC(SQRT(AllTypes.Decimal), Int(2)).AS("sqrt"),
TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"),
TRUNC(ABSf(AllTypes.Decimal), Int8(2)).AS("abs"),
TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int8(2)).AS("power"),
TRUNC(SQRT(AllTypes.Decimal), Int16(2)).AS("sqrt"),
TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int8(2)).AS("cbrt"),
CEIL(AllTypes.Real).AS("ceil"),
FLOOR(AllTypes.Real).AS("floor"),
@ -536,12 +542,12 @@ func TestFloatOperators(t *testing.T) {
ROUND(AllTypes.Decimal, AllTypes.Integer).AS("round2"),
SIGN(AllTypes.Real).AS("sign"),
TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"),
TRUNC(AllTypes.Decimal, Int32(1)).AS("trunc"),
).LIMIT(2)
queryStr, _ := query.Sql()
//fmt.Println(query.Sql())
require.Equal(t, queryStr, `
testutils.AssertStatementSql(t, query, `
SELECT (all_types.numeric = all_types.numeric) AS "eq1",
(all_types.decimal = $1) AS "eq2",
(all_types.real = $2) AS "eq3",
@ -555,30 +561,32 @@ SELECT (all_types.numeric = all_types.numeric) AS "eq1",
(all_types.numeric < $8) AS "lt2",
(all_types.numeric > $9) AS "gt1",
(all_types.numeric > $10) AS "gt2",
TRUNC((all_types.decimal + all_types.decimal), $11) AS "add1",
TRUNC((all_types.decimal + $12), $13) AS "add2",
TRUNC((all_types.decimal - all_types.decimal_ptr), $14) AS "sub1",
TRUNC((all_types.decimal - $15), $16) AS "sub2",
TRUNC((all_types.decimal * all_types.decimal_ptr), $17) AS "mul1",
TRUNC((all_types.decimal * $18), $19) AS "mul2",
TRUNC((all_types.decimal / all_types.decimal_ptr), $20) AS "div1",
TRUNC((all_types.decimal / $21), $22) AS "div2",
TRUNC((all_types.decimal % all_types.decimal_ptr), $23) AS "mod1",
TRUNC((all_types.decimal % $24), $25) AS "mod2",
TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $26) AS "pow1",
TRUNC(POW(all_types.decimal, $27), $28) AS "pow2",
TRUNC(ABS(all_types.decimal), $29) AS "abs",
TRUNC(POWER(all_types.decimal, $30), $31) AS "power",
TRUNC(SQRT(all_types.decimal), $32) AS "sqrt",
TRUNC(CBRT(all_types.decimal)::decimal, $33) AS "cbrt",
(all_types.numeric BETWEEN $11 AND all_types.decimal) AS "between",
(all_types.numeric NOT BETWEEN (all_types.decimal * $12) AND $13) AS "not_between",
TRUNC(all_types.decimal + all_types.decimal, $14::smallint) AS "add1",
TRUNC(all_types.decimal + $15, $16::smallint) AS "add2",
TRUNC(all_types.decimal - all_types.decimal_ptr, $17::integer) AS "sub1",
TRUNC(all_types.decimal - $18, $19::smallint) AS "sub2",
TRUNC(all_types.decimal * all_types.decimal_ptr, $20::smallint) AS "mul1",
TRUNC(all_types.decimal * $21, $22::integer) AS "mul2",
TRUNC(all_types.decimal / all_types.decimal_ptr, $23::integer) AS "div1",
TRUNC(all_types.decimal / $24, $25::smallint) AS "div2",
TRUNC(all_types.decimal % all_types.decimal_ptr, $26::smallint) AS "mod1",
TRUNC(all_types.decimal % $27, $28::smallint) AS "mod2",
TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $29::smallint) AS "pow1",
TRUNC(POW(all_types.decimal, $30), $31::smallint) AS "pow2",
TRUNC(ABS(all_types.decimal), $32::smallint) AS "abs",
TRUNC(POWER(all_types.decimal, $33), $34::smallint) AS "power",
TRUNC(SQRT(all_types.decimal), $35::smallint) AS "sqrt",
TRUNC(CBRT(all_types.decimal)::decimal, $36::smallint) AS "cbrt",
CEIL(all_types.real) AS "ceil",
FLOOR(all_types.real) AS "floor",
ROUND(all_types.decimal) AS "round1",
ROUND(all_types.decimal, all_types.integer) AS "round2",
SIGN(all_types.real) AS "sign",
TRUNC(all_types.decimal, $34) AS "trunc"
TRUNC(all_types.decimal, $37::integer) AS "trunc"
FROM test_sample.all_types
LIMIT $35;
LIMIT $38;
`)
var dest []struct {
@ -590,6 +598,7 @@ LIMIT $35;
require.NoError(t, err)
//testutils.PrintJson(dest)
// testutils.SaveJSONFile(dest, "./testdata/results/common/float_operators.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json")
}
@ -602,62 +611,50 @@ func TestIntegerOperators(t *testing.T) {
AllTypes.SmallIntPtr,
AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"),
AllTypes.BigInt.EQ(Int(12)).AS("eq2"),
AllTypes.BigInt.EQ(Int64(12)).AS("eq2"),
AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"),
AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"),
AllTypes.BigInt.NOT_EQ(Int64(12)).AS("neq2"),
AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"),
AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"),
AllTypes.BigInt.IS_DISTINCT_FROM(Int32(12)).AS("distinct2"),
AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"),
AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"),
AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int32(12)).AS("not distinct2"),
AllTypes.Integer.BETWEEN(Int(11), Int(200)).AS("between"),
AllTypes.Integer.NOT_BETWEEN(Int(66), Int(77)).AS("not_between"),
AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"),
AllTypes.BigInt.LT(Int(65)).AS("lt2"),
AllTypes.BigInt.LT(Uint8(65)).AS("lt2"),
AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"),
AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"),
AllTypes.BigInt.LT_EQ(Uint16(65)).AS("lte2"),
AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"),
AllTypes.BigInt.GT(Int(65)).AS("gt2"),
AllTypes.BigInt.GT(Uint32(65)).AS("gt2"),
AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"),
AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"),
AllTypes.BigInt.GT_EQ(Uint64(65)).AS("gte2"),
AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"),
AllTypes.BigInt.ADD(Int(11)).AS("add2"),
AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"),
AllTypes.BigInt.SUB(Int(11)).AS("sub2"),
AllTypes.BigInt.SUB(Int8(11)).AS("sub2"),
AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"),
AllTypes.BigInt.MUL(Int(11)).AS("mul2"),
AllTypes.BigInt.MUL(Int16(11)).AS("mul2"),
AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"),
AllTypes.BigInt.DIV(Int(11)).AS("div2"),
AllTypes.BigInt.DIV(Int32(11)).AS("div2"),
AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"),
AllTypes.BigInt.MOD(Int(11)).AS("mod2"),
AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int(3))).AS("pow1"),
AllTypes.SmallInt.POW(Int(6)).AS("pow2"),
AllTypes.BigInt.MOD(Int64(11)).AS("mod2"),
AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int8(3))).AS("pow1"),
AllTypes.SmallInt.POW(Int8(6)).AS("pow2"),
AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"),
AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"),
AllTypes.SmallInt.BIT_OR(AllTypes.SmallInt).AS("bit or 1"),
AllTypes.SmallInt.BIT_OR(Int(22)).AS("bit or 2"),
AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"),
AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"),
BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"),
BIT_NOT(Int(-11)).AS("bit_not_2"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int8(2))).AS("bit shift left 1"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"),
AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"),
AllTypes.SmallInt.BIT_SHIFT_RIGHT(Int(1)).AS("bit shift right 2"),
@ -666,7 +663,7 @@ func TestIntegerOperators(t *testing.T) {
CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"),
).LIMIT(2)
//fmt.Println(query.Sql())
// fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, `
SELECT all_types.big_int AS "all_types.big_int",
@ -674,50 +671,52 @@ SELECT all_types.big_int AS "all_types.big_int",
all_types.small_int AS "all_types.small_int",
all_types.small_int_ptr AS "all_types.small_int_ptr",
(all_types.big_int = all_types.big_int) AS "eq1",
(all_types.big_int = $1) AS "eq2",
(all_types.big_int = $1::bigint) AS "eq2",
(all_types.big_int != all_types.big_int_ptr) AS "neq1",
(all_types.big_int != $2) AS "neq2",
(all_types.big_int != $2::bigint) AS "neq2",
(all_types.big_int IS DISTINCT FROM all_types.big_int) AS "distinct1",
(all_types.big_int IS DISTINCT FROM $3) AS "distinct2",
(all_types.big_int IS DISTINCT FROM $3::integer) AS "distinct2",
(all_types.big_int IS NOT DISTINCT FROM all_types.big_int) AS "not distinct1",
(all_types.big_int IS NOT DISTINCT FROM $4) AS "not distinct2",
(all_types.big_int IS NOT DISTINCT FROM $4::integer) AS "not distinct2",
(all_types.integer BETWEEN $5 AND $6) AS "between",
(all_types.integer NOT BETWEEN $7 AND $8) AS "not_between",
(all_types.big_int < all_types.big_int_ptr) AS "lt1",
(all_types.big_int < $5) AS "lt2",
(all_types.big_int < $9::smallint) AS "lt2",
(all_types.big_int <= all_types.big_int_ptr) AS "lte1",
(all_types.big_int <= $6) AS "lte2",
(all_types.big_int <= $10::integer) AS "lte2",
(all_types.big_int > all_types.big_int_ptr) AS "gt1",
(all_types.big_int > $7) AS "gt2",
(all_types.big_int > $11::bigint) AS "gt2",
(all_types.big_int >= all_types.big_int_ptr) AS "gte1",
(all_types.big_int >= $8) AS "gte2",
(all_types.big_int >= $12::bigint) AS "gte2",
(all_types.big_int + all_types.big_int) AS "add1",
(all_types.big_int + $9) AS "add2",
(all_types.big_int + $13) AS "add2",
(all_types.big_int - all_types.big_int) AS "sub1",
(all_types.big_int - $10) AS "sub2",
(all_types.big_int - $14::smallint) AS "sub2",
(all_types.big_int * all_types.big_int) AS "mul1",
(all_types.big_int * $11) AS "mul2",
(all_types.big_int * $15::smallint) AS "mul2",
(all_types.big_int / all_types.big_int) AS "div1",
(all_types.big_int / $12) AS "div2",
(all_types.big_int / $16::integer) AS "div2",
(all_types.big_int % all_types.big_int) AS "mod1",
(all_types.big_int % $13) AS "mod2",
POW(all_types.small_int, (all_types.small_int / $14)) AS "pow1",
POW(all_types.small_int, $15) AS "pow2",
(all_types.big_int % $17::bigint) AS "mod2",
POW(all_types.small_int, all_types.small_int / $18::smallint) AS "pow1",
POW(all_types.small_int, $19::smallint) AS "pow2",
(all_types.small_int & all_types.small_int) AS "bit_and1",
(all_types.small_int & all_types.small_int) AS "bit_and2",
(all_types.small_int | all_types.small_int) AS "bit or 1",
(all_types.small_int | $16) AS "bit or 2",
(all_types.small_int | $20) AS "bit or 2",
(all_types.small_int # all_types.small_int) AS "bit xor 1",
(all_types.small_int # $17) AS "bit xor 2",
(~ ($18 * all_types.small_int)) AS "bit_not_1",
(all_types.small_int # $21) AS "bit xor 2",
(~ ($22 * all_types.small_int)) AS "bit_not_1",
(~ -11) AS "bit_not_2",
(all_types.small_int << (all_types.small_int / $19)) AS "bit shift left 1",
(all_types.small_int << $20) AS "bit shift left 2",
(all_types.small_int >> (all_types.small_int / $21)) AS "bit shift right 1",
(all_types.small_int >> $22) AS "bit shift right 2",
(all_types.small_int << (all_types.small_int / $23::smallint)) AS "bit shift left 1",
(all_types.small_int << $24) AS "bit shift left 2",
(all_types.small_int >> (all_types.small_int / $25)) AS "bit shift right 1",
(all_types.small_int >> $26) AS "bit shift right 2",
ABS(all_types.big_int) AS "abs",
SQRT(ABS(all_types.big_int)) AS "sqrt",
CBRT(ABS(all_types.big_int)) AS "cbrt"
FROM test_sample.all_types
LIMIT $23;
LIMIT $27;
`)
var dest []struct {
@ -728,7 +727,7 @@ LIMIT $23;
require.NoError(t, err)
//testutils.SaveJsonFile("./testdata/common/int_operators.json", dest)
//testutils.SaveJSONFile(dest, "./testdata/results/common/int_operators.json")
//testutils.PrintJson(dest)
testutils.AssertJSONFile(t, dest, "./testdata/results/common/int_operators.json")
}
@ -759,21 +758,18 @@ func TestTimeExpression(t *testing.T) {
AllTypes.Time.IS_DISTINCT_FROM(AllTypes.Time),
AllTypes.Time.IS_DISTINCT_FROM(Time(23, 6, 6, 100)),
AllTypes.Time.IS_NOT_DISTINCT_FROM(AllTypes.Time),
AllTypes.Time.IS_NOT_DISTINCT_FROM(Time(23, 6, 6, 200)),
AllTypes.Time.LT(AllTypes.Time),
AllTypes.Time.LT(Time(23, 6, 6, 22)),
AllTypes.Time.LT_EQ(AllTypes.Time),
AllTypes.Time.LT_EQ(Time(23, 6, 6, 33)),
AllTypes.Time.GT(AllTypes.Time),
AllTypes.Time.GT(Time(23, 6, 6, 0)),
AllTypes.Time.GT_EQ(AllTypes.Time),
AllTypes.Time.GT_EQ(Time(23, 6, 6, 1)),
AllTypes.Time.BETWEEN(Time(11, 0, 30, 100), TimeT(time.Now())),
AllTypes.Time.NOT_BETWEEN(AllTypes.TimePtr, AllTypes.Time.ADD(INTERVAL(2, HOUR))),
AllTypes.Date.ADD(INTERVAL(1, HOUR)),
AllTypes.Date.SUB(INTERVAL(1, MINUTE)),
@ -781,12 +777,20 @@ func TestTimeExpression(t *testing.T) {
AllTypes.Time.SUB(INTERVAL(1, MINUTE)),
AllTypes.Timez.ADD(INTERVAL(1, HOUR)),
AllTypes.Timez.SUB(INTERVAL(1, MINUTE)),
AllTypes.Timez.BETWEEN(TimezT(time.Now()), AllTypes.TimezPtr),
AllTypes.Timez.NOT_BETWEEN(AllTypes.Timez, TimezT(time.Now())),
AllTypes.Timestamp.ADD(INTERVAL(1, HOUR)),
AllTypes.Timestamp.SUB(INTERVAL(1, MINUTE)),
AllTypes.Timestamp.BETWEEN(AllTypes.TimestampPtr, TimestampT(time.Now())),
AllTypes.Timestamp.NOT_BETWEEN(TimestampT(time.Now()), AllTypes.TimestampPtr),
AllTypes.Timestampz.ADD(INTERVAL(1, HOUR)),
AllTypes.Timestampz.SUB(INTERVAL(1, MINUTE)),
AllTypes.Timestamp.BETWEEN(AllTypes.TimestampPtr, TimestampT(time.Now())),
AllTypes.Timestamp.NOT_BETWEEN(AllTypes.TimestampPtr, TimestampT(time.Now())),
AllTypes.Date.SUB(CAST(String("04:05:06")).AS_INTERVAL()),
AllTypes.Date.BETWEEN(Date(2000, 2, 2), DateT(time.Now())),
AllTypes.Date.NOT_BETWEEN(AllTypes.DatePtr, DateT(time.Now().Add(20*time.Hour))),
CURRENT_DATE(),
CURRENT_TIME(),
@ -847,6 +851,8 @@ func TestInterval(t *testing.T) {
AllTypes.Interval.LT_EQ(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr),
AllTypes.Interval.GT(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr),
AllTypes.Interval.GT_EQ(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr),
AllTypes.Interval.BETWEEN(INTERVAL(1, HOUR), INTERVAL(2, HOUR)),
AllTypes.Interval.NOT_BETWEEN(AllTypes.IntervalPtr, INTERVALd(30*time.Second)),
AllTypes.Interval.ADD(AllTypes.IntervalPtr).EQ(INTERVALd(17*time.Second)),
AllTypes.Interval.SUB(AllTypes.IntervalPtr).EQ(INTERVAL(100, MICROSECOND)),
AllTypes.IntervalPtr.MUL(Int(11)).EQ(AllTypes.Interval),

View file

@ -35,6 +35,7 @@ ORDER BY "Album"."AlbumId" ASC;
testutils.AssertDeepEqual(t, dest[1], album2)
testutils.AssertDeepEqual(t, dest[len(dest)-1], album347)
requireLogged(t, stmt)
requireQueryLogged(t, stmt, 347)
}
func TestJoinEverything(t *testing.T) {
@ -101,12 +102,341 @@ func TestJoinEverything(t *testing.T) {
}
}
err := stmt.Query(db, &dest)
testutils.AssertStatementSql(t, stmt, `
SELECT "Artist"."ArtistId" AS "Artist.ArtistId",
"Artist"."Name" AS "Artist.Name",
"Album"."AlbumId" AS "Album.AlbumId",
"Album"."Title" AS "Album.Title",
"Album"."ArtistId" AS "Album.ArtistId",
"Track"."TrackId" AS "Track.TrackId",
"Track"."Name" AS "Track.Name",
"Track"."AlbumId" AS "Track.AlbumId",
"Track"."MediaTypeId" AS "Track.MediaTypeId",
"Track"."GenreId" AS "Track.GenreId",
"Track"."Composer" AS "Track.Composer",
"Track"."Milliseconds" AS "Track.Milliseconds",
"Track"."Bytes" AS "Track.Bytes",
"Track"."UnitPrice" AS "Track.UnitPrice",
"Genre"."GenreId" AS "Genre.GenreId",
"Genre"."Name" AS "Genre.Name",
"MediaType"."MediaTypeId" AS "MediaType.MediaTypeId",
"MediaType"."Name" AS "MediaType.Name",
"PlaylistTrack"."PlaylistId" AS "PlaylistTrack.PlaylistId",
"PlaylistTrack"."TrackId" AS "PlaylistTrack.TrackId",
"Playlist"."PlaylistId" AS "Playlist.PlaylistId",
"Playlist"."Name" AS "Playlist.Name",
"Invoice"."InvoiceId" AS "Invoice.InvoiceId",
"Invoice"."CustomerId" AS "Invoice.CustomerId",
"Invoice"."InvoiceDate" AS "Invoice.InvoiceDate",
"Invoice"."BillingAddress" AS "Invoice.BillingAddress",
"Invoice"."BillingCity" AS "Invoice.BillingCity",
"Invoice"."BillingState" AS "Invoice.BillingState",
"Invoice"."BillingCountry" AS "Invoice.BillingCountry",
"Invoice"."BillingPostalCode" AS "Invoice.BillingPostalCode",
"Invoice"."Total" AS "Invoice.Total",
"Customer"."CustomerId" AS "Customer.CustomerId",
"Customer"."FirstName" AS "Customer.FirstName",
"Customer"."LastName" AS "Customer.LastName",
"Customer"."Company" AS "Customer.Company",
"Customer"."Address" AS "Customer.Address",
"Customer"."City" AS "Customer.City",
"Customer"."State" AS "Customer.State",
"Customer"."Country" AS "Customer.Country",
"Customer"."PostalCode" AS "Customer.PostalCode",
"Customer"."Phone" AS "Customer.Phone",
"Customer"."Fax" AS "Customer.Fax",
"Customer"."Email" AS "Customer.Email",
"Customer"."SupportRepId" AS "Customer.SupportRepId",
"Employee"."EmployeeId" AS "Employee.EmployeeId",
"Employee"."LastName" AS "Employee.LastName",
"Employee"."FirstName" AS "Employee.FirstName",
"Employee"."Title" AS "Employee.Title",
"Employee"."ReportsTo" AS "Employee.ReportsTo",
"Employee"."BirthDate" AS "Employee.BirthDate",
"Employee"."HireDate" AS "Employee.HireDate",
"Employee"."Address" AS "Employee.Address",
"Employee"."City" AS "Employee.City",
"Employee"."State" AS "Employee.State",
"Employee"."Country" AS "Employee.Country",
"Employee"."PostalCode" AS "Employee.PostalCode",
"Employee"."Phone" AS "Employee.Phone",
"Employee"."Fax" AS "Employee.Fax",
"Employee"."Email" AS "Employee.Email",
"Manager"."EmployeeId" AS "Manager.EmployeeId",
"Manager"."LastName" AS "Manager.LastName",
"Manager"."FirstName" AS "Manager.FirstName",
"Manager"."Title" AS "Manager.Title",
"Manager"."ReportsTo" AS "Manager.ReportsTo",
"Manager"."BirthDate" AS "Manager.BirthDate",
"Manager"."HireDate" AS "Manager.HireDate",
"Manager"."Address" AS "Manager.Address",
"Manager"."City" AS "Manager.City",
"Manager"."State" AS "Manager.State",
"Manager"."Country" AS "Manager.Country",
"Manager"."PostalCode" AS "Manager.PostalCode",
"Manager"."Phone" AS "Manager.Phone",
"Manager"."Fax" AS "Manager.Fax",
"Manager"."Email" AS "Manager.Email"
FROM chinook."Artist"
LEFT JOIN chinook."Album" ON ("Artist"."ArtistId" = "Album"."ArtistId")
LEFT JOIN chinook."Track" ON ("Track"."AlbumId" = "Album"."AlbumId")
LEFT JOIN chinook."Genre" ON ("Genre"."GenreId" = "Track"."GenreId")
LEFT JOIN chinook."MediaType" ON ("MediaType"."MediaTypeId" = "Track"."MediaTypeId")
LEFT JOIN chinook."PlaylistTrack" ON ("PlaylistTrack"."TrackId" = "Track"."TrackId")
LEFT JOIN chinook."Playlist" ON ("Playlist"."PlaylistId" = "PlaylistTrack"."PlaylistId")
LEFT JOIN chinook."InvoiceLine" ON ("InvoiceLine"."TrackId" = "Track"."TrackId")
LEFT JOIN chinook."Invoice" ON ("Invoice"."InvoiceId" = "InvoiceLine"."InvoiceId")
LEFT JOIN chinook."Customer" ON ("Customer"."CustomerId" = "Invoice"."CustomerId")
LEFT JOIN chinook."Employee" ON ("Employee"."EmployeeId" = "Customer"."SupportRepId")
LEFT JOIN chinook."Employee" AS "Manager" ON ("Manager"."EmployeeId" = "Employee"."ReportsTo")
ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId", "Genre"."GenreId", "MediaType"."MediaTypeId", "Playlist"."PlaylistId", "Invoice"."InvoiceId", "Customer"."CustomerId";
`)
err := stmt.QueryContext(context.Background(), db, &dest)
require.NoError(t, err)
require.Equal(t, len(dest), 275)
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json")
requireLogged(t, stmt)
requireQueryLogged(t, stmt, 9423)
}
// default column aliases from sub-CTEs are bubbled up to the main query,
// cte name does not affect default column alias in main query
func TestSubQueryColumnAliasBubbling(t *testing.T) {
subQuery1 := SELECT(
Artist.AllColumns,
String("custom_column_1").AS("custom_column_1"),
).FROM(
Artist,
).ORDER_BY(
Artist.ArtistId.ASC(),
).AsTable("subQuery1")
subQuery2 := SELECT(
subQuery1.AllColumns(),
String("custom_column_2").AS("custom_column_2"),
).FROM(
subQuery1,
).AsTable("subQuery2")
mainQuery := SELECT(
subQuery2.AllColumns(), // columns will have the same alias as in the sub-query
subQuery2.AllColumns().As("artist2.*"), // all column aliases will be changed to artist2.*
subQuery2.AllColumns().Except(Artist.Name).As("artist3.*"),
subQuery2.AllColumns().Except(
Artist.MutableColumns,
StringColumn("custom_column_1").From(subQuery2), // custom_column_1 appears with the same alias in subQuery2
StringColumn("custom_column_2").From(subQuery2),
).As("artist4.*"),
).FROM(
subQuery2,
)
// fmt.Println(mainQuery.Sql())
testutils.AssertStatementSql(t, mainQuery, `
SELECT "subQuery2"."Artist.ArtistId" AS "Artist.ArtistId",
"subQuery2"."Artist.Name" AS "Artist.Name",
"subQuery2".custom_column_1 AS "custom_column_1",
"subQuery2".custom_column_2 AS "custom_column_2",
"subQuery2"."Artist.ArtistId" AS "artist2.ArtistId",
"subQuery2"."Artist.Name" AS "artist2.Name",
"subQuery2".custom_column_1 AS "artist2.custom_column_1",
"subQuery2".custom_column_2 AS "artist2.custom_column_2",
"subQuery2"."Artist.ArtistId" AS "artist3.ArtistId",
"subQuery2".custom_column_1 AS "artist3.custom_column_1",
"subQuery2".custom_column_2 AS "artist3.custom_column_2",
"subQuery2"."Artist.ArtistId" AS "artist4.ArtistId"
FROM (
SELECT "subQuery1"."Artist.ArtistId" AS "Artist.ArtistId",
"subQuery1"."Artist.Name" AS "Artist.Name",
"subQuery1".custom_column_1 AS "custom_column_1",
$1 AS "custom_column_2"
FROM (
SELECT "Artist"."ArtistId" AS "Artist.ArtistId",
"Artist"."Name" AS "Artist.Name",
$2 AS "custom_column_1"
FROM chinook."Artist"
ORDER BY "Artist"."ArtistId" ASC
) AS "subQuery1"
) AS "subQuery2";
`)
var dest []struct {
// subQuery2.AllColumns()
Artist1 struct {
model.Artist
CustomColumn1 string
CustomColumn2 string
}
// subQuery2.AllColumns().As("artist2.*")
Artist2 struct {
model.Artist `alias:"artist2.*"`
CustomColumn1 string
CustomColumn2 string
} `alias:"artist2.*"`
// subQuery2.AllColumns().Except(Artist.Name).As("artist3.*")
Artist3 struct {
model.Artist `alias:"artist3.*"`
CustomColumn1 string
CustomColumn2 string
} `alias:"artist3.*"`
// subQuery2.AllColumns().Except(...).As("artist4.*")
Artist4 struct {
model.Artist `alias:"artist4.*"`
CustomColumn1 string
CustomColumn2 string
} `alias:"artist4.*"`
}
err := mainQuery.Query(db, &dest)
require.NoError(t, err)
// Artist1
require.Len(t, dest, 275)
require.Equal(t, dest[0].Artist1.Artist, model.Artist{
ArtistId: 1,
Name: testutils.StringPtr("AC/DC"),
})
require.Equal(t, dest[0].Artist1.CustomColumn1, "custom_column_1")
require.Equal(t, dest[0].Artist1.CustomColumn2, "custom_column_2")
// Artist2
require.Equal(t, testutils.ToJSON(dest[0].Artist1), testutils.ToJSON(dest[0].Artist2))
// Artist3
require.Equal(t, dest[0].Artist3.ArtistId, int32(1))
require.Nil(t, dest[0].Artist3.Name)
require.Equal(t, dest[0].Artist3.CustomColumn1, "custom_column_1")
require.Equal(t, dest[0].Artist3.CustomColumn2, "custom_column_2")
// Artist4
require.Equal(t, dest[0].Artist3.Artist, dest[0].Artist4.Artist)
require.Equal(t, dest[0].Artist4.CustomColumn1, "")
require.Equal(t, dest[0].Artist4.CustomColumn2, "")
}
func TestUnAliasedNamesPanicError(t *testing.T) {
subQuery1 := SELECT(
Artist.AllColumns,
Artist.Name.CONCAT(String("-musician")), //alias missing
).FROM(
Artist,
).ORDER_BY(
Artist.ArtistId.ASC(),
).AsTable("subQuery1")
require.Panics(t, func() {
SELECT(
subQuery1.AllColumns(), // panic, column not aliased
).FROM(
subQuery1,
)
}, "jet: can't export unaliased expression subQuery: subQuery1, expression: (\"Artist\".\"Name\" || '-musician')")
}
func TestProjectionListReAliasing(t *testing.T) {
projectionList := ProjectionList{
Track.GenreId,
SUM(Track.Milliseconds).AS("duration"),
MAX(Track.Milliseconds).AS("duration.max"),
}
stmt := SELECT(
projectionList.As("genre_info"),
).FROM(
Track,
).WHERE(
Track.GenreId.LT(Int(5)),
).GROUP_BY(
Track.GenreId,
).ORDER_BY(
Track.GenreId,
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT "Track"."GenreId" AS "genre_info.GenreId",
SUM("Track"."Milliseconds") AS "genre_info.duration",
MAX("Track"."Milliseconds") AS "genre_info.max"
FROM chinook."Track"
WHERE "Track"."GenreId" < 5
GROUP BY "Track"."GenreId"
ORDER BY "Track"."GenreId";
`)
type GenreInfo struct {
GenreID string
Duration int64
Max int64
}
var dest []GenreInfo
err := stmt.Query(db, &dest)
require.NoError(t, err)
expectedSQL := `
[
{
"GenreID": "1",
"Duration": 368231326,
"Max": 1612329
},
{
"GenreID": "2",
"Duration": 37928199,
"Max": 907520
},
{
"GenreID": "3",
"Duration": 115846292,
"Max": 816509
},
{
"GenreID": "4",
"Duration": 77805478,
"Max": 558602
}
]
`
testutils.AssertJSON(t, dest, expectedSQL)
subQuery := stmt.AsTable("subQuery")
mainStmt := SELECT(
subQuery.AllColumns().As("genre_information.*"),
).FROM(
subQuery,
)
testutils.AssertDebugStatementSql(t, mainStmt, `
SELECT "subQuery"."genre_info.GenreId" AS "genre_information.GenreId",
"subQuery"."genre_info.duration" AS "genre_information.duration",
"subQuery"."genre_info.max" AS "genre_information.max"
FROM (
SELECT "Track"."GenreId" AS "genre_info.GenreId",
SUM("Track"."Milliseconds") AS "genre_info.duration",
MAX("Track"."Milliseconds") AS "genre_info.max"
FROM chinook."Track"
WHERE "Track"."GenreId" < 5
GROUP BY "Track"."GenreId"
ORDER BY "Track"."GenreId"
) AS "subQuery";
`)
type GenreInformation GenreInfo
var newDest []GenreInformation
err = mainStmt.Query(db, &newDest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, expectedSQL)
}
func TestSelfJoin(t *testing.T) {
@ -413,3 +743,53 @@ var album347 = model.Album{
Title: "Koyaanisqatsi (Soundtrack from the Motion Picture)",
ArtistId: 275,
}
func TestAggregateFunc(t *testing.T) {
stmt := SELECT(
PERCENTILE_DISC(Float(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"),
PERCENTILE_DISC(Invoice.Total.DIV(Float(100))).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceDate.ASC()).AS("percentile_disc_2"),
PERCENTILE_DISC(RawFloat("(select array_agg(s) from generate_series(0, 1, 0.2) as s)")).
WITHIN_GROUP_ORDER_BY(Invoice.BillingAddress.DESC()).AS("percentile_disc_3"),
PERCENTILE_CONT(Float(0.3)).WITHIN_GROUP_ORDER_BY(Invoice.Total).AS("percentile_cont_1"),
PERCENTILE_CONT(Float(0.2)).WITHIN_GROUP_ORDER_BY(INTERVAL(1, HOUR).DESC()).AS("percentile_cont_int"),
MODE().WITHIN_GROUP_ORDER_BY(Invoice.BillingPostalCode.DESC()).AS("mode_1"),
).FROM(
Invoice,
).GROUP_BY(
Invoice.Total,
)
testutils.AssertStatementSql(t, stmt, `
SELECT PERCENTILE_DISC ($1::double precision) WITHIN GROUP (ORDER BY "Invoice"."InvoiceId") AS "percentile_disc_1",
PERCENTILE_DISC ("Invoice"."Total" / $2) WITHIN GROUP (ORDER BY "Invoice"."InvoiceDate" ASC) AS "percentile_disc_2",
PERCENTILE_DISC ((select array_agg(s) from generate_series(0, 1, 0.2) as s)) WITHIN GROUP (ORDER BY "Invoice"."BillingAddress" DESC) AS "percentile_disc_3",
PERCENTILE_CONT ($3::double precision) WITHIN GROUP (ORDER BY "Invoice"."Total") AS "percentile_cont_1",
PERCENTILE_CONT ($4::double precision) WITHIN GROUP (ORDER BY INTERVAL '1 HOUR' DESC) AS "percentile_cont_int",
MODE () WITHIN GROUP (ORDER BY "Invoice"."BillingPostalCode" DESC) AS "mode_1"
FROM chinook."Invoice"
GROUP BY "Invoice"."Total";
`, 0.1, 100.0, 0.3, 0.2)
var dest struct {
PercentileDisc1 string
PercentileDisc2 string
PercentileDisc3 string
PercentileCont1 string
Mode1 string
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
{
"PercentileDisc1": "41",
"PercentileDisc2": "2009-01-19T00:00:00Z",
"PercentileDisc3": "{\"Via Degli Scipioni, 43\",\"Qe 7 Bloco G\",\"Berger Stra<72>e 10\",\"696 Osborne Street\",\"2211 W Berry Street\",\"1033 N Park Ave\"}",
"PercentileCont1": "0.99",
"Mode1": "X1A 1N6"
}
`)
}

View file

@ -4,6 +4,8 @@ import (
"context"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/postgres"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table"
"github.com/stretchr/testify/require"
@ -23,7 +25,14 @@ WHERE link.name IN ('Gmail', 'Outlook');
WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook")
AssertExec(t, deleteStmt, 2)
res, err := deleteStmt.ExecContext(context.Background(), db)
require.NoError(t, err)
rows, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, rows, int64(2))
requireQueryLogged(t, deleteStmt, int64(2))
}
func TestDeleteWithWhereAndReturning(t *testing.T) {
@ -103,3 +112,72 @@ func TestDeleteExecContext(t *testing.T) {
require.Error(t, err, "context deadline exceeded")
requireLogged(t, deleteStmt)
}
func TestDeleteFrom(t *testing.T) {
tx := beginTx(t)
defer tx.Rollback()
stmt := table.Rental.DELETE().
USING(
table.Staff.
INNER_JOIN(table.Store, table.Store.StoreID.EQ(table.Staff.StaffID)),
table.Actor,
).
WHERE(
table.Staff.StaffID.EQ(table.Rental.StaffID).
AND(table.Staff.StaffID.EQ(Int(2))).
AND(table.Rental.RentalID.LT(Int(10))),
).
RETURNING(
table.Rental.AllColumns,
table.Store.AllColumns,
)
testutils.AssertStatementSql(t, stmt, `
DELETE FROM dvds.rental
USING dvds.staff
INNER JOIN dvds.store ON (store.store_id = staff.staff_id),
dvds.actor
WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = $1)) AND (rental.rental_id < $2)
RETURNING rental.rental_id AS "rental.rental_id",
rental.rental_date AS "rental.rental_date",
rental.inventory_id AS "rental.inventory_id",
rental.customer_id AS "rental.customer_id",
rental.return_date AS "rental.return_date",
rental.staff_id AS "rental.staff_id",
rental.last_update AS "rental.last_update",
store.store_id AS "store.store_id",
store.manager_staff_id AS "store.manager_staff_id",
store.address_id AS "store.address_id",
store.last_update AS "store.last_update";
`)
var dest []struct {
Rental model2.Rental
Store model2.Store
}
err := stmt.Query(tx, &dest)
require.NoError(t, err)
require.Len(t, dest, 3)
testutils.AssertJSON(t, dest[0], `
{
"Rental": {
"RentalID": 4,
"RentalDate": "2005-05-24T23:04:41Z",
"InventoryID": 2452,
"CustomerID": 333,
"ReturnDate": "2005-06-03T01:43:41Z",
"StaffID": 2,
"LastUpdate": "2006-02-16T02:30:53Z"
},
"Store": {
"StoreID": 2,
"ManagerStaffID": 2,
"AddressID": 2,
"LastUpdate": "2006-02-15T09:57:12Z"
}
}
`)
}

View file

@ -27,7 +27,7 @@ var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table",
var dbConnection = postgres.DBConnection{
Host: dbconfig.PgHost,
Port: 5432,
Port: dbconfig.PgPort,
User: dbconfig.PgUser,
Password: dbconfig.PgPassword,
DBName: dbconfig.PgDBName,

View file

@ -7,6 +7,7 @@ import (
"os/exec"
"path/filepath"
"reflect"
"strconv"
"testing"
"github.com/go-jet/jet/v2/generator/postgres"
@ -52,8 +53,13 @@ func TestCmdGenerator(t *testing.T) {
err := os.RemoveAll(genTestDir2)
require.NoError(t, err)
cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost", "-port=5432",
"-user=jet", "-password=jet", "-schema=dvds", "-path="+genTestDir2)
cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost",
"-port="+strconv.Itoa(dbconfig.PgPort),
"-user=jet",
"-password=jet",
"-schema=dvds",
"-path="+genTestDir2)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
@ -86,6 +92,59 @@ func TestCmdGenerator(t *testing.T) {
require.NoError(t, err)
}
func TestGeneratorIgnoreTables(t *testing.T) {
err := os.RemoveAll(genTestDir2)
require.NoError(t, err)
cmd := exec.Command("jet",
"-source=PostgreSQL",
"-host=localhost",
"-port="+strconv.Itoa(dbconfig.PgPort),
"-user=jet",
"-password=jet",
"-dbname=jetdb",
"-schema=dvds",
"-ignore-tables=actor,ADDRESS,country, Film , cITY,",
"-ignore-views=Actor_info, FILM_LIST ,staff_list",
"-ignore-enums=mpaa_rating",
"-path="+genTestDir2)
fmt.Println(cmd.Args)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
err = cmd.Run()
require.NoError(t, err)
// Table SQL Builder files
tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table")
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "category.go",
"customer.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go")
// View SQL Builder files
viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view")
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "customer_list.go", "sales_by_store.go")
// Enums SQL Builder files
_, err = ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum")
require.Error(t, err, "open ./.gentestdata2/jetdb/dvds/enum: no such file or directory")
modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model")
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, modelFiles, "category.go",
"customer.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go",
"nicer_but_slower_film_list.go", "sales_by_film_category.go",
"customer_list.go", "sales_by_store.go")
}
func TestGenerator(t *testing.T) {
for i := 0; i < 3; i++ {

View file

@ -7,6 +7,7 @@ import (
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
"math/rand"
"os"
"runtime"
"testing"
"time"
@ -59,11 +60,21 @@ var loggedSQL string
var loggedSQLArgs []interface{}
var loggedDebugSQL string
var queryInfo postgres.QueryInfo
var callerFile string
var callerLine int
var callerFunction string
func init() {
postgres.SetLogger(func(ctx context.Context, statement postgres.PrintableStatement) {
loggedSQL, loggedSQLArgs = statement.Sql()
loggedDebugSQL = statement.DebugSql()
})
postgres.SetQueryLogger(func(ctx context.Context, info postgres.QueryInfo) {
queryInfo = info
callerFile, callerLine, callerFunction = info.Caller()
})
}
func requireLogged(t *testing.T, statement postgres.Statement) {
@ -73,6 +84,21 @@ func requireLogged(t *testing.T, statement postgres.Statement) {
require.Equal(t, loggedDebugSQL, statement.DebugSql())
}
func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) {
query, args := statement.Sql()
queryLogged, argsLogged := queryInfo.Statement.Sql()
require.Equal(t, query, queryLogged)
require.Equal(t, args, argsLogged)
require.Equal(t, queryInfo.RowsProcessed, rowsProcessed)
pc, file, _, _ := runtime.Caller(1)
funcDetails := runtime.FuncForPC(pc)
require.Equal(t, file, callerFile)
require.NotEmpty(t, callerLine)
require.Equal(t, funcDetails.Name(), callerFunction)
}
func skipForPgxDriver(t *testing.T) {
if isPgxDriver() {
t.SkipNow()
@ -87,3 +113,9 @@ func isPgxDriver() bool {
return false
}
func beginTx(t *testing.T) *sql.Tx {
tx, err := db.Begin()
require.NoError(t, err)
return tx
}

View file

@ -9,7 +9,7 @@ import (
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/v2/postgres"
)

View file

@ -62,7 +62,8 @@ func TestScanToValidDestination(t *testing.T) {
t.Run("global query function scan", func(t *testing.T) {
queryStr, args := oneInventoryQuery.Sql()
dest := []struct{}{}
err := qrm.Query(nil, db, queryStr, args, &dest)
rowProcessed, err := qrm.Query(nil, db, queryStr, args, &dest)
require.Equal(t, rowProcessed, int64(1))
require.NoError(t, err)
})
@ -782,6 +783,7 @@ func TestRowsScan(t *testing.T) {
require.NoError(t, err)
requireLogged(t, stmt)
requireQueryLogged(t, stmt, 0)
}
func TestScanNumericToFloat(t *testing.T) {

View file

@ -48,6 +48,63 @@ WHERE actor.actor_id = 2;
requireLogged(t, query)
}
func TestSelectDistinctOn(t *testing.T) {
stmt := SELECT(
Rental.StaffID,
Rental.CustomerID,
Rental.RentalID,
).DISTINCT(
Rental.StaffID,
Rental.CustomerID,
).FROM(
Rental,
).WHERE(
Rental.CustomerID.LT(Int(2)),
).ORDER_BY(
Rental.StaffID.ASC(),
Rental.CustomerID.ASC(),
Rental.RentalID.ASC(),
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT DISTINCT ON (rental.staff_id, rental.customer_id) rental.staff_id AS "rental.staff_id",
rental.customer_id AS "rental.customer_id",
rental.rental_id AS "rental.rental_id"
FROM dvds.rental
WHERE rental.customer_id < 2
ORDER BY rental.staff_id ASC, rental.customer_id ASC, rental.rental_id ASC;
`)
var dest []model.Rental
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"RentalID": 573,
"RentalDate": "0001-01-01T00:00:00Z",
"InventoryID": 0,
"CustomerID": 1,
"ReturnDate": null,
"StaffID": 1,
"LastUpdate": "0001-01-01T00:00:00Z"
},
{
"RentalID": 76,
"RentalDate": "0001-01-01T00:00:00Z",
"InventoryID": 0,
"CustomerID": 1,
"ReturnDate": null,
"StaffID": 2,
"LastUpdate": "0001-01-01T00:00:00Z"
}
]
`)
}
func TestClassicSelect(t *testing.T) {
expectedSQL := `
SELECT payment.payment_id AS "payment.payment_id",
@ -814,10 +871,10 @@ ORDER BY f1.film_id ASC;
type F1 model.Film
type F2 model.Film
theSameLengthFilms := []struct {
var theSameLengthFilms []struct {
F1 F1
F2 F2
}{}
}
err := query.Query(db, &theSameLengthFilms)
@ -858,68 +915,124 @@ LIMIT 1000;
Title2 string
Length int16
}
films := []thesameLengthFilms{}
var films []thesameLengthFilms
err := query.Query(db, &films)
require.NoError(t, err)
//spew.Dump(films)
require.Equal(t, len(films), 1000)
testutils.AssertDeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46})
}
func TestSubQuery(t *testing.T) {
expectedQuery := `
SELECT actor.actor_id AS "actor.actor_id",
rRatingFilms :=
SELECT(
Film.FilmID,
Film.Title,
Film.Rating,
).FROM(
Film,
).WHERE(
Film.Rating.EQ(enum.MpaaRating.R),
).AsTable("rFilms")
rFilmID := Film.FilmID.From(rRatingFilms)
stmt :=
SELECT(
rRatingFilms.AllColumns(),
Actor.AllColumns,
FilmActor.AllColumns,
).FROM(
rRatingFilms.
INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(rFilmID)).
INNER_JOIN(Actor, FilmActor.ActorID.EQ(Actor.ActorID)),
).WHERE(
rFilmID.LT(Int(50)),
).ORDER_BY(
rFilmID.ASC(),
Actor.ActorID.ASC(),
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT "rFilms"."film.film_id" AS "film.film_id",
"rFilms"."film.title" AS "film.title",
"rFilms"."film.rating" AS "film.rating",
actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update",
film_actor.actor_id AS "film_actor.actor_id",
film_actor.film_id AS "film_actor.film_id",
film_actor.last_update AS "film_actor.last_update",
"rFilms"."film.film_id" AS "film.film_id",
"rFilms"."film.title" AS "film.title",
"rFilms"."film.rating" AS "film.rating"
FROM dvds.actor
INNER JOIN dvds.film_actor ON (actor.actor_id = film_actor.film_id)
INNER JOIN (
film_actor.last_update AS "film_actor.last_update"
FROM (
SELECT film.film_id AS "film.film_id",
film.title AS "film.title",
film.rating AS "film.rating"
FROM dvds.film
WHERE film.rating = 'R'
) AS "rFilms" ON (film_actor.film_id = "rFilms"."film.film_id");
`
) AS "rFilms"
INNER JOIN dvds.film_actor ON (film_actor.film_id = "rFilms"."film.film_id")
INNER JOIN dvds.actor ON (film_actor.actor_id = actor.actor_id)
WHERE "rFilms"."film.film_id" < 50
ORDER BY "rFilms"."film.film_id" ASC, actor.actor_id ASC;
`)
rRatingFilms := Film.
SELECT(
Film.FilmID,
Film.Title,
Film.Rating,
).
WHERE(Film.Rating.EQ(enum.MpaaRating.R)).
AsTable("rFilms")
var dest []struct {
model.Film
rFilmID := Film.FilmID.From(rRatingFilms)
query := Actor.
INNER_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.FilmID)).
INNER_JOIN(rRatingFilms, FilmActor.FilmID.EQ(rFilmID)).
SELECT(
Actor.AllColumns,
FilmActor.AllColumns,
rRatingFilms.AllColumns(),
)
testutils.AssertDebugStatementSql(t, query, expectedQuery)
dest := []model.Actor{}
err := query.Query(db, &dest)
Actors []model.Actor
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Len(t, dest, 10)
testutils.AssertJSON(t, dest[0], `
{
"FilmID": 8,
"Title": "Airport Pollock",
"Description": null,
"ReleaseYear": null,
"LanguageID": 0,
"RentalDuration": 0,
"RentalRate": 0,
"Length": null,
"ReplacementCost": 0,
"Rating": "R",
"LastUpdate": "0001-01-01T00:00:00Z",
"SpecialFeatures": null,
"Fulltext": "",
"Actors": [
{
"ActorID": 55,
"FirstName": "Fay",
"LastName": "Kilmer",
"LastUpdate": "2013-05-26T14:47:57.62Z"
},
{
"ActorID": 96,
"FirstName": "Gene",
"LastName": "Willis",
"LastUpdate": "2013-05-26T14:47:57.62Z"
},
{
"ActorID": 110,
"FirstName": "Susan",
"LastName": "Davis",
"LastUpdate": "2013-05-26T14:47:57.62Z"
},
{
"ActorID": 138,
"FirstName": "Lucille",
"LastName": "Dee",
"LastUpdate": "2013-05-26T14:47:57.62Z"
}
]
}
`)
}
func TestSelectFunctions(t *testing.T) {
@ -1078,6 +1191,66 @@ ORDER BY customer.customer_id, SUM(payment.amount) ASC;
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/customer_payment_sum.json")
}
func TestAggregateFunctionDistinct(t *testing.T) {
stmt := SELECT(
Payment.CustomerID,
COUNT(DISTINCT(Payment.Amount)).AS("distinct.count"),
SUM(DISTINCT(Payment.Amount)).AS("distinct.sum"),
AVG(DISTINCT(Payment.Amount)).AS("distinct.avg"),
MIN(DISTINCT(Payment.PaymentDate)).AS("distinct.first_payment_date"),
MAX(DISTINCT(Payment.PaymentDate)).AS("distinct.last_payment_date"),
).FROM(
Payment,
).WHERE(
Payment.CustomerID.EQ(Int(1)),
).GROUP_BY(
Payment.CustomerID,
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT payment.customer_id AS "payment.customer_id",
COUNT(DISTINCT payment.amount) AS "distinct.count",
SUM(DISTINCT payment.amount) AS "distinct.sum",
AVG(DISTINCT payment.amount) AS "distinct.avg",
MIN(DISTINCT payment.payment_date) AS "distinct.first_payment_date",
MAX(DISTINCT payment.payment_date) AS "distinct.last_payment_date"
FROM dvds.payment
WHERE payment.customer_id = 1
GROUP BY payment.customer_id;
`)
type Distinct struct {
model.Payment
Count int64
Sum float64
Avg float64
FirstPaymentDate time.Time
LastPaymentDate time.Time
}
var dest Distinct
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
{
"PaymentID": 0,
"CustomerID": 1,
"StaffID": 0,
"RentalID": 0,
"Amount": 0,
"PaymentDate": "0001-01-01T00:00:00Z",
"Count": 8,
"Sum": 38.92,
"Avg": 4.865,
"FirstPaymentDate": "2007-02-14T23:22:38.996577Z",
"LastPaymentDate": "2007-04-30T01:10:44.996577Z"
}
`)
}
func TestSelectGroupBy2(t *testing.T) {
expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id",
@ -1887,7 +2060,7 @@ SELECT customer.customer_id AS "customer.customer_id",
customer.last_update AS "customer.last_update",
customer.active AS "customer.active"
FROM dvds.customer
WHERE ($1 AND (customer.customer_id = $2)) AND (customer.activebool = $3);
WHERE ($1::boolean AND (customer.customer_id = $2)) AND (customer.activebool = $3::boolean);
`, true, int64(1), true)
dest := []model.Customer{}
@ -2056,3 +2229,353 @@ FROM dvds.address;
require.Len(t, dest, 603)
})
}
type FilmWrap struct {
model.Film
Actors []ActorWrap
}
type ActorWrap struct {
model.Actor
Films []FilmWrap
}
func TestRecursionScanNxM(t *testing.T) {
stmt := SELECT(
Actor.AllColumns,
Film.AllColumns,
).FROM(
Actor.
INNER_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.ActorID)).
INNER_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID)),
).ORDER_BY(
Actor.ActorID,
Film.FilmID,
).LIMIT(100)
t.Run("film->actors", func(t *testing.T) {
var films []FilmWrap
err := stmt.Query(db, &films)
require.NoError(t, err)
require.Len(t, films, 95)
testutils.AssertJSON(t, films[:2], `
[
{
"FilmID": 1,
"Title": "Academy Dinosaur",
"Description": "A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies",
"ReleaseYear": 2006,
"LanguageID": 1,
"RentalDuration": 6,
"RentalRate": 0.99,
"Length": 86,
"ReplacementCost": 20.99,
"Rating": "PG",
"LastUpdate": "2013-05-26T14:50:58.951Z",
"SpecialFeatures": "{\"Deleted Scenes\",\"Behind the Scenes\"}",
"Fulltext": "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17",
"Actors": [
{
"ActorID": 1,
"FirstName": "Penelope",
"LastName": "Guiness",
"LastUpdate": "2013-05-26T14:47:57.62Z",
"Films": null
}
]
},
{
"FilmID": 23,
"Title": "Anaconda Confessions",
"Description": "A Lacklusture Display of a Dentist And a Dentist who must Fight a Girl in Australia",
"ReleaseYear": 2006,
"LanguageID": 1,
"RentalDuration": 3,
"RentalRate": 0.99,
"Length": 92,
"ReplacementCost": 9.99,
"Rating": "R",
"LastUpdate": "2013-05-26T14:50:58.951Z",
"SpecialFeatures": "{Trailers,\"Deleted Scenes\"}",
"Fulltext": "'anaconda':1 'australia':18 'confess':2 'dentist':8,11 'display':5 'fight':14 'girl':16 'lacklustur':4 'must':13",
"Actors": [
{
"ActorID": 1,
"FirstName": "Penelope",
"LastName": "Guiness",
"LastUpdate": "2013-05-26T14:47:57.62Z",
"Films": null
},
{
"ActorID": 4,
"FirstName": "Jennifer",
"LastName": "Davis",
"LastUpdate": "2013-05-26T14:47:57.62Z",
"Films": null
}
]
}
]
`)
})
t.Run("actors->films", func(t *testing.T) {
var actors []ActorWrap
err := stmt.Query(db, &actors)
require.NoError(t, err)
require.Equal(t, len(actors), 5)
require.Equal(t, actors[0].ActorID, int32(1))
require.Equal(t, actors[0].FirstName, "Penelope")
require.Len(t, actors[0].Films, 19)
testutils.AssertJSON(t, actors[0].Films[:2], `
[
{
"FilmID": 1,
"Title": "Academy Dinosaur",
"Description": "A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies",
"ReleaseYear": 2006,
"LanguageID": 1,
"RentalDuration": 6,
"RentalRate": 0.99,
"Length": 86,
"ReplacementCost": 20.99,
"Rating": "PG",
"LastUpdate": "2013-05-26T14:50:58.951Z",
"SpecialFeatures": "{\"Deleted Scenes\",\"Behind the Scenes\"}",
"Fulltext": "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17",
"Actors": null
},
{
"FilmID": 23,
"Title": "Anaconda Confessions",
"Description": "A Lacklusture Display of a Dentist And a Dentist who must Fight a Girl in Australia",
"ReleaseYear": 2006,
"LanguageID": 1,
"RentalDuration": 3,
"RentalRate": 0.99,
"Length": 92,
"ReplacementCost": 9.99,
"Rating": "R",
"LastUpdate": "2013-05-26T14:50:58.951Z",
"SpecialFeatures": "{Trailers,\"Deleted Scenes\"}",
"Fulltext": "'anaconda':1 'australia':18 'confess':2 'dentist':8,11 'display':5 'fight':14 'girl':16 'lacklustur':4 'must':13",
"Actors": null
}
]
`)
})
}
type StoreWrap struct {
model.Store
Staffs []StaffWrap
}
type StaffWrap struct {
model.Staff
Store StoreWrap
}
func TestRecursionScanNx1(t *testing.T) {
stmt := SELECT(
Store.AllColumns,
Staff.AllColumns,
).FROM(
Store.
INNER_JOIN(Staff, Staff.StoreID.EQ(Store.StoreID)),
).ORDER_BY(
Store.StoreID,
Staff.StaffID,
)
t.Run("store->staff", func(t *testing.T) {
var stores []StoreWrap
err := stmt.Query(db, &stores)
require.NoError(t, err)
require.Len(t, stores, 2)
testutils.AssertJSON(t, stores, `
[
{
"StoreID": 1,
"ManagerStaffID": 1,
"AddressID": 1,
"LastUpdate": "2006-02-15T09:57:12Z",
"Staffs": [
{
"StaffID": 1,
"FirstName": "Mike",
"LastName": "Hillyer",
"AddressID": 3,
"Email": "Mike.Hillyer@sakilastaff.com",
"StoreID": 1,
"Active": true,
"Username": "Mike",
"Password": "8cb2237d0679ca88db6464eac60da96345513964",
"LastUpdate": "2006-05-16T16:13:11.79328Z",
"Picture": "iVBORw0KWgo=",
"Store": {
"StoreID": 0,
"ManagerStaffID": 0,
"AddressID": 0,
"LastUpdate": "0001-01-01T00:00:00Z",
"Staffs": null
}
}
]
},
{
"StoreID": 2,
"ManagerStaffID": 2,
"AddressID": 2,
"LastUpdate": "2006-02-15T09:57:12Z",
"Staffs": [
{
"StaffID": 2,
"FirstName": "Jon",
"LastName": "Stephens",
"AddressID": 4,
"Email": "Jon.Stephens@sakilastaff.com",
"StoreID": 2,
"Active": true,
"Username": "Jon",
"Password": "8cb2237d0679ca88db6464eac60da96345513964",
"LastUpdate": "2006-05-16T16:13:11.79328Z",
"Picture": null,
"Store": {
"StoreID": 0,
"ManagerStaffID": 0,
"AddressID": 0,
"LastUpdate": "0001-01-01T00:00:00Z",
"Staffs": null
}
}
]
}
]
`)
})
t.Run("staff->store", func(t *testing.T) {
var staffs []StaffWrap
err := stmt.Query(db, &staffs)
require.NoError(t, err)
testutils.AssertJSON(t, staffs, `
[
{
"StaffID": 1,
"FirstName": "Mike",
"LastName": "Hillyer",
"AddressID": 3,
"Email": "Mike.Hillyer@sakilastaff.com",
"StoreID": 1,
"Active": true,
"Username": "Mike",
"Password": "8cb2237d0679ca88db6464eac60da96345513964",
"LastUpdate": "2006-05-16T16:13:11.79328Z",
"Picture": "iVBORw0KWgo=",
"Store": {
"StoreID": 1,
"ManagerStaffID": 1,
"AddressID": 1,
"LastUpdate": "2006-02-15T09:57:12Z",
"Staffs": null
}
},
{
"StaffID": 2,
"FirstName": "Jon",
"LastName": "Stephens",
"AddressID": 4,
"Email": "Jon.Stephens@sakilastaff.com",
"StoreID": 2,
"Active": true,
"Username": "Jon",
"Password": "8cb2237d0679ca88db6464eac60da96345513964",
"LastUpdate": "2006-05-16T16:13:11.79328Z",
"Picture": null,
"Store": {
"StoreID": 2,
"ManagerStaffID": 2,
"AddressID": 2,
"LastUpdate": "2006-02-15T09:57:12Z",
"Staffs": null
}
}
]
`)
})
}
// In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions,
// postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error.
// Int8, Int16, .... functions will add automatic type cast over placeholder, so type deduction is always possible.
func TestLiteralTypeDeduction(t *testing.T) {
stmt := SELECT(
SUM(
CASE().WHEN(Staff.Active.IS_TRUE()).
THEN(Int8(6)). // if Int8 and Int32 are replaced with Int,
ELSE(Int32(-1)), // execution of this statement will return an error
).AS("num_passed"),
).FROM(Staff)
testutils.AssertStatementSql(t, stmt, `
SELECT SUM((CASE WHEN staff.active IS TRUE THEN $1::smallint ELSE $2::integer END)) AS "num_passed"
FROM dvds.staff;
`)
err := stmt.Query(db, &struct{}{})
require.NoError(t, err)
}
func GET_FILM_COUNT(lenFrom, lenTo IntegerExpression) IntegerExpression {
return IntExp(Func("dvds.get_film_count", lenFrom, lenTo))
}
func TestCustomFunctionCall(t *testing.T) {
stmt := SELECT(
GET_FILM_COUNT(Int(100), Int(120)).AS("film_count"),
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT dvds.get_film_count(100, 120) AS "film_count";
`)
var dest struct {
FilmCount int
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.FilmCount, 165)
stmt2 := SELECT(
Raw("dvds.get_film_count(#1, #2)", RawArgs{"#1": 100, "#2": 120}).AS("film_count"),
)
err = stmt2.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.FilmCount, 165)
stmt3 := RawStatement(`
SELECT dvds.get_film_count(#1, #2) AS "film_count";`, RawArgs{"#1": 100, "#2": 120},
)
err = stmt3.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.FilmCount, 165)
}

View file

@ -4,6 +4,8 @@ import (
"context"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/postgres"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table"
"github.com/stretchr/testify/require"
@ -264,11 +266,13 @@ func TestUpdateWithModelData(t *testing.T) {
expectedSQL := `
UPDATE test_sample.link
SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL)
WHERE link.id = 201;
WHERE link.id = 201::integer;
`
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201))
AssertExec(t, stmt, 1)
_, err := stmt.Exec(db)
require.NoError(t, err)
requireQueryLogged(t, stmt, 1)
}
func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
@ -291,7 +295,7 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
var expectedSQL = `
UPDATE test_sample.link
SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com')
WHERE link.id = 201;
WHERE link.id = 201::integer;
`
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201))
@ -371,6 +375,77 @@ func TestUpdateExecContext(t *testing.T) {
require.Error(t, err, "context deadline exceeded")
}
func TestUpdateFrom(t *testing.T) {
tx := beginTx(t)
defer tx.Rollback()
stmt := table.Rental.UPDATE().
SET(
table.Rental.RentalDate.SET(Timestamp(2020, 2, 2, 0, 0, 0)),
).
FROM(
table.Staff.
INNER_JOIN(table.Store, table.Store.StoreID.EQ(table.Staff.StaffID)),
table.Actor,
).
WHERE(
table.Staff.StaffID.EQ(table.Rental.StaffID).
AND(table.Staff.StaffID.EQ(Int(2))).
AND(table.Rental.RentalID.LT(Int(10))),
).
RETURNING(
table.Rental.AllColumns.Except(table.Rental.LastUpdate),
table.Store.AllColumns.Except(table.Store.LastUpdate),
)
testutils.AssertStatementSql(t, stmt, `
UPDATE dvds.rental
SET rental_date = $1::timestamp without time zone
FROM dvds.staff
INNER JOIN dvds.store ON (store.store_id = staff.staff_id),
dvds.actor
WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = $2)) AND (rental.rental_id < $3)
RETURNING rental.rental_id AS "rental.rental_id",
rental.rental_date AS "rental.rental_date",
rental.inventory_id AS "rental.inventory_id",
rental.customer_id AS "rental.customer_id",
rental.return_date AS "rental.return_date",
rental.staff_id AS "rental.staff_id",
store.store_id AS "store.store_id",
store.manager_staff_id AS "store.manager_staff_id",
store.address_id AS "store.address_id";
`)
var dest []struct {
Rental model2.Rental
Store model2.Store
}
err := stmt.Query(tx, &dest)
require.NoError(t, err)
require.Len(t, dest, 3)
testutils.AssertJSON(t, dest[0], `
{
"Rental": {
"RentalID": 4,
"RentalDate": "2020-02-02T00:00:00Z",
"InventoryID": 2452,
"CustomerID": 333,
"ReturnDate": "2005-06-03T01:43:41Z",
"StaffID": 2,
"LastUpdate": "0001-01-01T00:00:00Z"
},
"Store": {
"StoreID": 2,
"ManagerStaffID": 2,
"AddressID": 2,
"LastUpdate": "0001-01-01T00:00:00Z"
}
}
`)
}
func setupLinkTableForUpdateTest(t *testing.T) {
cleanUpLinkTable(t)

View file

@ -1,6 +1,8 @@
package postgres
import (
"context"
"fmt"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/model"
@ -143,7 +145,7 @@ func TestWithStatementDeleteAndInsert(t *testing.T) {
require.Equal(t, len(updateDiscontinuedPrice.AllColumns()[0].(ProjectionList)), 10)
require.Equal(t, len(logDiscontinuedProducts.AllColumns()), 10)
//fmt.Println(stmt.Sql())
// fmt.Println(stmt.Sql())
testutils.AssertStatementSql(t, stmt, `
WITH remove_discontinued_orders AS (
@ -217,5 +219,650 @@ FROM log_discontinued;
err = stmt.Query(tx, &resp)
require.NoError(t, err)
}
func TestRecursiveWithStatement_Fibonacci(t *testing.T) {
// CTE columns are listed as part of CTE definition
n1 := IntegerColumn("n1")
fibN1 := IntegerColumn("fibN1")
nextFibN1 := IntegerColumn("nextFibN1")
fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1)
// CTE columns are columns from non-recursive select
fibonacci2 := CTE("fibonacci2")
n2 := IntegerColumn("n2").From(fibonacci2)
fibN2 := IntegerColumn("fibN2").From(fibonacci2)
nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2)
stmt := WITH_RECURSIVE(
fibonacci1.AS(
SELECT(
Int32(1), Int32(0), Int32(1),
).UNION_ALL(
SELECT(
n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1),
).FROM(
fibonacci1,
).WHERE(
n1.LT(Int(20)),
),
),
),
fibonacci2.AS(
SELECT(
Int32(1).AS(n2.Name()),
Int32(0).AS(fibN2.Name()),
Int32(1).AS(nextFibN2.Name()),
).UNION_ALL(
SELECT(
n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2),
).FROM(
fibonacci2,
).WHERE(
n2.LT(Int(20)),
),
),
),
)(
SELECT(
fibonacci1.AllColumns(),
fibonacci2.AllColumns(),
).FROM(
fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)),
).WHERE(
n1.EQ(Int(20)),
),
)
//fmt.Println(stmt.Sql())
testutils.AssertStatementSql(t, stmt, `
WITH RECURSIVE fibonacci1 (n1, "fibN1", "nextFibN1") AS (
(
SELECT $1::integer,
$2::integer,
$3::integer
)
UNION ALL
(
SELECT fibonacci1.n1 + $4,
fibonacci1."nextFibN1" AS "nextFibN1",
fibonacci1."fibN1" + fibonacci1."nextFibN1"
FROM fibonacci1
WHERE fibonacci1.n1 < $5
)
),fibonacci2 AS (
(
SELECT $6::integer AS "n2",
$7::integer AS "fibN2",
$8::integer AS "nextFibN2"
)
UNION ALL
(
SELECT fibonacci2.n2 + $9,
fibonacci2."nextFibN2" AS "nextFibN2",
fibonacci2."fibN2" + fibonacci2."nextFibN2"
FROM fibonacci2
WHERE fibonacci2.n2 < $10
)
)
SELECT fibonacci1.n1 AS "n1",
fibonacci1."fibN1" AS "fibN1",
fibonacci1."nextFibN1" AS "nextFibN1",
fibonacci2.n2 AS "n2",
fibonacci2."fibN2" AS "fibN2",
fibonacci2."nextFibN2" AS "nextFibN2"
FROM fibonacci1
INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2)
WHERE fibonacci1.n1 = $11;
`)
var dest struct {
N1 int
FibN1 int
NextFibN1 int
N2 int
FibN2 int
NextFibN2 int
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.N1, 20)
require.Equal(t, dest.FibN1, 4181)
require.Equal(t, dest.NextFibN1, 6765)
require.Equal(t, dest.N2, 20)
require.Equal(t, dest.FibN2, 4181)
require.Equal(t, dest.NextFibN2, 6765)
}
// default column aliases from sub-queries are bubbled up to the main query,
// cte name does not affect default column alias in main query
func TestCTEColumnAliasBubbling(t *testing.T) {
cte1 := CTE("cte1")
cte2 := CTE("cte2")
stmt := WITH(
cte1.AS(
SELECT(
Territories.AllColumns,
String("custom_column_1").AS("custom_column_1"),
).FROM(
Territories,
).ORDER_BY(
Territories.TerritoryID.ASC(),
),
),
cte2.AS(
SELECT(
cte1.AllColumns(),
String("custom_column_2").AS("custom_column_2"),
).FROM(
cte1,
),
),
)(
SELECT(
cte2.AllColumns(), // columns will have the same alias as in CTEs
cte2.AllColumns().As("territories2.*"), // all column aliases will be changed to territories2.*
cte2.AllColumns().Except(Territories.RegionID, Territories.TerritoryDescription).As("territories3.*"),
cte2.AllColumns().
Except(
Territories.MutableColumns,
StringColumn("custom_column_1").From(cte2), // custom_column_1 appears with the same alias in cte2
StringColumn("custom_column_2").From(cte2),
).As("territories4.*"),
).FROM(
cte2,
),
)
// fmt.Println(stmt.Sql())
testutils.AssertStatementSql(t, stmt, `
WITH cte1 AS (
SELECT territories.territory_id AS "territories.territory_id",
territories.territory_description AS "territories.territory_description",
territories.region_id AS "territories.region_id",
$1 AS "custom_column_1"
FROM northwind.territories
ORDER BY territories.territory_id ASC
),cte2 AS (
SELECT cte1."territories.territory_id" AS "territories.territory_id",
cte1."territories.territory_description" AS "territories.territory_description",
cte1."territories.region_id" AS "territories.region_id",
cte1.custom_column_1 AS "custom_column_1",
$2 AS "custom_column_2"
FROM cte1
)
SELECT cte2."territories.territory_id" AS "territories.territory_id",
cte2."territories.territory_description" AS "territories.territory_description",
cte2."territories.region_id" AS "territories.region_id",
cte2.custom_column_1 AS "custom_column_1",
cte2.custom_column_2 AS "custom_column_2",
cte2."territories.territory_id" AS "territories2.territory_id",
cte2."territories.territory_description" AS "territories2.territory_description",
cte2."territories.region_id" AS "territories2.region_id",
cte2.custom_column_1 AS "territories2.custom_column_1",
cte2.custom_column_2 AS "territories2.custom_column_2",
cte2."territories.territory_id" AS "territories3.territory_id",
cte2.custom_column_1 AS "territories3.custom_column_1",
cte2.custom_column_2 AS "territories3.custom_column_2",
cte2."territories.territory_id" AS "territories4.territory_id"
FROM cte2;
`)
var dest []struct {
// cte2.AllColumns()
Territories1 struct {
model.Territories
CustomColumn1 string
CustomColumn2 string
}
// cte2.AllColumns().As("territories2.*")
Territories2 struct {
model.Territories `alias:"territories2.*"`
CustomColumn1 string
CustomColumn2 string
} `alias:"territories2.*"`
// cte2.AllColumns().Except(Territories.RegionID, Territories.TerritoryDescription).As("territories3.*")
Territories3 struct {
model.Territories `alias:"territories3.*"`
CustomColumn1 string
CustomColumn2 string
} `alias:"territories3.*"`
// cte2.AllColumns() ... .As("territories4.*")
Territories4 struct {
model.Territories `alias:"territories3.*"`
CustomColumn1 string
CustomColumn2 string
} `alias:"territories4.*"`
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Len(t, dest, 53)
require.Equal(t, dest[0].Territories1.Territories, model.Territories{
TerritoryID: "01581",
TerritoryDescription: "Westboro",
RegionID: 1,
})
require.Equal(t, dest[0].Territories1.CustomColumn1, "custom_column_1")
require.Equal(t, dest[0].Territories1.CustomColumn2, "custom_column_2")
// Territories2
require.Equal(t, testutils.ToJSON(dest[0].Territories1), testutils.ToJSON(dest[0].Territories2))
// Territories3
require.Equal(t, dest[0].Territories3.TerritoryID, dest[0].Territories1.TerritoryID)
require.Equal(t, dest[0].Territories3.RegionID, int16(0))
require.Equal(t, dest[0].Territories3.TerritoryDescription, "")
require.Equal(t, dest[0].Territories1.CustomColumn1, dest[0].Territories3.CustomColumn1)
require.Equal(t, dest[0].Territories1.CustomColumn2, dest[0].Territories3.CustomColumn2)
// Territories4
require.Equal(t, dest[0].Territories3.Territories, dest[0].Territories4.Territories)
require.Equal(t, dest[0].Territories4.CustomColumn1, "")
require.Equal(t, dest[0].Territories4.CustomColumn2, "")
}
func TestRecursiveWithStatement(t *testing.T) {
subordinates := CTE("subordinates")
stmt := WITH_RECURSIVE(
subordinates.AS(
SELECT(
Employees.AllColumns,
).FROM(
Employees,
).WHERE(
Employees.EmployeeID.EQ(Int(2)),
).UNION(
SELECT(
Employees.AllColumns,
).FROM(
Employees.
INNER_JOIN(subordinates, Employees.EmployeeID.From(subordinates).EQ(Employees.ReportsTo)),
),
),
),
)(
SELECT(
subordinates.AllColumns(),
).FROM(
subordinates,
),
)
//fmt.Println(stmt.DebugSql())
type EmployeeWrap struct {
model.Employees
Subordinates []*EmployeeWrap
}
type employeeID = int16
employeeMap := make(map[employeeID]*EmployeeWrap)
rows, err := stmt.Rows(context.Background(), db)
require.NoError(t, err)
var result *EmployeeWrap
for rows.Next() {
var employeeModel model.Employees
err := rows.Scan(&employeeModel)
require.NoError(t, err)
newEmployeeWrap := &EmployeeWrap{
Employees: employeeModel,
}
employeeMap[employeeModel.EmployeeID] = newEmployeeWrap
if result == nil { // top manager(always first row in the result)
result = newEmployeeWrap
continue
}
if employee, ok := employeeMap[*employeeModel.ReportsTo]; ok {
employee.Subordinates = append(employee.Subordinates, newEmployeeWrap)
}
}
require.NoError(t, rows.Err())
require.NoError(t, rows.Close())
testutils.AssertJSON(t, *result, `
{
"EmployeeID": 2,
"LastName": "Fuller",
"FirstName": "Andrew",
"Title": "Vice President, Sales",
"TitleOfCourtesy": "Dr.",
"BirthDate": "1952-02-19T00:00:00Z",
"HireDate": "1992-08-14T00:00:00Z",
"Address": "908 W. Capital Way",
"City": "Tacoma",
"Region": "WA",
"PostalCode": "98401",
"Country": "USA",
"HomePhone": "(206) 555-9482",
"Extension": "3457",
"Photo": "",
"Notes": "Andrew received his BTS commercial in 1974 and a Ph.D. in international marketing from the University of Dallas in 1981. He is fluent in French and Italian and reads German. He joined the company as a sales representative, was promoted to sales manager in January 1992 and to vice president of sales in March 1993. Andrew is a member of the Sales Management Roundtable, the Seattle Chamber of Commerce, and the Pacific Rim Importers Association.",
"ReportsTo": null,
"PhotoPath": "http://accweb/emmployees/fuller.bmp",
"Subordinates": [
{
"EmployeeID": 1,
"LastName": "Davolio",
"FirstName": "Nancy",
"Title": "Sales Representative",
"TitleOfCourtesy": "Ms.",
"BirthDate": "1948-12-08T00:00:00Z",
"HireDate": "1992-05-01T00:00:00Z",
"Address": "507 - 20th Ave. E.\\nApt. 2A",
"City": "Seattle",
"Region": "WA",
"PostalCode": "98122",
"Country": "USA",
"HomePhone": "(206) 555-9857",
"Extension": "5467",
"Photo": "",
"Notes": "Education includes a BA in psychology from Colorado State University in 1970. She also completed The Art of the Cold Call. Nancy is a member of Toastmasters International.",
"ReportsTo": 2,
"PhotoPath": "http://accweb/emmployees/davolio.bmp",
"Subordinates": null
},
{
"EmployeeID": 3,
"LastName": "Leverling",
"FirstName": "Janet",
"Title": "Sales Representative",
"TitleOfCourtesy": "Ms.",
"BirthDate": "1963-08-30T00:00:00Z",
"HireDate": "1992-04-01T00:00:00Z",
"Address": "722 Moss Bay Blvd.",
"City": "Kirkland",
"Region": "WA",
"PostalCode": "98033",
"Country": "USA",
"HomePhone": "(206) 555-3412",
"Extension": "3355",
"Photo": "",
"Notes": "Janet has a BS degree in chemistry from Boston College (1984). She has also completed a certificate program in food retailing management. Janet was hired as a sales associate in 1991 and promoted to sales representative in February 1992.",
"ReportsTo": 2,
"PhotoPath": "http://accweb/emmployees/leverling.bmp",
"Subordinates": null
},
{
"EmployeeID": 4,
"LastName": "Peacock",
"FirstName": "Margaret",
"Title": "Sales Representative",
"TitleOfCourtesy": "Mrs.",
"BirthDate": "1937-09-19T00:00:00Z",
"HireDate": "1993-05-03T00:00:00Z",
"Address": "4110 Old Redmond Rd.",
"City": "Redmond",
"Region": "WA",
"PostalCode": "98052",
"Country": "USA",
"HomePhone": "(206) 555-8122",
"Extension": "5176",
"Photo": "",
"Notes": "Margaret holds a BA in English literature from Concordia College (1958) and an MA from the American Institute of Culinary Arts (1966). She was assigned to the London office temporarily from July through November 1992.",
"ReportsTo": 2,
"PhotoPath": "http://accweb/emmployees/peacock.bmp",
"Subordinates": null
},
{
"EmployeeID": 5,
"LastName": "Buchanan",
"FirstName": "Steven",
"Title": "Sales Manager",
"TitleOfCourtesy": "Mr.",
"BirthDate": "1955-03-04T00:00:00Z",
"HireDate": "1993-10-17T00:00:00Z",
"Address": "14 Garrett Hill",
"City": "London",
"Region": null,
"PostalCode": "SW1 8JR",
"Country": "UK",
"HomePhone": "(71) 555-4848",
"Extension": "3453",
"Photo": "",
"Notes": "Steven Buchanan graduated from St. Andrews University, Scotland, with a BSC degree in 1976. Upon joining the company as a sales representative in 1992, he spent 6 months in an orientation program at the Seattle office and then returned to his permanent post in London. He was promoted to sales manager in March 1993. Mr. Buchanan has completed the courses Successful Telemarketing and International Sales Management. He is fluent in French.",
"ReportsTo": 2,
"PhotoPath": "http://accweb/emmployees/buchanan.bmp",
"Subordinates": [
{
"EmployeeID": 6,
"LastName": "Suyama",
"FirstName": "Michael",
"Title": "Sales Representative",
"TitleOfCourtesy": "Mr.",
"BirthDate": "1963-07-02T00:00:00Z",
"HireDate": "1993-10-17T00:00:00Z",
"Address": "Coventry House\\nMiner Rd.",
"City": "London",
"Region": null,
"PostalCode": "EC2 7JR",
"Country": "UK",
"HomePhone": "(71) 555-7773",
"Extension": "428",
"Photo": "",
"Notes": "Michael is a graduate of Sussex University (MA, economics, 1983) and the University of California at Los Angeles (MBA, marketing, 1986). He has also taken the courses Multi-Cultural Selling and Time Management for the Sales Professional. He is fluent in Japanese and can read and write French, Portuguese, and Spanish.",
"ReportsTo": 5,
"PhotoPath": "http://accweb/emmployees/davolio.bmp",
"Subordinates": null
},
{
"EmployeeID": 7,
"LastName": "King",
"FirstName": "Robert",
"Title": "Sales Representative",
"TitleOfCourtesy": "Mr.",
"BirthDate": "1960-05-29T00:00:00Z",
"HireDate": "1994-01-02T00:00:00Z",
"Address": "Edgeham Hollow\\nWinchester Way",
"City": "London",
"Region": null,
"PostalCode": "RG1 9SP",
"Country": "UK",
"HomePhone": "(71) 555-5598",
"Extension": "465",
"Photo": "",
"Notes": "Robert King served in the Peace Corps and traveled extensively before completing his degree in English at the University of Michigan in 1992, the year he joined the company. After completing a course entitled Selling in Europe, he was transferred to the London office in March 1993.",
"ReportsTo": 5,
"PhotoPath": "http://accweb/emmployees/davolio.bmp",
"Subordinates": null
},
{
"EmployeeID": 9,
"LastName": "Dodsworth",
"FirstName": "Anne",
"Title": "Sales Representative",
"TitleOfCourtesy": "Ms.",
"BirthDate": "1966-01-27T00:00:00Z",
"HireDate": "1994-11-15T00:00:00Z",
"Address": "7 Houndstooth Rd.",
"City": "London",
"Region": null,
"PostalCode": "WG2 7LT",
"Country": "UK",
"HomePhone": "(71) 555-4444",
"Extension": "452",
"Photo": "",
"Notes": "Anne has a BA degree in English from St. Lawrence College. She is fluent in French and German.",
"ReportsTo": 5,
"PhotoPath": "http://accweb/emmployees/davolio.bmp",
"Subordinates": null
}
]
},
{
"EmployeeID": 8,
"LastName": "Callahan",
"FirstName": "Laura",
"Title": "Inside Sales Coordinator",
"TitleOfCourtesy": "Ms.",
"BirthDate": "1958-01-09T00:00:00Z",
"HireDate": "1994-03-05T00:00:00Z",
"Address": "4726 - 11th Ave. N.E.",
"City": "Seattle",
"Region": "WA",
"PostalCode": "98105",
"Country": "USA",
"HomePhone": "(206) 555-1189",
"Extension": "2344",
"Photo": "",
"Notes": "Laura received a BA in psychology from the University of Washington. She has also completed a course in business French. She reads and writes French.",
"ReportsTo": 2,
"PhotoPath": "http://accweb/emmployees/davolio.bmp",
"Subordinates": null
}
]
}
`)
}
var suppliersWithFax = CTE("suppliers_fax").AS(
SELECT(
Suppliers.SupplierID,
Suppliers.ContactName,
Suppliers.Country,
).FROM(
Suppliers,
).WHERE(Suppliers.Fax.IS_NOT_NULL()),
)
func SuppliersNotFromUSorAUS(suppliersCTE CommonTableExpression) CommonTableExpression {
return CTE("not_from_us_or_aus").AS(
SELECT(
suppliersCTE.AllColumns(),
).FROM(
suppliersCTE,
).WHERE(
Suppliers.Country.From(suppliersCTE).NOT_IN(String("US"), String("Australia")),
),
)
}
func TestCTEReuse(t *testing.T) {
suppliersFilteredByCountry := SuppliersNotFromUSorAUS(suppliersWithFax)
supplierContactName := Suppliers.ContactName.From(suppliersFilteredByCountry)
stmt := WITH(
suppliersWithFax,
suppliersFilteredByCountry,
)(
SELECT(
suppliersFilteredByCountry.AllColumns(),
).FROM(
suppliersFilteredByCountry,
).WHERE(
supplierContactName.NOT_EQ(String("John")),
),
)
// fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, `
WITH suppliers_fax AS (
SELECT suppliers.supplier_id AS "suppliers.supplier_id",
suppliers.contact_name AS "suppliers.contact_name",
suppliers.country AS "suppliers.country"
FROM northwind.suppliers
WHERE suppliers.fax IS NOT NULL
),not_from_us_or_aus AS (
SELECT suppliers_fax."suppliers.supplier_id" AS "suppliers.supplier_id",
suppliers_fax."suppliers.contact_name" AS "suppliers.contact_name",
suppliers_fax."suppliers.country" AS "suppliers.country"
FROM suppliers_fax
WHERE suppliers_fax."suppliers.country" NOT IN ('US', 'Australia')
)
SELECT not_from_us_or_aus."suppliers.supplier_id" AS "suppliers.supplier_id",
not_from_us_or_aus."suppliers.contact_name" AS "suppliers.contact_name",
not_from_us_or_aus."suppliers.country" AS "suppliers.country"
FROM not_from_us_or_aus
WHERE not_from_us_or_aus."suppliers.contact_name" != 'John';
`)
var dest []model.Suppliers
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Len(t, dest, 11)
}
func TestWitStatement_CTE_NotMaterialized(t *testing.T) {
orders1 := CTE("orders1")
orders1ID := Orders.OrderID.From(orders1)
orders2 := orders1.ALIAS("orders2")
orders2ID := Orders.OrderID.From(orders2)
stmt := WITH(
orders1.AS_NOT_MATERIALIZED(
SELECT(
Orders.OrderID,
Orders.EmployeeID,
Orders.ShipCity,
).FROM(
Orders,
),
),
)(
SELECT(
orders1.AllColumns().As("orders1.*"),
orders2.AllColumns().As("orders2.*"),
).FROM(
orders1.
INNER_JOIN(orders2, orders1ID.EQ(orders2ID)),
).WHERE(
orders1ID.LT(Int(10320)),
),
)
// fmt.Println(stmt.Sql())
testutils.AssertStatementSql(t, stmt, `
WITH orders1 AS NOT MATERIALIZED (
SELECT orders.order_id AS "orders.order_id",
orders.employee_id AS "orders.employee_id",
orders.ship_city AS "orders.ship_city"
FROM northwind.orders
)
SELECT orders1."orders.order_id" AS "orders1.order_id",
orders1."orders.employee_id" AS "orders1.employee_id",
orders1."orders.ship_city" AS "orders1.ship_city",
orders2."orders.order_id" AS "orders2.order_id",
orders2."orders.employee_id" AS "orders2.employee_id",
orders2."orders.ship_city" AS "orders2.ship_city"
FROM orders1
INNER JOIN orders1 AS orders2 ON (orders1."orders.order_id" = orders2."orders.order_id")
WHERE orders1."orders.order_id" < $1;
`)
var dest []struct {
Orders1 model.Orders `alias:"orders1.*"`
Orders2 model.Orders `alias:"orders2.*"`
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Len(t, dest, 72)
fmt.Println(len(dest))
}

View file

@ -347,10 +347,13 @@ func TestFloatOperators(t *testing.T) {
AllTypes.Numeric.IS_NOT_DISTINCT_FROM(AllTypes.Numeric).AS("not_distinct1"),
AllTypes.Decimal.IS_NOT_DISTINCT_FROM(Float(12)).AS("not_distinct2"),
AllTypes.Real.IS_NOT_DISTINCT_FROM(Float(12.12)).AS("not_distinct3"),
AllTypes.Numeric.LT(Float(124)).AS("lt1"),
AllTypes.Numeric.LT(Float(34.56)).AS("lt2"),
AllTypes.Numeric.GT(Float(124)).AS("gt1"),
AllTypes.Numeric.GT(Float(34.56)).AS("gt2"),
AllTypes.Numeric.BETWEEN(Float(1.34), AllTypes.Decimal).AS("between"),
AllTypes.Numeric.NOT_BETWEEN(AllTypes.Decimal.MUL(Float(3)), Float(100.12)).AS("not_between"),
AllTypes.Decimal.ADD(AllTypes.Decimal).AS("add1"),
AllTypes.Decimal.ADD(Float(11.22)).AS("add2"),
@ -395,6 +398,8 @@ SELECT (all_types.numeric = all_types.numeric) AS "eq1",
(all_types.numeric < ?) AS "lt2",
(all_types.numeric > ?) AS "gt1",
(all_types.numeric > ?) AS "gt2",
(all_types.numeric BETWEEN ? AND all_types.decimal) AS "between",
(all_types.numeric NOT BETWEEN (all_types.decimal * ?) AND ?) AS "not_between",
(all_types.decimal + all_types.decimal) AS "add1",
(all_types.decimal + ?) AS "add2",
(all_types.decimal - all_types.decimal_ptr) AS "sub1",
@ -441,40 +446,32 @@ func TestIntegerOperators(t *testing.T) {
AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"),
AllTypes.BigInt.EQ(Int(12)).AS("eq2"),
AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"),
AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"),
AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"),
AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"),
AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"),
AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"),
AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"),
AllTypes.BigInt.LT(Int(65)).AS("lt2"),
AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"),
AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"),
AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"),
AllTypes.BigInt.GT(Int(65)).AS("gt2"),
AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"),
AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"),
AllTypes.Integer.BETWEEN(Int(11), Int(200)).AS("between"),
AllTypes.Integer.NOT_BETWEEN(Int(66), Int(77)).AS("not_between"),
AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"),
AllTypes.BigInt.ADD(Int(11)).AS("add2"),
AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"),
AllTypes.BigInt.SUB(Int(11)).AS("sub2"),
AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"),
AllTypes.BigInt.MUL(Int(11)).AS("mul2"),
AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"),
AllTypes.BigInt.DIV(Int(11)).AS("div2"),
AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"),
AllTypes.BigInt.MOD(Int(11)).AS("mod2"),
@ -483,19 +480,15 @@ func TestIntegerOperators(t *testing.T) {
AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"),
AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"),
AllTypes.SmallInt.BIT_OR(AllTypes.SmallInt).AS("bit or 1"),
AllTypes.SmallInt.BIT_OR(Int(22)).AS("bit or 2"),
AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"),
AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"),
BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"),
BIT_NOT(Int(-1).MUL(Int(11))).AS("bit_not_2"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"),
AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"),
AllTypes.SmallInt.BIT_SHIFT_RIGHT(Int(1)).AS("bit shift right 2"),
@ -522,7 +515,8 @@ func TestIntegerOperators(t *testing.T) {
require.Equal(t, *dest[0].BitXor2, int64(5))
require.Equal(t, *dest[0].BitShiftLeft1, int64(1792))
require.Equal(t, *dest[0].BitShiftRight2, int64(7))
require.Equal(t, *dest[0].Between, false)
require.Equal(t, *dest[0].NotBetween, true)
}
func TestStringOperators(t *testing.T) {
@ -540,6 +534,8 @@ func TestStringOperators(t *testing.T) {
AllTypes.Text.LT(String("Text")),
AllTypes.Text.LT_EQ(AllTypes.VarCharPtr),
AllTypes.Text.LT_EQ(String("Text")),
AllTypes.Text.BETWEEN(String("min"), String("max")),
AllTypes.Text.NOT_BETWEEN(AllTypes.VarChar, AllTypes.CharPtr),
AllTypes.Text.CONCAT(String("text2")),
AllTypes.Text.CONCAT(Int(11)),
AllTypes.Text.LIKE(String("abc")),
@ -717,27 +713,23 @@ func TestDateExpressions(t *testing.T) {
AllTypes.Date.EQ(AllTypes.Date),
AllTypes.Date.EQ(Date(2019, 6, 6)),
AllTypes.DatePtr.NOT_EQ(AllTypes.Date),
AllTypes.DatePtr.NOT_EQ(Date(2019, 1, 6)),
AllTypes.Date.IS_DISTINCT_FROM(AllTypes.Date).AS("distinct1"),
AllTypes.Date.IS_DISTINCT_FROM(Date(2008, 7, 4)).AS("distinct2"),
AllTypes.Date.IS_NOT_DISTINCT_FROM(AllTypes.Date),
AllTypes.Date.IS_NOT_DISTINCT_FROM(Date(2019, 3, 6)),
AllTypes.Date.LT(AllTypes.Date),
AllTypes.Date.LT(Date(2019, 4, 6)),
AllTypes.Date.LT_EQ(AllTypes.Date),
AllTypes.Date.LT_EQ(Date(2019, 5, 5)),
AllTypes.Date.GT(AllTypes.Date),
AllTypes.Date.GT(Date(2019, 1, 4)),
AllTypes.Date.GT_EQ(AllTypes.Date),
AllTypes.Date.GT_EQ(Date(2019, 2, 3)),
AllTypes.Date.BETWEEN(Date(2000, 2, 2), AllTypes.DatePtr),
AllTypes.Date.NOT_BETWEEN(AllTypes.DatePtr, Date(2000, 2, 2)),
//AllTypes.Date.ADD(INTERVAL2(2, HOUR)),
//AllTypes.Date.ADD(INTERVAL2(1, DAY, 7, MONTH)),
@ -790,12 +782,12 @@ func TestTimeExpressions(t *testing.T) {
AllTypes.TimePtr.NOT_EQ(AllTypes.Time),
AllTypes.TimePtr.NOT_EQ(Time(20, 16, 6)),
AllTypes.Time.IS_DISTINCT_FROM(AllTypes.Time),
AllTypes.Time.IS_DISTINCT_FROM(Time(19, 26, 6)),
AllTypes.Time.IS_NOT_DISTINCT_FROM(AllTypes.Time),
AllTypes.Time.IS_NOT_DISTINCT_FROM(Time(18, 36, 6)),
AllTypes.Time.BETWEEN(Time(11, 0, 30, 100), AllTypes.TimePtr),
AllTypes.Time.NOT_BETWEEN(AllTypes.TimePtr, TIME(time.Now())),
AllTypes.Time.LT(AllTypes.Time),
AllTypes.Time.LT(Time(17, 46, 6)),
@ -822,6 +814,8 @@ func TestTimeExpressions(t *testing.T) {
CURRENT_TIME(),
)
//fmt.Println(query.DebugSql())
var dest struct {
Time1 string
Time2 time.Time
@ -855,27 +849,23 @@ func TestDateTimeExpressions(t *testing.T) {
AllTypes.DateTime.EQ(AllTypes.DateTime),
AllTypes.DateTime.EQ(dateTime),
AllTypes.DateTimePtr.NOT_EQ(AllTypes.DateTime),
AllTypes.DateTimePtr.NOT_EQ(DateTime(2019, 6, 6, 10, 2, 46, 100*time.Millisecond)),
AllTypes.DateTime.IS_DISTINCT_FROM(AllTypes.DateTime),
AllTypes.DateTime.IS_DISTINCT_FROM(dateTime),
AllTypes.DateTime.IS_NOT_DISTINCT_FROM(AllTypes.DateTime),
AllTypes.DateTime.IS_NOT_DISTINCT_FROM(dateTime),
AllTypes.DateTime.LT(AllTypes.DateTime),
AllTypes.DateTime.LT(dateTime),
AllTypes.DateTime.LT_EQ(AllTypes.DateTime),
AllTypes.DateTime.LT_EQ(dateTime),
AllTypes.DateTime.GT(AllTypes.DateTime),
AllTypes.DateTime.GT(dateTime),
AllTypes.DateTime.GT_EQ(AllTypes.DateTime),
AllTypes.DateTime.GT_EQ(dateTime),
AllTypes.DateTime.BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr),
AllTypes.DateTime.NOT_BETWEEN(AllTypes.DateTimePtr, AllTypes.TimestampPtr),
//AllTypes.DateTime.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)),
//AllTypes.DateTime.ADD(INTERVALe(AllTypes.BigInt, HOUR)),

Some files were not shown because too many files have changed in this diff Show more