Merge pull request #10 from go-jet/develop

Release 2.0.0: MySQL and MariaDB support.
This commit is contained in:
go-jet 2019-08-20 10:09:43 +02:00 committed by GitHub
commit 361e3605f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
209 changed files with 12509 additions and 467897 deletions

View file

@ -3,17 +3,25 @@
# Check https://circleci.com/docs/2.0/language-go/ for more details # Check https://circleci.com/docs/2.0/language-go/ for more details
version: 2 version: 2
jobs: jobs:
build: build-postgres-and-mysql:
docker: docker:
# specify the version # specify the version
- image: circleci/golang:1.11 - image: circleci/golang:1.11
- image: circleci/postgres:10.6-alpine - image: circleci/postgres:10.8-alpine
environment: # environment variables for primary container environment: # environment variables for primary container
POSTGRES_USER: jet POSTGRES_USER: jet
POSTGRES_PASSWORD: jet POSTGRES_PASSWORD: jet
POSTGRES_DB: jetdb POSTGRES_DB: jetdb
- image: circleci/mysql:8.0
command: [--default-authentication-plugin=mysql_native_password]
environment:
MYSQL_ROOT_PASSWORD: jet
MYSQL_DATABASE: dvds
MYSQL_USER: jet
MYSQL_PASSWORD: jet
working_directory: /go/src/github.com/go-jet/jet working_directory: /go/src/github.com/go-jet/jet
environment: # environment variables for the build itself environment: # environment variables for the build itself
@ -22,12 +30,20 @@ jobs:
steps: steps:
- checkout - checkout
# specify any bash command here prefixed with `run: ` - run:
name: Submodule init
command: |
git submodule init
git submodule update
cd ./tests/testdata && git fetch && git checkout master
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
go get github.com/google/uuid go get github.com/google/uuid
go get github.com/lib/pq go get github.com/lib/pq
go get github.com/go-sql-driver/mysql
go get github.com/pkg/profile go get github.com/pkg/profile
go get gotest.tools/assert go get gotest.tools/assert
@ -48,14 +64,37 @@ jobs:
echo Failed waiting for Postgres && exit 1 echo Failed waiting for Postgres && exit 1
- run: - run:
name: Init Postgres database name: Waiting for MySQL to be ready
command: | command: |
cd tests for i in `seq 1 10`;
go run ./init/init.go do
cd .. nc -z 127.0.0.1 3306 && echo Success && exit 0
echo -n .
sleep 1
done
echo Failed waiting for MySQL && exit 1
- run:
name: Install MySQL CLI;
command: |
sudo apt-get install default-mysql-client
- run:
name: Create MySQL user and databases
command: |
mysql -h 127.0.0.1 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -u jet -pjet -e "create database test_sample"
- run:
name: Init Postgres database
command: |
cd tests
go run ./init/init.go -testsuite all
cd ..
- run: mkdir -p $TEST_RESULTS - run: mkdir -p $TEST_RESULTS
- run: go test -v . ./tests -coverpkg=github.com/go-jet/jet,github.com/go-jet/jet/execution/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml - run: go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/execution/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
- run: - run:
name: Upload code coverage name: Upload code coverage
@ -67,4 +106,75 @@ jobs:
- store_test_results: # Upload test results for display in Test Summary: https://circleci.com/docs/2.0/collect-test-data/ - store_test_results: # Upload test results for display in Test Summary: https://circleci.com/docs/2.0/collect-test-data/
path: /tmp/test-results path: /tmp/test-results
build-mariadb:
docker:
# specify the version
- image: circleci/golang:1.11
- image: circleci/mariadb:10.3
command: [--default-authentication-plugin=mysql_native_password]
environment:
MYSQL_ROOT_PASSWORD: jet
MYSQL_DATABASE: dvds
MYSQL_USER: jet
MYSQL_PASSWORD: jet
working_directory: /go/src/github.com/go-jet/jet
environment: # environment variables for the build itself
TEST_RESULTS: /tmp/test-results # path to where test results will be saved
steps:
- checkout
- run:
name: Submodule init
command: |
git submodule init
git submodule update
cd ./tests/testdata && git fetch && git checkout master
- run:
name: Install dependencies
command: |
go get github.com/google/uuid
go get github.com/lib/pq
go get github.com/go-sql-driver/mysql
go get github.com/pkg/profile
go get gotest.tools/assert
go get github.com/davecgh/go-spew/spew
go get github.com/jstemmer/go-junit-report
go install github.com/go-jet/jet/cmd/jet
- run:
name: Install MySQL CLI;
command: |
sudo apt-get install default-mysql-client
- run:
name: Init MariaDB database
command: |
mysql -h 127.0.0.1 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -u jet -pjet -e "create database test_sample"
- run:
name: Init MariaDB database
command: |
cd tests
go run ./init/init.go -testsuite MariaDB
cd ..
- run:
name: Run MariaDB tests
command: |
go test -v ./tests/mysql/ -source=MariaDB
workflows:
version: 2
build_and_test:
jobs:
- build-postgres-and-mysql
- build-mariadb

1
.gitignore vendored
View file

@ -18,3 +18,4 @@
# Test files # Test files
gen gen
.gentestdata .gentestdata
.tests/testdata/

3
.gitmodules vendored Normal file
View file

@ -0,0 +1,3 @@
[submodule "tests/testdata"]
path = tests/testdata
url = https://github.com/go-jet/jet-test-data

View file

@ -1,16 +1,19 @@
# Jet # Jet
[![CircleCI](https://circleci.com/gh/go-jet/jet/tree/develop.svg?style=svg&circle-token=97f255c6a4a3ab6590ea2e9195eb3ebf9f97b4a7)](https://circleci.com/gh/go-jet/jet/tree/develop) [![CircleCI](https://circleci.com/gh/go-jet/jet/tree/master.svg?style=svg&circle-token=97f255c6a4a3ab6590ea2e9195eb3ebf9f97b4a7)](https://circleci.com/gh/go-jet/jet/tree/develop)
[![codecov](https://codecov.io/gh/go-jet/jet/branch/develop/graph/badge.svg)](https://codecov.io/gh/go-jet/jet) [![codecov](https://codecov.io/gh/go-jet/jet/branch/master/graph/badge.svg)](https://codecov.io/gh/go-jet/jet)
[![Go Report Card](https://goreportcard.com/badge/github.com/go-jet/jet)](https://goreportcard.com/report/github.com/go-jet/jet) [![Go Report Card](https://goreportcard.com/badge/github.com/go-jet/jet)](https://goreportcard.com/report/github.com/go-jet/jet)
[![Documentation](https://godoc.org/github.com/go-jet/jet?status.svg)](http://godoc.org/github.com/go-jet/jet) [![Documentation](https://godoc.org/github.com/go-jet/jet?status.svg)](http://godoc.org/github.com/go-jet/jet)
[![GitHub release](https://img.shields.io/github/release/go-jet/jet.svg)](https://github.com/go-jet/jet/releases) [![GitHub release](https://img.shields.io/github/release/go-jet/jet.svg)](https://github.com/go-jet/jet/releases)
Jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability to easily Jet is a framework for writing type-safe SQL queries in Go, with ability to easily
convert database query result to desired arbitrary structure. convert database query result into desired arbitrary object structure.
_*Support for additional databases will be added in future jet releases._ Jet currently supports `PostgreSQL`, `MySQL` and `MariaDB`. Future releases will add support for additional databases.
![jet](https://github.com/go-jet/jet/wiki/image/jet.png)
Jet is the easiest and fastest way to write complex SQL queries and map database query result
into complex object composition. __It is not an ORM.__
## Contents ## Contents
- [Features](#features) - [Features](#features)
@ -28,14 +31,18 @@ _*Support for additional databases will be added in future jet releases._
## Features ## Features
1) Auto-generated type-safe SQL Builder 1) Auto-generated type-safe SQL Builder
- Types - boolean, integers(smallint, integer, bigint), floats(real, numeric, decimal, double precision), - PostgreSQL:
strings(text, character, character varying), date, time(z), timestamp(z) and enums. * SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, sub-queries)`
- Statements: * INSERT `(VALUES, query, RETURNING)`,
* SELECT (DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, sub-queries) * UPDATE `(SET, WHERE, RETURNING)`,
* INSERT (VALUES, query, RETURNING), * DELETE `(WHERE, RETURNING)`,
* UPDATE (SET, WHERE, RETURNING), * LOCK `(IN, NOWAIT)`
* DELETE (WHERE, RETURNING), - MySQL and MariaDB:
* LOCK (IN, NOWAIT) * SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, sub-queries)`
* INSERT `(VALUES, query)`,
* UPDATE `(SET, WHERE)`,
* DELETE `(WHERE, ORDER_BY, LIMIT)`,
* LOCK `(READ, WRITE)`
2) Auto-generated Data Model types - Go types mapped to database type (table or enum), used to store 2) Auto-generated Data Model types - Go types mapped to database type (table or enum), used to store
result of database queries. Can be combined to create desired query result destination. result of database queries. Can be combined to create desired query result destination.
3) Query execution with result mapping to arbitrary destination structure. 3) Query execution with result mapping to arbitrary destination structure.
@ -64,16 +71,16 @@ go install github.com/go-jet/jet/cmd/jet
Make sure GOPATH bin folder is added to the PATH environment variable. Make sure GOPATH bin folder is added to the PATH environment variable.
### Quick Start ### Quick Start
For this quick start example we will use sample _dvd rental_ database. Full database dump can be found in [./tests/init/data/dvds.sql](./tests/init/data/dvds.sql). For this quick start example we will use PostgreSQL sample _'dvd rental'_ database. Full database dump can be found in [./tests/testdata/init/postgres/dvds.sql](./tests/testdata/init/postgres/dvds.sql).
Schema diagram of interest for example can be found [here](./examples/quick-start/diagram.png). Schema diagram of interest for example can be found [here](./examples/quick-start/diagram.png).
#### Generate SQL Builder and Model files #### Generate SQL Builder and Model files
To generate jet SQL Builder and Data Model files from postgres database we need to call `jet` generator with postgres To generate jet SQL Builder and Data Model files from postgres database, we need to call `jet` generator with postgres
connection parameters and root destination folder path for generated files.\ connection parameters and root destination folder path for generated files.\
Assuming we are running local postgres database, with user `jetuser`, user password `jetpass`, database `jetdb` and Assuming we are running local postgres database, with user `jetuser`, user password `jetpass`, database `jetdb` and
schema `dvds` we will use this command: schema `dvds` we will use this command:
```sh ```sh
jet -host=localhost -port=5432 -user=jetuser -password=jetpass -dbname=jetdb -schema=dvds -path=./gen jet -source=PostgreSQL -host=localhost -port=5432 -user=jetuser -password=jetpass -dbname=jetdb -schema=dvds -path=./gen
``` ```
```sh ```sh
Connecting to postgres database: host=localhost port=5432 user=jetuser password=jetpass dbname=jetdb sslmode=disable Connecting to postgres database: host=localhost port=5432 user=jetuser password=jetpass dbname=jetdb sslmode=disable
@ -87,7 +94,9 @@ Generating enum sql builder files...
Generating enum model files... Generating enum model files...
Done Done
``` ```
_*User has to have a permission to read information schema tables_ Procedure is similar for MySQL or MariaDB, except source should be replaced with `MySql` or `MariaDB` and schema name should
be omitted (both databases doesn't have schema support).
_*User has to have a permission to read information schema tables._
As command output suggest, Jet will: As command output suggest, Jet will:
- connect to postgres database and retrieve information about the _tables_ and _enums_ of `dvds` schema - connect to postgres database and retrieve information about the _tables_ and _enums_ of `dvds` schema
@ -116,15 +125,17 @@ Generated files folder structure will look like this:
Types from `table` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store Types from `table` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store
results of the SQL queries. results of the SQL queries.
#### Lets write some SQL queries in Go #### Lets write some SQL queries in Go
First we need to import jet and generated files from previous step: First we need to import jet and generated files from previous step:
```go ```go
import ( import (
// dot import so that Go code would resemble as much as native SQL // dot import so go code would resemble as much as native SQL
// dot import is not mandatory // dot import is not mandatory
. "github.com/go-jet/jet" . "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table"
. "github.com/go-jet/jet/examples/quick-start/gen/jetdb/dvds/table" . "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/examples/quick-start/gen/jetdb/dvds/model" "github.com/go-jet/jet/examples/quick-start/gen/jetdb/dvds/model"
) )
@ -153,13 +164,13 @@ stmt := SELECT(
Film.FilmID.ASC(), Film.FilmID.ASC(),
) )
``` ```
With package(dot) import above statements looks almost the same as native SQL. Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with Package(dot) import is used so that statement would resemble as much as possible as native SQL. Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with
string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns
and can be compared only with integer columns and expressions. and can be compared only with integer columns and expressions.
__How to get parametrized SQL query from statement?__ __How to get parametrized SQL query from statement?__
```go ```go
query, args, err := stmt.Sql() query, args := stmt.Sql()
``` ```
query - parametrized query\ query - parametrized query\
args - parameters for the query args - parameters for the query
@ -209,9 +220,9 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
__How to get debug SQL from statement?__ __How to get debug SQL from statement?__
```go ```go
debugSql, err := stmt.DebugSql() debugSql := stmt.DebugSql()
``` ```
debugSql - query string that can be copy pasted to sql editor and executed. It's not intended to be used in production. debugSql - query string that can be copy pasted to sql editor and executed. __It's not intended to be used in production!!!__
<details> <details>
<summary>Click to see debug sql</summary> <summary>Click to see debug sql</summary>
@ -275,9 +286,9 @@ var dest []struct {
``` ```
Because one actor can act in multiple films, `Films` field is a slice, and because each film belongs to one language Because one actor can act in multiple films, `Films` field is a slice, and because each film belongs to one language
`Langauge` field is just a single model struct. `Langauge` field is just a single model struct.
_*There is no limitation of how big or nested destination structure can be._ _*There is no limitation of how big or nested destination can be._
Now lets execute a above statement on open database connection db and store result into `dest`. Now lets execute a above statement on open database connection (or transaction) db and store result into `dest`.
```go ```go
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
@ -485,12 +496,12 @@ found at project [wiki](https://github.com/go-jet/jet/wiki) page.
## Benefits ## Benefits
What are the benefits of writing SQL in Go using Jet? The biggest benefit is speed. What are the benefits of writing SQL in Go using Jet?
Speed is improved in 3 major areas: The biggest benefit is speed. Speed is improved in 3 major areas:
##### Speed of development ##### Speed of development
Writing SQL queries is much easier directly from Go, because programmer has the help of SQL code completion and SQL type safety directly in Go. Writing SQL queries is much easier, because programmer has the help of SQL code completion and SQL type safety directly in Go.
Writing code is much faster and code is more robust. Automatic scan to arbitrary structure removes a lot of headache and Writing code is much faster and code is more robust. Automatic scan to arbitrary structure removes a lot of headache and
boilerplate code needed to structure database query result. boilerplate code needed to structure database query result.
@ -508,7 +519,7 @@ return result in one database call. Handler execution will be only proportional
ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate). ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate).
With Jet you can even join the whole database and store the whole structured result in in one query call. With Jet you can even join the whole database and store the whole structured result in in one query call.
This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/chinook_db_test.go#L40). This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/postgres/chinook_db_test.go#L40).
The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s. The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s.
##### How quickly bugs are found ##### How quickly bugs are found
@ -529,14 +540,14 @@ Without Jet these bugs will have to be either caught by some test or by manual t
## Dependencies ## Dependencies
At the moment Jet dependence only of: At the moment Jet dependence only of:
- `github.com/google/uuid` _(Used for debug purposes and in data model files)_ - `github.com/lib/pq` _(Used by jet generator to read information about database schema from `PostgreSQL`)_
- `github.com/lib/pq` _(Used by Jet to read information about database schema types)_ - `github.com/go-sql-driver/mysql` _(Used by jet generator to read information about database from `MySQL` and `MariaDB`)_
- `github.com/google/uuid` _(Used in data model files and for debug purposes)_
To run the tests, additional dependencies are required: To run the tests, additional dependencies are required:
- `github.com/pkg/profile` - `github.com/pkg/profile`
- `gotest.tools/assert` - `gotest.tools/assert`
## Versioning ## Versioning
[SemVer](http://semver.org/) is used for versioning. For the versions available, see the [releases](https://github.com/go-jet/jet/releases). [SemVer](http://semver.org/) is used for versioning. For the versions available, see the [releases](https://github.com/go-jet/jet/releases).

View file

@ -1,34 +0,0 @@
package jet
type alias struct {
expression Expression
alias string
}
func newAlias(expression Expression, aliasName string) projection {
return &alias{
expression: expression,
alias: aliasName,
}
}
func (a *alias) from(subQuery SelectTable) projection {
column := newColumn(a.alias, "", nil)
column.parent = &column
column.subQuery = subQuery
return &column
}
func (a *alias) serializeForProjection(statement statementType, out *sqlBuilder) error {
err := a.expression.serialize(statement, out)
if err != nil {
return err
}
out.writeString("AS")
out.writeQuotedString(a.alias)
return nil
}

133
cast.go
View file

@ -1,133 +0,0 @@
package jet
import "fmt"
type cast interface {
// Cast expression AS bool type
AS_BOOL() BoolExpression
// Cast expression AS smallint type
AS_SMALLINT() IntegerExpression
// Cast expression AS integer type
AS_INTEGER() IntegerExpression
// Cast expression AS bigint type
AS_BIGINT() IntegerExpression
// Cast expression AS numeric type, using precision and optionally scale
AS_NUMERIC(precision int, scale ...int) FloatExpression
// Cast expression AS real type
AS_REAL() FloatExpression
// Cast expression AS double precision type
AS_DOUBLE() FloatExpression
// Cast expression AS text type
AS_TEXT() StringExpression
// Cast expression AS date type
AS_DATE() DateExpression
// Cast expression AS time type
AS_TIME() TimeExpression
// Cast expression AS time with time timezone type
AS_TIMEZ() TimezExpression
// Cast expression AS timestamp type
AS_TIMESTAMP() TimestampExpression
// Cast expression AS timestamp with timezone type
AS_TIMESTAMPZ() TimestampzExpression
}
type castImpl struct {
Expression
castType string
}
// CAST wraps expression for casting.
// For instance: CAST(table.column).AS_BOOL()
func CAST(expression Expression) cast {
return &castImpl{
Expression: expression,
}
}
func (b *castImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
err := b.Expression.serialize(statement, out, options...)
out.writeString("::" + b.castType)
return err
}
func (b *castImpl) AS_BOOL() BoolExpression {
b.castType = "boolean"
return BoolExp(b)
}
func (b *castImpl) AS_SMALLINT() IntegerExpression {
b.castType = "smallint"
return IntExp(b)
}
// Cast expression AS integer type
func (b *castImpl) AS_INTEGER() IntegerExpression {
b.castType = "integer"
return IntExp(b)
}
// Cast expression AS bigint type
func (b *castImpl) AS_BIGINT() IntegerExpression {
b.castType = "bigint"
return IntExp(b)
}
// Cast expression AS numeric type, using precision and optionally scale
func (b *castImpl) AS_NUMERIC(precision int, scale ...int) FloatExpression {
if len(scale) > 0 {
b.castType = fmt.Sprintf("numeric(%d, %d)", precision, scale[0])
} else {
b.castType = fmt.Sprintf("numeric(%d)", precision)
}
return FloatExp(b)
}
// Cast expression AS real type
func (b *castImpl) AS_REAL() FloatExpression {
b.castType = "real"
return FloatExp(b)
}
// Cast expression AS double precision type
func (b *castImpl) AS_DOUBLE() FloatExpression {
b.castType = "double precision"
return FloatExp(b)
}
// Cast expression AS text type
func (b *castImpl) AS_TEXT() StringExpression {
b.castType = "text"
return StringExp(b)
}
// Cast expression AS date type
func (b *castImpl) AS_DATE() DateExpression {
b.castType = "date"
return DateExp(b)
}
// Cast expression AS time type
func (b *castImpl) AS_TIME() TimeExpression {
b.castType = "time without time zone"
return TimeExp(b)
}
// Cast expression AS time with time timezone type
func (b *castImpl) AS_TIMEZ() TimezExpression {
b.castType = "time with time zone"
return TimezExp(b)
}
// Cast expression AS timestamp type
func (b *castImpl) AS_TIMESTAMP() TimestampExpression {
b.castType = "timestamp without time zone"
return TimestampExp(b)
}
// Cast expression AS timestamp with timezone type
func (b *castImpl) AS_TIMESTAMPZ() TimestampzExpression {
b.castType = "timestamp with time zone"
return TimestampzExp(b)
}

255
clause.go
View file

@ -1,255 +0,0 @@
package jet
import (
"bytes"
"github.com/go-jet/jet/internal/utils"
"github.com/google/uuid"
"strconv"
"strings"
"time"
)
type serializeOption int
const (
noWrap serializeOption = iota
)
type clause interface {
serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error
}
func contains(options []serializeOption, option serializeOption) bool {
for _, opt := range options {
if opt == option {
return true
}
}
return false
}
type sqlBuilder struct {
buff bytes.Buffer
args []interface{}
lastChar byte
ident int
}
type statementType string
const (
selectStatement statementType = "SELECT"
insertStatement statementType = "INSERT"
updateStatement statementType = "UPDATE"
deleteStatement statementType = "DELETE"
setStatement statementType = "SET"
lockStatement statementType = "LOCK"
)
const defaultIdent = 5
func (q *sqlBuilder) increaseIdent() {
q.ident += defaultIdent
}
func (q *sqlBuilder) decreaseIdent() {
if q.ident < defaultIdent {
q.ident = 0
}
q.ident -= defaultIdent
}
func (q *sqlBuilder) writeProjections(statement statementType, projections []projection) error {
q.increaseIdent()
err := serializeProjectionList(statement, projections, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeFrom(statement statementType, table ReadableTable) error {
q.newLine()
q.writeString("FROM")
q.increaseIdent()
err := table.serialize(statement, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeWhere(statement statementType, where Expression) error {
q.newLine()
q.writeString("WHERE")
q.increaseIdent()
err := where.serialize(statement, q, noWrap)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeGroupBy(statement statementType, groupBy []groupByClause) error {
q.newLine()
q.writeString("GROUP BY")
q.increaseIdent()
err := serializeGroupByClauseList(statement, groupBy, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeOrderBy(statement statementType, orderBy []orderByClause) error {
q.newLine()
q.writeString("ORDER BY")
q.increaseIdent()
err := serializeOrderByClauseList(statement, orderBy, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeHaving(statement statementType, having Expression) error {
q.newLine()
q.writeString("HAVING")
q.increaseIdent()
err := having.serialize(statement, q, noWrap)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeReturning(statement statementType, returning []projection) error {
if len(returning) == 0 {
return nil
}
q.newLine()
q.writeString("RETURNING")
q.increaseIdent()
return q.writeProjections(statement, returning)
}
func (q *sqlBuilder) newLine() {
q.write([]byte{'\n'})
q.write(bytes.Repeat([]byte{' '}, q.ident))
}
func (q *sqlBuilder) write(data []byte) {
if len(data) == 0 {
return
}
if !isPreSeparator(q.lastChar) && !isPostSeparator(data[0]) && q.buff.Len() > 0 {
q.buff.WriteByte(' ')
}
q.buff.Write(data)
q.lastChar = data[len(data)-1]
}
func isPreSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':'
}
func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
}
func (q *sqlBuilder) writeQuotedString(str string) {
q.writeString(`"` + str + `"`)
}
func (q *sqlBuilder) writeString(str string) {
q.write([]byte(str))
}
func (q *sqlBuilder) writeIdentifier(name string) {
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
if quoteWrap {
q.writeString(`"` + name + `"`)
} else {
q.writeString(name)
}
}
func (q *sqlBuilder) writeByte(b byte) {
q.write([]byte{b})
}
func (q *sqlBuilder) finalize() (string, []interface{}) {
return q.buff.String() + ";\n", q.args
}
func (q *sqlBuilder) insertConstantArgument(arg interface{}) {
q.writeString(argToString(arg))
}
func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) {
q.args = append(q.args, arg)
argPlaceholder := "$" + strconv.Itoa(len(q.args))
q.writeString(argPlaceholder)
}
func argToString(value interface{}) string {
if utils.IsNil(value) {
return "NULL"
}
switch bindVal := value.(type) {
case bool:
if bindVal {
return "TRUE"
}
return "FALSE"
case int8:
return strconv.FormatInt(int64(bindVal), 10)
case int:
return strconv.FormatInt(int64(bindVal), 10)
case int16:
return strconv.FormatInt(int64(bindVal), 10)
case int32:
return strconv.FormatInt(int64(bindVal), 10)
case int64:
return strconv.FormatInt(int64(bindVal), 10)
case uint8:
return strconv.FormatUint(uint64(bindVal), 10)
case uint:
return strconv.FormatUint(uint64(bindVal), 10)
case uint16:
return strconv.FormatUint(uint64(bindVal), 10)
case uint32:
return strconv.FormatUint(uint64(bindVal), 10)
case uint64:
return strconv.FormatUint(uint64(bindVal), 10)
case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case float64:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case string:
return stringQuote(bindVal)
case []byte:
return stringQuote(string(bindVal))
case uuid.UUID:
return stringQuote(bindVal.String())
case time.Time:
return stringQuote(string(utils.FormatTimestamp(bindVal)))
default:
return "[Unsupported type]"
}
}
func stringQuote(value string) string {
return `'` + strings.Replace(value, "'", "''", -1) + `'`
}

View file

@ -3,12 +3,19 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"github.com/go-jet/jet/generator/postgres" mysqlgen "github.com/go-jet/jet/generator/mysql"
postgresgen "github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/postgres"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"os" "os"
"strings"
) )
var ( var (
source string
host string host string
port int port int
user string user string
@ -22,14 +29,16 @@ var (
) )
func init() { func init() {
flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL or MariaDB)")
flag.StringVar(&host, "host", "", "Database host path (Example: localhost)") flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
flag.IntVar(&port, "port", 0, "Database port") flag.IntVar(&port, "port", 0, "Database port")
flag.StringVar(&user, "user", "", "Database user") flag.StringVar(&user, "user", "", "Database user")
flag.StringVar(&password, "password", "", "The users password") flag.StringVar(&password, "password", "", "The users password")
flag.StringVar(&sslmode, "sslmode", "disable", "Whether or not to use SSL(optional)")
flag.StringVar(&params, "params", "", "Additional connection string parameters(optional)") flag.StringVar(&params, "params", "", "Additional connection string parameters(optional)")
flag.StringVar(&dbName, "dbname", "", "name of the database") flag.StringVar(&dbName, "dbname", "", "Database name")
flag.StringVar(&schemaName, "schema", "public", "Database schema name.") flag.StringVar(&schemaName, "schema", "public", `Database schema name. (default "public") (ignored for MySQL and MariaDB)`)
flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL(optional)(default "disable") (ignored for MySQL and MariaDB)`)
flag.StringVar(&destDir, "path", "", "Destination dir for files generated.") flag.StringVar(&destDir, "path", "", "Destination dir for files generated.")
} }
@ -38,7 +47,11 @@ func main() {
flag.Usage = func() { flag.Usage = func() {
_, _ = fmt.Fprint(os.Stdout, ` _, _ = fmt.Fprint(os.Stdout, `
Usage of jet: Jet generator 2.0.0
Usage:
-source string
Database system name (PostgreSQL, MySQL or MariaDB)
-host string -host string
Database host path (Example: localhost) Database host path (Example: localhost)
-port int -port int
@ -48,13 +61,13 @@ Usage of jet:
-password string -password string
The users password The users password
-dbname string -dbname string
name of the database Database name
-params string -params string
Additional connection string parameters(optional) Additional connection string parameters(optional)
-schema string -schema string
Database schema name. (default "public") Database schema name. (default "public") (ignored for MySQL and MariaDB)
-sslmode string -sslmode string
Whether or not to use SSL(optional) (default "disable") Whether or not to use SSL(optional) (default "disable") (ignored for MySQL and MariaDB)
-path string -path string
Destination dir for files generated. Destination dir for files generated.
`) `)
@ -62,28 +75,54 @@ Usage of jet:
flag.Parse() flag.Parse()
if host == "" || port == 0 || user == "" || dbName == "" || schemaName == "" { if source == "" || host == "" || port == 0 || user == "" || dbName == "" {
fmt.Println("\njet: required flag missing") printErrorAndExit("\nERROR: required flag(s) missing")
flag.Usage()
os.Exit(-2)
} }
genData := postgres.DBConnection{ var err error
Host: host,
Port: port,
User: user,
Password: password,
SslMode: sslmode,
Params: params,
DBName: dbName, switch strings.ToLower(strings.TrimSpace(source)) {
SchemaName: schemaName, case strings.ToLower(postgres.Dialect.Name()),
strings.ToLower(postgres.Dialect.PackageName()):
genData := postgresgen.DBConnection{
Host: host,
Port: port,
User: user,
Password: password,
SslMode: sslmode,
Params: params,
DBName: dbName,
SchemaName: schemaName,
}
err = postgresgen.Generate(destDir, genData)
case strings.ToLower(mysql.Dialect.Name()), "mariadb":
dbConn := mysqlgen.DBConnection{
Host: host,
Port: port,
User: user,
Password: password,
Params: params,
DBName: dbName,
}
err = mysqlgen.Generate(destDir, dbConn)
default:
fmt.Println("ERROR: unsupported source " + source + ". " + postgres.Dialect.Name() + " and " + mysql.Dialect.Name() + " are currently supported.")
os.Exit(-4)
} }
err := postgres.Generate(destDir, genData)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
os.Exit(-1) os.Exit(-5)
} }
} }
func printErrorAndExit(error string) {
fmt.Println(error)
flag.Usage()
os.Exit(-2)
}

145
column.go
View file

@ -1,145 +0,0 @@
// Modeling of columns
package jet
type column interface {
Name() string
TableName() string
setTableName(table string)
setSubQuery(subQuery SelectTable)
defaultAlias() string
}
// Column is common column interface for all types of columns.
type Column interface {
Expression
column
}
// The base type for real materialized columns.
type columnImpl struct {
expressionInterfaceImpl
name string
tableName string
subQuery SelectTable
}
func newColumn(name string, tableName string, parent Column) columnImpl {
bc := columnImpl{
name: name,
tableName: tableName,
}
bc.expressionInterfaceImpl.parent = parent
return bc
}
func (c *columnImpl) Name() string {
return c.name
}
func (c *columnImpl) TableName() string {
return c.tableName
}
func (c *columnImpl) setTableName(table string) {
c.tableName = table
}
func (c *columnImpl) setSubQuery(subQuery SelectTable) {
c.subQuery = subQuery
}
func (c *columnImpl) defaultAlias() string {
if c.tableName != "" {
return c.tableName + "." + c.name
}
return c.name
}
func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
if statement == setStatement {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.writeString(`"` + c.defaultAlias() + `"`) //always quote
return nil
}
return c.serialize(statement, out)
}
func (c columnImpl) serializeForProjection(statement statementType, out *sqlBuilder) error {
err := c.serialize(statement, out)
if err != nil {
return err
}
out.writeString(`AS "` + c.defaultAlias() + `"`)
return nil
}
func (c columnImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if c.subQuery != nil {
out.writeIdentifier(c.subQuery.Alias())
out.writeByte('.')
out.writeQuotedString(c.defaultAlias())
} else {
if c.tableName != "" {
out.writeIdentifier(c.tableName)
out.writeByte('.')
}
out.writeIdentifier(c.name)
}
return nil
}
//------------------------------------------------------//
// ColumnList is redefined type to support list of columns as single projection
type ColumnList []Column
// projection interface implementation
func (cl ColumnList) isProjectionType() {}
func (cl ColumnList) from(subQuery SelectTable) projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.from(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serializeForProjection(statement statementType, out *sqlBuilder) error {
projections := columnListToProjectionList(cl)
err := serializeProjectionList(statement, projections, out)
if err != nil {
return err
}
return nil
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }

View file

@ -1,45 +0,0 @@
package jet
import "testing"
var dateVar = Date(2000, 12, 30)
func TestDateExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.EQ(table2ColDate), "(table1.col_date = table2.col_date)")
assertClauseSerialize(t, table1ColDate.EQ(dateVar), "(table1.col_date = $1::date)", "2000-12-30")
}
func TestDateExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.NOT_EQ(table2ColDate), "(table1.col_date != table2.col_date)")
assertClauseSerialize(t, table1ColDate.NOT_EQ(dateVar), "(table1.col_date != $1::date)", "2000-12-30")
}
func TestDateExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(table2ColDate), "(table1.col_date IS DISTINCT FROM table2.col_date)")
assertClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(dateVar), "(table1.col_date IS DISTINCT FROM $1::date)", "2000-12-30")
}
func TestDateExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(table2ColDate), "(table1.col_date IS NOT DISTINCT FROM table2.col_date)")
assertClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(dateVar), "(table1.col_date IS NOT DISTINCT FROM $1::date)", "2000-12-30")
}
func TestDateExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColDate.GT(table2ColDate), "(table1.col_date > table2.col_date)")
assertClauseSerialize(t, table1ColDate.GT(dateVar), "(table1.col_date > $1::date)", "2000-12-30")
}
func TestDateExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.GT_EQ(table2ColDate), "(table1.col_date >= table2.col_date)")
assertClauseSerialize(t, table1ColDate.GT_EQ(dateVar), "(table1.col_date >= $1::date)", "2000-12-30")
}
func TestDateExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColDate.LT(table2ColDate), "(table1.col_date < table2.col_date)")
assertClauseSerialize(t, table1ColDate.LT(dateVar), "(table1.col_date < $1::date)", "2000-12-30")
}
func TestDateExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.LT_EQ(table2ColDate), "(table1.col_date <= table2.col_date)")
assertClauseSerialize(t, table1ColDate.LT_EQ(dateVar), "(table1.col_date <= $1::date)", "2000-12-30")
}

View file

@ -1,102 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
)
// DeleteStatement is interface for SQL DELETE statement
type DeleteStatement interface {
Statement
WHERE(expression BoolExpression) DeleteStatement
RETURNING(projections ...projection) DeleteStatement
}
func newDeleteStatement(table WritableTable) DeleteStatement {
return &deleteStatementImpl{
table: table,
}
}
type deleteStatementImpl struct {
table WritableTable
where BoolExpression
returning []projection
}
func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
d.where = expression
return d
}
func (d *deleteStatementImpl) RETURNING(projections ...projection) DeleteStatement {
d.returning = projections
return d
}
func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
if d == nil {
return errors.New("jet: delete statement is nil")
}
out.newLine()
out.writeString("DELETE FROM")
if d.table == nil {
return errors.New("jet: nil tableName")
}
if err := d.table.serialize(deleteStatement, out); err != nil {
return err
}
if d.where == nil {
return errors.New("jet: deleting without a WHERE clause")
}
if err := out.writeWhere(deleteStatement, d.where); err != nil {
return err
}
if err := out.writeReturning(deleteStatement, d.returning); err != nil {
return err
}
return nil
}
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := &sqlBuilder{}
err = d.serializeImpl(queryData)
if err != nil {
return
}
query, args = queryData.finalize()
return
}
func (d *deleteStatementImpl) DebugSql() (query string, err error) {
return debugSql(d)
}
func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(d, db, destination)
}
func (d *deleteStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, d, db, destination)
}
func (d *deleteStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(d, db)
}
func (d *deleteStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, d, db)
}

View file

@ -1,25 +0,0 @@
package jet
import (
"testing"
)
func TestDeleteUnconditionally(t *testing.T) {
assertStatementErr(t, table1.DELETE(), `jet: deleting without a WHERE clause`)
assertStatementErr(t, table1.DELETE().WHERE(nil), `jet: deleting without a WHERE clause`)
}
func TestDeleteWithWhere(t *testing.T) {
assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), `
DELETE FROM db.table1
WHERE table1.col1 = $1;
`, int64(1))
}
func TestDeleteWithWhereAndReturning(t *testing.T) {
assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).RETURNING(table1Col1), `
DELETE FROM db.table1
WHERE table1.col1 = $1
RETURNING table1.col1 AS "table1.col1";
`, int64(1))
}

73
doc.go
View file

@ -1,5 +1,74 @@
/* /*
Package jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability Package jet is a framework for writing type-safe SQL queries in Go, with ability to easily convert database query
to easily convert database query result to desired arbitrary structure. result into desired arbitrary object structure.
Installation
Use the bellow command to install jet
$ go get -u github.com/go-jet/jet
Install jet generator to GOPATH bin folder. This will allow generating jet files from the command line.
go install github.com/go-jet/jet/cmd/jet
*Make sure GOPATH bin folder is added to the PATH environment variable.
Usage
Jet requires already defined database schema(with tables, enums etc), so that jet generator can generate SQL Builder
and Model files. File generation is very fast, and can be added as every pre-build step.
Sample command:
jet -source=PostgreSQL -host=localhost -port=5432 -user=jet -password=pass -dbname=jetdb -schema=dvds -path=./gen
Then next step is to import generated SQL Builder and Model files and write SQL queries in Go:
import . "some_path/.gen/jetdb/dvds/table"
import "some_path/.gen/jetdb/dvds/model"
To write SQL queries for PostgreSQL import:
. "github.com/go-jet/jet/postgres"
To write SQL queries for MySQL and MariaDB import:
. "github.com/go-jet/jet/mysql"
*Dot import is used so that Go code resemble as much as native SQL. Dot import is not mandatory.
Write SQL:
// sub-query
rRatingFilms := SELECT(
Film.FilmID,
Film.Title,
Film.Rating,
).
FROM(Film).
WHERE(Film.Rating.EQ(enum.FilmRating.R)).
AsTable("rFilms")
// export column from sub-query
rFilmID := Film.FilmID.From(rRatingFilms)
// main-query
query := SELECT(
Actor.AllColumns,
FilmActor.AllColumns,
rRatingFilms.AllColumns(),
).
FROM(
rRatingFilms.
INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(rFilmID)).
INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)
).
ORDER_BY(rFilmID, Actor.ActorID)
Store result into desired destination:
var dest []struct {
model.Film
Actors []model.Actor
}
err := query.Query(db, &dest)
Detail info about all features and use cases can be
found at project wiki page - https://github.com/go-jet/jet/wiki.
*/ */
package jet package jet

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -8,18 +8,18 @@
package enum package enum
import "github.com/go-jet/jet" import "github.com/go-jet/jet/postgres"
var MpaaRating = &struct { var MpaaRating = &struct {
G jet.StringExpression G postgres.StringExpression
Pg jet.StringExpression Pg postgres.StringExpression
Pg13 jet.StringExpression Pg13 postgres.StringExpression
R jet.StringExpression R postgres.StringExpression
Nc17 jet.StringExpression Nc17 postgres.StringExpression
}{ }{
G: jet.NewEnumValue("G"), G: postgres.NewEnumValue("G"),
Pg: jet.NewEnumValue("PG"), Pg: postgres.NewEnumValue("PG"),
Pg13: jet.NewEnumValue("PG-13"), Pg13: postgres.NewEnumValue("PG-13"),
R: jet.NewEnumValue("R"), R: postgres.NewEnumValue("R"),
Nc17: jet.NewEnumValue("NC-17"), Nc17: postgres.NewEnumValue("NC-17"),
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,22 +9,22 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var Actor = newActorTable() var Actor = newActorTable()
type ActorTable struct { type ActorTable struct {
jet.Table postgres.Table
//Columns //Columns
ActorID jet.ColumnInteger ActorID postgres.ColumnInteger
FirstName jet.ColumnString FirstName postgres.ColumnString
LastName jet.ColumnString LastName postgres.ColumnString
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new ActorTable with assigned alias // creates new ActorTable with assigned alias
@ -38,14 +38,14 @@ func (a *ActorTable) AS(alias string) *ActorTable {
func newActorTable() *ActorTable { func newActorTable() *ActorTable {
var ( var (
ActorIDColumn = jet.IntegerColumn("actor_id") ActorIDColumn = postgres.IntegerColumn("actor_id")
FirstNameColumn = jet.StringColumn("first_name") FirstNameColumn = postgres.StringColumn("first_name")
LastNameColumn = jet.StringColumn("last_name") LastNameColumn = postgres.StringColumn("last_name")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
) )
return &ActorTable{ return &ActorTable{
Table: jet.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn), Table: postgres.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn),
//Columns //Columns
ActorID: ActorIDColumn, ActorID: ActorIDColumn,
@ -53,7 +53,7 @@ func newActorTable() *ActorTable {
LastName: LastNameColumn, LastName: LastNameColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}, AllColumns: postgres.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn),
MutableColumns: jet.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}, MutableColumns: postgres.ColumnList(FirstNameColumn, LastNameColumn, LastUpdateColumn),
} }
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var Category = newCategoryTable() var Category = newCategoryTable()
type CategoryTable struct { type CategoryTable struct {
jet.Table postgres.Table
//Columns //Columns
CategoryID jet.ColumnInteger CategoryID postgres.ColumnInteger
Name jet.ColumnString Name postgres.ColumnString
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new CategoryTable with assigned alias // creates new CategoryTable with assigned alias
@ -37,20 +37,20 @@ func (a *CategoryTable) AS(alias string) *CategoryTable {
func newCategoryTable() *CategoryTable { func newCategoryTable() *CategoryTable {
var ( var (
CategoryIDColumn = jet.IntegerColumn("category_id") CategoryIDColumn = postgres.IntegerColumn("category_id")
NameColumn = jet.StringColumn("name") NameColumn = postgres.StringColumn("name")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
) )
return &CategoryTable{ return &CategoryTable{
Table: jet.NewTable("dvds", "category", CategoryIDColumn, NameColumn, LastUpdateColumn), Table: postgres.NewTable("dvds", "category", CategoryIDColumn, NameColumn, LastUpdateColumn),
//Columns //Columns
CategoryID: CategoryIDColumn, CategoryID: CategoryIDColumn,
Name: NameColumn, Name: NameColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{CategoryIDColumn, NameColumn, LastUpdateColumn}, AllColumns: postgres.ColumnList(CategoryIDColumn, NameColumn, LastUpdateColumn),
MutableColumns: jet.ColumnList{NameColumn, LastUpdateColumn}, MutableColumns: postgres.ColumnList(NameColumn, LastUpdateColumn),
} }
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,31 +9,31 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var Film = newFilmTable() var Film = newFilmTable()
type FilmTable struct { type FilmTable struct {
jet.Table postgres.Table
//Columns //Columns
FilmID jet.ColumnInteger FilmID postgres.ColumnInteger
Title jet.ColumnString Title postgres.ColumnString
Description jet.ColumnString Description postgres.ColumnString
ReleaseYear jet.ColumnInteger ReleaseYear postgres.ColumnInteger
LanguageID jet.ColumnInteger LanguageID postgres.ColumnInteger
RentalDuration jet.ColumnInteger RentalDuration postgres.ColumnInteger
RentalRate jet.ColumnFloat RentalRate postgres.ColumnFloat
Length jet.ColumnInteger Length postgres.ColumnInteger
ReplacementCost jet.ColumnFloat ReplacementCost postgres.ColumnFloat
Rating jet.ColumnString Rating postgres.ColumnString
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
SpecialFeatures jet.ColumnString SpecialFeatures postgres.ColumnString
Fulltext jet.ColumnString Fulltext postgres.ColumnString
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new FilmTable with assigned alias // creates new FilmTable with assigned alias
@ -47,23 +47,23 @@ func (a *FilmTable) AS(alias string) *FilmTable {
func newFilmTable() *FilmTable { func newFilmTable() *FilmTable {
var ( var (
FilmIDColumn = jet.IntegerColumn("film_id") FilmIDColumn = postgres.IntegerColumn("film_id")
TitleColumn = jet.StringColumn("title") TitleColumn = postgres.StringColumn("title")
DescriptionColumn = jet.StringColumn("description") DescriptionColumn = postgres.StringColumn("description")
ReleaseYearColumn = jet.IntegerColumn("release_year") ReleaseYearColumn = postgres.IntegerColumn("release_year")
LanguageIDColumn = jet.IntegerColumn("language_id") LanguageIDColumn = postgres.IntegerColumn("language_id")
RentalDurationColumn = jet.IntegerColumn("rental_duration") RentalDurationColumn = postgres.IntegerColumn("rental_duration")
RentalRateColumn = jet.FloatColumn("rental_rate") RentalRateColumn = postgres.FloatColumn("rental_rate")
LengthColumn = jet.IntegerColumn("length") LengthColumn = postgres.IntegerColumn("length")
ReplacementCostColumn = jet.FloatColumn("replacement_cost") ReplacementCostColumn = postgres.FloatColumn("replacement_cost")
RatingColumn = jet.StringColumn("rating") RatingColumn = postgres.StringColumn("rating")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
SpecialFeaturesColumn = jet.StringColumn("special_features") SpecialFeaturesColumn = postgres.StringColumn("special_features")
FulltextColumn = jet.StringColumn("fulltext") FulltextColumn = postgres.StringColumn("fulltext")
) )
return &FilmTable{ return &FilmTable{
Table: jet.NewTable("dvds", "film", FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn), Table: postgres.NewTable("dvds", "film", FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn),
//Columns //Columns
FilmID: FilmIDColumn, FilmID: FilmIDColumn,
@ -80,7 +80,7 @@ func newFilmTable() *FilmTable {
SpecialFeatures: SpecialFeaturesColumn, SpecialFeatures: SpecialFeaturesColumn,
Fulltext: FulltextColumn, Fulltext: FulltextColumn,
AllColumns: jet.ColumnList{FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}, AllColumns: postgres.ColumnList(FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn),
MutableColumns: jet.ColumnList{TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}, MutableColumns: postgres.ColumnList(TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn),
} }
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var FilmActor = newFilmActorTable() var FilmActor = newFilmActorTable()
type FilmActorTable struct { type FilmActorTable struct {
jet.Table postgres.Table
//Columns //Columns
ActorID jet.ColumnInteger ActorID postgres.ColumnInteger
FilmID jet.ColumnInteger FilmID postgres.ColumnInteger
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new FilmActorTable with assigned alias // creates new FilmActorTable with assigned alias
@ -37,20 +37,20 @@ func (a *FilmActorTable) AS(alias string) *FilmActorTable {
func newFilmActorTable() *FilmActorTable { func newFilmActorTable() *FilmActorTable {
var ( var (
ActorIDColumn = jet.IntegerColumn("actor_id") ActorIDColumn = postgres.IntegerColumn("actor_id")
FilmIDColumn = jet.IntegerColumn("film_id") FilmIDColumn = postgres.IntegerColumn("film_id")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
) )
return &FilmActorTable{ return &FilmActorTable{
Table: jet.NewTable("dvds", "film_actor", ActorIDColumn, FilmIDColumn, LastUpdateColumn), Table: postgres.NewTable("dvds", "film_actor", ActorIDColumn, FilmIDColumn, LastUpdateColumn),
//Columns //Columns
ActorID: ActorIDColumn, ActorID: ActorIDColumn,
FilmID: FilmIDColumn, FilmID: FilmIDColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{ActorIDColumn, FilmIDColumn, LastUpdateColumn}, AllColumns: postgres.ColumnList(ActorIDColumn, FilmIDColumn, LastUpdateColumn),
MutableColumns: jet.ColumnList{LastUpdateColumn}, MutableColumns: postgres.ColumnList(LastUpdateColumn),
} }
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var FilmCategory = newFilmCategoryTable() var FilmCategory = newFilmCategoryTable()
type FilmCategoryTable struct { type FilmCategoryTable struct {
jet.Table postgres.Table
//Columns //Columns
FilmID jet.ColumnInteger FilmID postgres.ColumnInteger
CategoryID jet.ColumnInteger CategoryID postgres.ColumnInteger
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new FilmCategoryTable with assigned alias // creates new FilmCategoryTable with assigned alias
@ -37,20 +37,20 @@ func (a *FilmCategoryTable) AS(alias string) *FilmCategoryTable {
func newFilmCategoryTable() *FilmCategoryTable { func newFilmCategoryTable() *FilmCategoryTable {
var ( var (
FilmIDColumn = jet.IntegerColumn("film_id") FilmIDColumn = postgres.IntegerColumn("film_id")
CategoryIDColumn = jet.IntegerColumn("category_id") CategoryIDColumn = postgres.IntegerColumn("category_id")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
) )
return &FilmCategoryTable{ return &FilmCategoryTable{
Table: jet.NewTable("dvds", "film_category", FilmIDColumn, CategoryIDColumn, LastUpdateColumn), Table: postgres.NewTable("dvds", "film_category", FilmIDColumn, CategoryIDColumn, LastUpdateColumn),
//Columns //Columns
FilmID: FilmIDColumn, FilmID: FilmIDColumn,
CategoryID: CategoryIDColumn, CategoryID: CategoryIDColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{FilmIDColumn, CategoryIDColumn, LastUpdateColumn}, AllColumns: postgres.ColumnList(FilmIDColumn, CategoryIDColumn, LastUpdateColumn),
MutableColumns: jet.ColumnList{LastUpdateColumn}, MutableColumns: postgres.ColumnList(LastUpdateColumn),
} }
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var Language = newLanguageTable() var Language = newLanguageTable()
type LanguageTable struct { type LanguageTable struct {
jet.Table postgres.Table
//Columns //Columns
LanguageID jet.ColumnInteger LanguageID postgres.ColumnInteger
Name jet.ColumnString Name postgres.ColumnString
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new LanguageTable with assigned alias // creates new LanguageTable with assigned alias
@ -37,20 +37,20 @@ func (a *LanguageTable) AS(alias string) *LanguageTable {
func newLanguageTable() *LanguageTable { func newLanguageTable() *LanguageTable {
var ( var (
LanguageIDColumn = jet.IntegerColumn("language_id") LanguageIDColumn = postgres.IntegerColumn("language_id")
NameColumn = jet.StringColumn("name") NameColumn = postgres.StringColumn("name")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
) )
return &LanguageTable{ return &LanguageTable{
Table: jet.NewTable("dvds", "language", LanguageIDColumn, NameColumn, LastUpdateColumn), Table: postgres.NewTable("dvds", "language", LanguageIDColumn, NameColumn, LastUpdateColumn),
//Columns //Columns
LanguageID: LanguageIDColumn, LanguageID: LanguageIDColumn,
Name: NameColumn, Name: NameColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{LanguageIDColumn, NameColumn, LastUpdateColumn}, AllColumns: postgres.ColumnList(LanguageIDColumn, NameColumn, LastUpdateColumn),
MutableColumns: jet.ColumnList{NameColumn, LastUpdateColumn}, MutableColumns: postgres.ColumnList(NameColumn, LastUpdateColumn),
} }
} }

View file

@ -9,24 +9,23 @@ import (
// dot import so go code would resemble as much as native SQL // dot import so go code would resemble as much as native SQL
// dot import is not mandatory // dot import is not mandatory
. "github.com/go-jet/jet"
. "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table" . "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table"
. "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/model" "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/model"
) )
const ( const (
Host = "localhost" host = "localhost"
Port = 5432 port = 5432
User = "jet" user = "jet"
Password = "jet" password = "jet"
DBName = "jetdb" dbName = "jetdb"
) )
func main() { func main() {
// Connect to database // Connect to database
var connectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName) var connectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", host, port, user, password, dbName)
db, err := sql.Open("postgres", connectString) db, err := sql.Open("postgres", connectString)
panicOnError(err) panicOnError(err)
@ -97,17 +96,15 @@ func jsonSave(path string, v interface{}) {
} }
} }
func printStatementInfo(stmt Statement) { func printStatementInfo(stmt SelectStatement) {
query, args, err := stmt.Sql() query, args := stmt.Sql()
panicOnError(err)
fmt.Println("Parameterized query: ") fmt.Println("Parameterized query: ")
fmt.Println(query) fmt.Println(query)
fmt.Println("Arguments: ") fmt.Println("Arguments: ")
fmt.Println(args) fmt.Println(args)
debugSQL, err := stmt.DebugSql() debugSQL := stmt.DebugSql()
panicOnError(err)
fmt.Println("\n\n==============================") fmt.Println("\n\n==============================")

View file

@ -4,10 +4,10 @@ import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors"
"fmt" "fmt"
"github.com/go-jet/jet/execution/internal" "github.com/go-jet/jet/execution/internal"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
"github.com/google/uuid"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
@ -18,14 +18,11 @@ import (
// Destination can be either pointer to struct or pointer to slice of structs. // Destination can be either pointer to struct or pointer to slice of structs.
func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error { func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error {
if utils.IsNil(destinationPtr) { utils.MustBeInitializedPtr(db, "jet: db is nil")
return errors.New("jet: Destination is nil") utils.MustBeInitializedPtr(destinationPtr, "jet: destination is nil")
} utils.MustBe(destinationPtr, reflect.Ptr, "jet: destination has to be a pointer to slice or pointer to struct")
destinationPtrType := reflect.TypeOf(destinationPtr) destinationPtrType := reflect.TypeOf(destinationPtr)
if destinationPtrType.Kind() != reflect.Ptr {
return errors.New("jet: Destination has to be a pointer to slice or pointer to struct")
}
if destinationPtrType.Elem().Kind() == reflect.Slice { if destinationPtrType.Elem().Kind() == reflect.Slice {
return queryToSlice(context, db, query, args, destinationPtr) return queryToSlice(context, db, query, args, destinationPtr)
@ -51,24 +48,11 @@ func Query(context context.Context, db DB, query string, args []interface{}, des
} }
return nil return nil
} else { } else {
return errors.New("jet: unsupported destination type") panic("jet: destination has to be a pointer to slice or pointer to struct")
} }
} }
func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) error { func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) error {
if db == nil {
return errors.New("jet: db is nil")
}
if slicePtr == nil {
return errors.New("jet: Destination is nil. ")
}
destinationType := reflect.TypeOf(slicePtr)
if destinationType.Kind() != reflect.Ptr && destinationType.Elem().Kind() != reflect.Slice {
return errors.New("jet: Destination has to be a pointer to slice. ")
}
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
@ -126,14 +110,12 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
sliceElemType := getSliceElemType(slicePtrValue) sliceElemType := getSliceElemType(slicePtrValue)
if isGoBaseType(sliceElemType) { if isSimpleModelType(sliceElemType) {
updated, err = mapRowToBaseTypeSlice(scanContext, slicePtrValue, field) updated, err = mapRowToBaseTypeSlice(scanContext, slicePtrValue, field)
return return
} }
if sliceElemType.Kind() != reflect.Struct { utils.TypeMustBe(sliceElemType, reflect.Struct, "jet: unsupported slice element type"+fieldToString(field))
return false, errors.New("jet: Unsupported dest type: " + field.Name + " " + field.Type.String())
}
structGroupKey := scanContext.getGroupKey(sliceElemType, field) structGroupKey := scanContext.getGroupKey(sliceElemType, field)
@ -226,7 +208,7 @@ func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
if implementsScannerType(field.Type) { if implementsScannerType(field.Type) {
fieldMap.implementsScanner = true fieldMap.implementsScanner = true
} else if !isGoBaseType(field.Type) { } else if !isSimpleModelType(field.Type) {
fieldMap.complexType = true fieldMap.complexType = true
} }
@ -249,6 +231,10 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
field := structType.Field(i) field := structType.Field(i)
fieldValue := structValue.Field(i) fieldValue := structValue.Field(i)
if !fieldValue.CanSet() { // private field
continue
}
fieldMap := typeInf.fieldMappings[i] fieldMap := typeInf.fieldMappings[i]
if fieldMap.complexType { if fieldMap.complexType {
@ -284,8 +270,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
err = scanner.Scan(cellValue) err = scanner.Scan(cellValue)
if err != nil { if err != nil {
err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String()) panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String())
return
} }
updated = true updated = true
} else { } else {
@ -294,12 +279,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
if cellValue != nil { if cellValue != nil {
updated = true updated = true
initializeValueIfNilPtr(fieldValue) initializeValueIfNilPtr(fieldValue)
err = setReflectValue(reflect.ValueOf(cellValue), fieldValue) setReflectValue(reflect.ValueOf(cellValue), fieldValue)
if err != nil {
err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String())
return
}
} }
} }
} }
@ -310,9 +290,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
if destPtrValue.Kind() != reflect.Ptr { utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.")
return false, errors.New("jet: Internal error. ")
}
destValueKind := destPtrValue.Elem().Kind() destValueKind := destPtrValue.Elem().Kind()
@ -321,7 +299,7 @@ func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrVa
} else if destValueKind == reflect.Slice { } else if destValueKind == reflect.Slice {
return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
} else { } else {
return false, errors.New("jet: Unsupported dest type: " + structField.Name + " " + structField.Type.String()) panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
} }
} }
@ -331,14 +309,12 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
if dest.Kind() != reflect.Ptr { if dest.Kind() != reflect.Ptr {
destPtrValue = dest.Addr() destPtrValue = dest.Addr()
} else if dest.Kind() == reflect.Ptr { } else {
if dest.IsNil() { if dest.IsNil() {
destPtrValue = reflect.New(dest.Type().Elem()) destPtrValue = reflect.New(dest.Type().Elem())
} else { } else {
destPtrValue = dest destPtrValue = dest
} }
} else {
return false, errors.New("jet: Internal error. ")
} }
updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField) updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField)
@ -399,7 +375,7 @@ func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value {
func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error { func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error {
if slicePtrValue.IsNil() { if slicePtrValue.IsNil() {
panic("Slice is nil") panic("jet: internal, slice is nil")
} }
sliceValue := slicePtrValue.Elem() sliceValue := slicePtrValue.Elem()
sliceElemType := sliceValue.Type().Elem() sliceElemType := sliceValue.Type().Elem()
@ -410,8 +386,12 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
newElemValue = objPtrValue.Elem() newElemValue = objPtrValue.Elem()
} }
if newElemValue.Type().ConvertibleTo(sliceElemType) {
newElemValue = newElemValue.Convert(sliceElemType)
}
if !newElemValue.Type().AssignableTo(sliceElemType) { if !newElemValue.Type().AssignableTo(sliceElemType) {
return fmt.Errorf("jet: can't append %s to %s slice ", newElemValue.Type().String(), sliceValue.Type().String()) panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice")
} }
sliceValue.Set(reflect.Append(sliceValue, newElemValue)) sliceValue.Set(reflect.Append(sliceValue, newElemValue))
@ -465,6 +445,7 @@ func toCommonIdentifier(name string) string {
} }
func initializeValueIfNilPtr(value reflect.Value) { func initializeValueIfNilPtr(value reflect.Value) {
if !value.IsValid() || !value.CanSet() { if !value.IsValid() || !value.CanSet() {
return return
} }
@ -490,55 +471,119 @@ func valueToString(value reflect.Value) string {
valueInterface = value.Interface() valueInterface = value.Interface()
} }
if t, ok := valueInterface.(time.Time); ok { if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String() return t.String()
} }
return fmt.Sprintf("%#v", valueInterface) return fmt.Sprintf("%#v", valueInterface)
} }
func isGoBaseType(objType reflect.Type) bool { var timeType = reflect.TypeOf(time.Now())
typeStr := objType.String() var uuidType = reflect.TypeOf(uuid.New())
switch typeStr { func isSimpleModelType(objType reflect.Type) bool {
case "string", "int", "int16", "int32", "int64", "float32", "float64", "time.Time", "bool", "[]byte", "[]uint8", objType = indirectType(objType)
"*string", "*int", "*int16", "*int32", "*int64", "*float32", "*float64", "*time.Time", "*bool", "*[]byte", "*[]uint8":
switch objType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String,
reflect.Bool:
return true
case reflect.Slice:
return objType.Elem().Kind() == reflect.Uint8 //[]byte
case reflect.Struct:
return objType == timeType || objType == uuidType // time.Time || uuid.UUID
}
return false
}
func isIntegerType(value reflect.Type) bool {
switch value {
case int8Type, unit8Type, int16Type, uint16Type,
int32Type, uint32Type, int64Type, uint64Type:
return true return true
} }
return false return false
} }
func setReflectValue(source, destination reflect.Value) error { func tryAssign(source, destination reflect.Value) bool {
var sourceElem reflect.Value if source.Type().ConvertibleTo(destination.Type()) {
source = source.Convert(destination.Type())
}
if isIntegerType(source.Type()) && destination.Type() == boolType {
intValue := source.Int()
if intValue == 1 {
source = reflect.ValueOf(true)
} else if intValue == 0 {
source = reflect.ValueOf(false)
}
}
if source.Type().AssignableTo(destination.Type()) {
destination.Set(source)
return true
}
return false
}
func setReflectValue(source, destination reflect.Value) {
if tryAssign(source, destination) {
return
}
if destination.Kind() == reflect.Ptr { if destination.Kind() == reflect.Ptr {
if source.Kind() == reflect.Ptr { if source.Kind() == reflect.Ptr {
sourceElem = source if !source.IsNil() {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
} else {
return
}
} else { } else {
if source.CanAddr() { if source.CanAddr() {
sourceElem = source.Addr() source = source.Addr()
} else { } else {
sourceCopy := reflect.New(source.Type()) sourceCopy := reflect.New(source.Type())
sourceCopy.Elem().Set(source) sourceCopy.Elem().Set(source)
sourceElem = sourceCopy source = sourceCopy
}
if tryAssign(source, destination) {
return
}
if tryAssign(source.Elem(), destination.Elem()) {
return
} }
} }
} else { } else {
if source.Kind() == reflect.Ptr { if source.Kind() == reflect.Ptr {
sourceElem = source.Elem() if source.IsNil() {
} else { return
sourceElem = source }
source = source.Elem()
}
if tryAssign(source, destination) {
return
} }
} }
if !sourceElem.Type().AssignableTo(destination.Type()) { panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String())
return errors.New("jet: can't set " + sourceElem.Type().String() + " to " + destination.Type().String())
}
destination.Set(sourceElem)
return nil
} }
func createScanValue(columnTypes []*sql.ColumnType) []interface{} { func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
@ -555,35 +600,49 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
return values return values
} }
var nullFloatType = reflect.TypeOf(internal.NullFloat32{}) var boolType = reflect.TypeOf(true)
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{}) var int8Type = reflect.TypeOf(int8(1))
var unit8Type = reflect.TypeOf(uint8(1))
var int16Type = reflect.TypeOf(int16(1))
var uint16Type = reflect.TypeOf(uint16(1))
var int32Type = reflect.TypeOf(int32(1))
var uint32Type = reflect.TypeOf(uint32(1))
var int64Type = reflect.TypeOf(int64(1))
var uint64Type = reflect.TypeOf(uint64(1))
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullInt8Type = reflect.TypeOf(internal.NullInt8{})
var nullInt16Type = reflect.TypeOf(internal.NullInt16{}) var nullInt16Type = reflect.TypeOf(internal.NullInt16{})
var nullInt32Type = reflect.TypeOf(internal.NullInt32{}) var nullInt32Type = reflect.TypeOf(internal.NullInt32{})
var nullInt64Type = reflect.TypeOf(sql.NullInt64{}) var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{})
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
var nullStringType = reflect.TypeOf(sql.NullString{}) var nullStringType = reflect.TypeOf(sql.NullString{})
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullTimeType = reflect.TypeOf(internal.NullTime{}) var nullTimeType = reflect.TypeOf(internal.NullTime{})
var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{}) var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{})
func newScanType(columnType *sql.ColumnType) reflect.Type { func newScanType(columnType *sql.ColumnType) reflect.Type {
switch columnType.DatabaseTypeName() { switch columnType.DatabaseTypeName() {
case "INT2": case "TINYINT":
return nullInt8Type
case "INT2", "SMALLINT", "YEAR":
return nullInt16Type return nullInt16Type
case "INT4": case "INT4", "MEDIUMINT", "INT":
return nullInt32Type return nullInt32Type
case "INT8": case "INT8", "BIGINT":
return nullInt64Type return nullInt64Type
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML": case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
return nullStringType return nullStringType
case "FLOAT4": case "FLOAT4":
return nullFloatType return nullFloat32Type
case "FLOAT8", "NUMERIC", "DECIMAL": case "FLOAT8", "NUMERIC", "DECIMAL", "FLOAT", "DOUBLE":
return nullFloat64Type return nullFloat64Type
case "BOOL": case "BOOL":
return nullBoolType return nullBoolType
case "BYTEA": case "BYTEA", "BINARY", "VARBINARY", "BLOB":
return nullByteArrayType return nullByteArrayType
case "DATE", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ": case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
return nullTimeType return nullTimeType
default: default:
return nullStringType return nullStringType
@ -697,7 +756,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl
field := structType.Field(i) field := structType.Field(i)
newTypeName, fieldName := getTypeAndFieldName(typeName, field) newTypeName, fieldName := getTypeAndFieldName(typeName, field)
if !isGoBaseType(field.Type) { if !isSimpleModelType(field.Type) {
var structType reflect.Type var structType reflect.Type
if field.Type.Kind() == reflect.Struct { if field.Type.Kind() == reflect.Struct {
structType = field.Type structType = field.Type
@ -749,7 +808,7 @@ func (s *scanContext) rowElem(index int) interface{} {
valuer, ok := s.row[index].(driver.Valuer) valuer, ok := s.row[index].(driver.Valuer)
if !ok { if !ok {
panic("Scan value doesn't implement driver.Valuer") panic("jet: internal error, scan value doesn't implement driver.Valuer")
} }
value, err := valuer.Value() value, err := valuer.Value()
@ -791,3 +850,11 @@ func indirectType(reflectType reflect.Type) reflect.Type {
} }
return reflectType.Elem() return reflectType.Elem()
} }
func fieldToString(field *reflect.StructField) string {
if field == nil {
return ""
}
return " at '" + field.Name + " " + field.Type.String() + "'"
}

View file

@ -2,9 +2,13 @@ package internal
import ( import (
"database/sql/driver" "database/sql/driver"
"fmt"
"strconv"
"time" "time"
) )
//===============================================================//
// NullByteArray struct // NullByteArray struct
type NullByteArray struct { type NullByteArray struct {
ByteArray []byte ByteArray []byte
@ -14,13 +18,16 @@ type NullByteArray struct {
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (nb *NullByteArray) Scan(value interface{}) error { func (nb *NullByteArray) Scan(value interface{}) error {
switch v := value.(type) { switch v := value.(type) {
case nil:
nb.Valid = false
return nil
case []byte: case []byte:
nb.ByteArray = append(v[:0:0], v...) nb.ByteArray = append(v[:0:0], v...)
nb.Valid = true nb.Valid = true
return nil
default: default:
nb.Valid = false return fmt.Errorf("can't scan []byte from %v", value)
} }
return nil
} }
// Value implements the driver Valuer interface. // Value implements the driver Valuer interface.
@ -31,6 +38,8 @@ func (nb NullByteArray) Value() (driver.Value, error) {
return nb.ByteArray, nil return nb.ByteArray, nil
} }
//===============================================================//
// NullTime struct // NullTime struct
type NullTime struct { type NullTime struct {
Time time.Time Time time.Time
@ -38,9 +47,23 @@ type NullTime struct {
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) error { func (nt *NullTime) Scan(value interface{}) (err error) {
nt.Time, nt.Valid = value.(time.Time) switch v := value.(type) {
return nil case nil:
nt.Valid = false
return
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, nt.Valid = parseTime(string(v))
return
case string:
nt.Time, nt.Valid = parseTime(v)
return
default:
return fmt.Errorf("can't scan time from %v", value)
}
} }
// Value implements the driver Valuer interface. // Value implements the driver Valuer interface.
@ -51,62 +74,100 @@ func (nt NullTime) Value() (driver.Value, error) {
return nt.Time, nil return nt.Time, nil
} }
// NullInt32 struct const formatTime = "2006-01-02 15:04:05.999999"
type NullInt32 struct {
Int32 int32 func parseTime(timeStr string) (t time.Time, valid bool) {
Valid bool // Valid is true if Int64 is not NULL
var format string
switch len(timeStr) {
case 8:
format = formatTime[11:19]
case 10, 19, 21, 22, 23, 24, 25, 26:
format = formatTime[:len(timeStr)]
default:
return t, false
}
t, err := time.Parse(format, timeStr)
return t, err == nil
}
//===============================================================//
// NullInt8 struct
type NullInt8 struct {
Int8 int8
Valid bool
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (n *NullInt32) Scan(value interface{}) error { func (n *NullInt8) Scan(value interface{}) (err error) {
switch v := value.(type) { switch v := value.(type) {
case nil:
n.Valid = false
return
case int64: case int64:
n.Int32, n.Valid = int32(v), true n.Int8, n.Valid = int8(v), true
return nil return
case int32: case int8:
n.Int32, n.Valid = v, true n.Int8, n.Valid = v, true
return nil return
case uint8: case []byte:
n.Int32, n.Valid = int32(v), true intV, err := strconv.ParseInt(string(v), 10, 8)
return nil if err == nil {
n.Int8, n.Valid = int8(intV), true
}
return err
default:
return fmt.Errorf("can't scan int8 from %v", value)
} }
n.Valid = false
return nil
} }
// Value implements the driver Valuer interface. // Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) { func (n NullInt8) Value() (driver.Value, error) {
if !n.Valid { if !n.Valid {
return nil, nil return nil, nil
} }
return n.Int32, nil return n.Int8, nil
} }
//===============================================================//
// NullInt16 struct // NullInt16 struct
type NullInt16 struct { type NullInt16 struct {
Int16 int16 Int16 int16
Valid bool // Valid is true if Int64 is not NULL Valid bool
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (n *NullInt16) Scan(value interface{}) error { func (n *NullInt16) Scan(value interface{}) error {
switch v := value.(type) { switch v := value.(type) {
case nil:
n.Valid = false
return nil
case int64: case int64:
n.Int16, n.Valid = int16(v), true n.Int16, n.Valid = int16(v), true
return nil return nil
case int16: case int16:
n.Int16, n.Valid = v, true n.Int16, n.Valid = v, true
return nil return nil
case int8:
n.Int16, n.Valid = int16(v), true
return nil
case uint8: case uint8:
n.Int16, n.Valid = int16(v), true n.Int16, n.Valid = int16(v), true
return nil return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 16)
if err == nil {
n.Int16, n.Valid = int16(intV), true
}
return nil
default:
return fmt.Errorf("can't scan int16 from %v", value)
} }
n.Valid = false
return nil
} }
// Value implements the driver Valuer interface. // Value implements the driver Valuer interface.
@ -117,26 +178,80 @@ func (n NullInt16) Value() (driver.Value, error) {
return n.Int16, nil return n.Int16, nil
} }
//===============================================================//
// NullInt32 struct
type NullInt32 struct {
Int32 int32
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt32) Scan(value interface{}) error {
switch v := value.(type) {
case nil:
n.Valid = false
return nil
case int64:
n.Int32, n.Valid = int32(v), true
return nil
case int32:
n.Int32, n.Valid = v, true
return nil
case int16:
n.Int32, n.Valid = int32(v), true
return nil
case uint16:
n.Int32, n.Valid = int32(v), true
return nil
case int8:
n.Int32, n.Valid = int32(v), true
return nil
case uint8:
n.Int32, n.Valid = int32(v), true
return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 32)
if err == nil {
n.Int32, n.Valid = int32(intV), true
}
return nil
default:
return fmt.Errorf("can't scan int32 from %v", value)
}
}
// Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int32, nil
}
//===============================================================//
// NullFloat32 struct // NullFloat32 struct
type NullFloat32 struct { type NullFloat32 struct {
Float32 float32 Float32 float32
Valid bool // Valid is true if Int64 is not NULL Valid bool
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (n *NullFloat32) Scan(value interface{}) error { func (n *NullFloat32) Scan(value interface{}) error {
switch v := value.(type) { switch v := value.(type) {
case nil:
n.Valid = false
return nil
case float64: case float64:
n.Float32, n.Valid = float32(v), true n.Float32, n.Valid = float32(v), true
return nil return nil
case float32: case float32:
n.Float32, n.Valid = v, true n.Float32, n.Valid = v, true
return nil return nil
default:
return fmt.Errorf("can't scan float32 from %v", value)
} }
n.Valid = false
return nil
} }
// Value implements the driver Valuer interface. // Value implements the driver Valuer interface.

View file

@ -1,194 +0,0 @@
package jet
import (
"errors"
)
// Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface {
clause
projection
groupByClause
orderByClause
// Test expression whether it is a NULL value.
IS_NULL() BoolExpression
// Test expression whether it is a non-NULL value.
IS_NOT_NULL() BoolExpression
// Check if this expressions matches any in expressions list
IN(expressions ...Expression) BoolExpression
// Check if this expressions is different of all expressions in expressions list
NOT_IN(expressions ...Expression) BoolExpression
// The temporary alias name to assign to the expression
AS(alias string) projection
// Expression will be used to sort query result in ascending order
ASC() orderByClause
// Expression will be used to sort query result in ascending order
DESC() orderByClause
}
type expressionInterfaceImpl struct {
parent Expression
}
func (e *expressionInterfaceImpl) from(subQuery SelectTable) projection {
return e.parent
}
func (e *expressionInterfaceImpl) IS_NULL() BoolExpression {
return newPostifxBoolExpression(e.parent, "IS NULL")
}
func (e *expressionInterfaceImpl) IS_NOT_NULL() BoolExpression {
return newPostifxBoolExpression(e.parent, "IS NOT NULL")
}
func (e *expressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.parent, WRAP(expressions...), "IN")
}
func (e *expressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.parent, WRAP(expressions...), "NOT IN")
}
func (e *expressionInterfaceImpl) AS(alias string) projection {
return newAlias(e.parent, alias)
}
func (e *expressionInterfaceImpl) ASC() orderByClause {
return newOrderByClause(e.parent, true)
}
func (e *expressionInterfaceImpl) DESC() orderByClause {
return newOrderByClause(e.parent, false)
}
func (e *expressionInterfaceImpl) serializeForGroupBy(statement statementType, out *sqlBuilder) error {
return e.parent.serialize(statement, out, noWrap)
}
func (e *expressionInterfaceImpl) serializeForProjection(statement statementType, out *sqlBuilder) error {
return e.parent.serialize(statement, out, noWrap)
}
func (e *expressionInterfaceImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
return e.parent.serialize(statement, out, noWrap)
}
// Representation of binary operations (e.g. comparisons, arithmetic)
type binaryOpExpression struct {
lhs, rhs Expression
operator string
}
func newBinaryExpression(lhs, rhs Expression, operator string) binaryOpExpression {
binaryExpression := binaryOpExpression{
lhs: lhs,
rhs: rhs,
operator: operator,
}
return binaryExpression
}
func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if c == nil {
return errors.New("jet: binary Expression is nil")
}
if c.lhs == nil {
return errors.New("jet: nil lhs")
}
if c.rhs == nil {
return errors.New("jet: nil rhs")
}
wrap := !contains(options, noWrap)
if wrap {
out.writeString("(")
}
if err := c.lhs.serialize(statement, out); err != nil {
return err
}
out.writeString(c.operator)
if err := c.rhs.serialize(statement, out); err != nil {
return err
}
if wrap {
out.writeString(")")
}
return nil
}
// A prefix operator Expression
type prefixOpExpression struct {
expression Expression
operator string
}
func newPrefixExpression(expression Expression, operator string) prefixOpExpression {
prefixExpression := prefixOpExpression{
expression: expression,
operator: operator,
}
return prefixExpression
}
func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil {
return errors.New("jet: Prefix Expression is nil")
}
out.writeString(p.operator + " ")
if p.expression == nil {
return errors.New("jet: nil prefix Expression")
}
if err := p.expression.serialize(statement, out); err != nil {
return err
}
return nil
}
// A postifx operator Expression
type postfixOpExpression struct {
expression Expression
operator string
}
func newPostfixOpExpression(expression Expression, operator string) postfixOpExpression {
postfixOpExpression := postfixOpExpression{
expression: expression,
operator: operator,
}
return postfixOpExpression
}
func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil {
return errors.New("jet: Postifx operator Expression is nil")
}
if p.expression == nil {
return errors.New("jet: nil prefix Expression")
}
if err := p.expression.serialize(statement, out); err != nil {
return err
}
out.writeString(p.operator)
return nil
}

View file

@ -0,0 +1,175 @@
package metadata
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/internal/utils"
"strings"
)
// ColumnMetaData struct
type ColumnMetaData struct {
Name string
IsNullable bool
DataType string
EnumName string
IsUnsigned bool
SqlBuilderColumnType string
GoBaseType string
GoModelType string
}
// NewColumnMetaData create new column meta data that describes one column in SQL database
func NewColumnMetaData(name string, isNullable bool, dataType string, enumName string, isUnsigned bool) ColumnMetaData {
columnMetaData := ColumnMetaData{
Name: name,
IsNullable: isNullable,
DataType: dataType,
EnumName: enumName,
IsUnsigned: isUnsigned,
}
columnMetaData.SqlBuilderColumnType = columnMetaData.getSqlBuilderColumnType()
columnMetaData.GoBaseType = columnMetaData.getGoBaseType()
columnMetaData.GoModelType = columnMetaData.getGoModelType()
return columnMetaData
}
// getSqlBuilderColumnType returns type of jet sql builder column
func (c ColumnMetaData) getSqlBuilderColumnType() string {
switch c.DataType {
case "boolean":
return "Bool"
case "smallint", "integer", "bigint",
"tinyint", "mediumint", "int", "year": //MySQL
return "Integer"
case "date":
return "Date"
case "timestamp without time zone",
"timestamp", "datetime": //MySQL:
return "Timestamp"
case "timestamp with time zone":
return "Timestampz"
case "time without time zone",
"time": //MySQL
return "Time"
case "time with time zone":
return "Timez"
case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "interval", "line", "ARRAY",
"char", "varchar", "binary", "varbinary",
"tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL
return "String"
case "real", "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "Float"
default:
fmt.Println("- [SQL Builder] Unsupported sql column '" + c.Name + " " + c.DataType + "', using StringColumn instead.")
return "String"
}
}
// getGoBaseType returns model type for column info.
func (c ColumnMetaData) getGoBaseType() string {
switch c.DataType {
case "USER-DEFINED", "enum":
return utils.ToGoIdentifier(c.EnumName)
case "boolean":
return "bool"
case "tinyint":
return "int8"
case "smallint",
"year":
return "int16"
case "integer",
"mediumint", "int": //MySQL
return "int32"
case "bigint":
return "int64"
case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone",
"timestamp", "datetime", "time": // MySQL
return "time.Time"
case "bytea",
"binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": //MySQL
return "[]byte"
case "text", "character", "character varying", "tsvector", "bit", "bit varying", "money", "json", "jsonb",
"xml", "point", "interval", "line", "ARRAY",
"char", "varchar", "tinytext", "mediumtext", "longtext": // MySQL
return "string"
case "real":
return "float32"
case "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "float64"
case "uuid":
return "uuid.UUID"
default:
fmt.Println("- [Model ] Unsupported sql column '" + c.Name + " " + c.DataType + "', using string instead.")
return "string"
}
}
// GoModelType returns model type for column info with optional pointer if
// column can be NULL.
func (c ColumnMetaData) getGoModelType() string {
typeStr := c.GoBaseType
if strings.Contains(typeStr, "int") && c.IsUnsigned {
typeStr = "u" + typeStr
}
if c.IsNullable {
return "*" + typeStr
}
return typeStr
}
// GoModelTag returns model field tag for column
func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string {
tags := []string{}
if isPrimaryKey {
tags = append(tags, "primary_key")
}
if len(tags) > 0 {
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
}
return ""
}
func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) ([]ColumnMetaData, error) {
rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []ColumnMetaData{}
for rows.Next() {
var name, isNullable, dataType, enumName string
var isUnsigned bool
err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned)
if err != nil {
return nil, err
}
ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned))
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -0,0 +1,15 @@
package metadata
import (
"database/sql"
)
// DialectQuerySet is set of methods necessary to retrieve dialect meta data information
type DialectQuerySet interface {
ListOfTablesQuery() string
PrimaryKeysQuery() string
ListOfColumnsQuery() string
ListOfEnumsQuery() string
GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaData, error)
}

View file

@ -0,0 +1,12 @@
package metadata
// EnumMetaData struct
type EnumMetaData struct {
EnumName string
Values []string
}
// Name returns enum name
func (e EnumMetaData) Name() string {
return e.EnumName
}

View file

@ -1,142 +0,0 @@
package postgresmeta
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/internal/utils"
"strings"
)
// ColumnInfo metadata struct
type ColumnInfo struct {
Name string
IsNullable bool
DataType string
EnumName string
}
// SqlBuilderColumnType returns type of jet sql builder column
func (c ColumnInfo) SqlBuilderColumnType() string {
switch c.DataType {
case "boolean":
return "Bool"
case "smallint", "integer", "bigint":
return "Integer"
case "date":
return "Date"
case "timestamp without time zone":
return "Timestamp"
case "timestamp with time zone":
return "Timestampz"
case "time without time zone":
return "Time"
case "time with time zone":
return "Timez"
case "USER-DEFINED", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "interval", "line", "ARRAY":
return "String"
case "real", "numeric", "decimal", "double precision":
return "Float"
default:
fmt.Println("Unsupported sql type: " + c.DataType + ", using string column instead for sql builder.")
return "String"
}
}
// GoBaseType returns model type for column info.
func (c ColumnInfo) GoBaseType() string {
switch c.DataType {
case "USER-DEFINED":
return utils.ToGoIdentifier(c.EnumName)
case "boolean":
return "bool"
case "smallint":
return "int16"
case "integer":
return "int32"
case "bigint":
return "int64"
case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone":
return "time.Time"
case "bytea":
return "[]byte"
case "text", "character", "character varying", "tsvector", "bit", "bit varying", "money", "json", "jsonb",
"xml", "point", "interval", "line", "ARRAY":
return "string"
case "real":
return "float32"
case "numeric", "decimal", "double precision":
return "float64"
case "uuid":
return "uuid.UUID"
default:
fmt.Println("Unsupported sql type: " + c.DataType + ", " + c.EnumName + ", using string instead for model type.")
return "string"
}
}
// GoModelType returns model type for column info with optional pointer if
// column can be NULL.
func (c ColumnInfo) GoModelType() string {
typeStr := c.GoBaseType()
if c.IsNullable {
return "*" + typeStr
}
return typeStr
}
// GoModelTag returns model field tag for column
func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string {
tags := []string{}
if isPrimaryKey {
tags = append(tags, "primary_key")
}
if len(tags) > 0 {
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
}
return ""
}
func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) {
query := `
SELECT column_name, is_nullable, data_type, udt_name
FROM information_schema.columns
where table_catalog = $1 and table_schema = $2 and table_name = $3
order by ordinal_position;`
rows, err := db.Query(query, dbName, schemaName, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []ColumnInfo{}
for rows.Next() {
columnInfo := ColumnInfo{}
var isNullable string
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType, &columnInfo.EnumName)
columnInfo.IsNullable = isNullable == "YES"
if err != nil {
return nil, err
}
ret = append(ret, columnInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -1,68 +0,0 @@
package postgresmeta
import (
"database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
)
// EnumInfo struct
type EnumInfo struct {
name string
Values []string
}
// Name returns enum name
func (e EnumInfo) Name() string {
return e.name
}
func getEnumInfos(db *sql.DB, schemaName string) ([]metadata.MetaData, error) {
query := `
SELECT t.typname,
e.enumlabel
FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;`
rows, err := db.Query(query, schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
enumsInfosMap := map[string][]string{}
for rows.Next() {
var enumName string
var enumValue string
err = rows.Scan(&enumName, &enumValue)
if err != nil {
return nil, err
}
enumValues := enumsInfosMap[enumName]
enumValues = append(enumValues, enumValue)
enumsInfosMap[enumName] = enumValues
}
err = rows.Err()
if err != nil {
return nil, err
}
ret := []metadata.MetaData{}
for enumName, enumValues := range enumsInfosMap {
ret = append(ret, EnumInfo{
enumName,
enumValues,
})
}
return ret, nil
}

View file

@ -1,76 +0,0 @@
package postgresmeta
import (
"database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
)
// SchemaInfo metadata struct
type SchemaInfo struct {
DatabaseName string
Name string
TableInfos []metadata.MetaData
EnumInfos []metadata.MetaData
}
// GetSchemaInfo returns schema information from db connection.
func GetSchemaInfo(db *sql.DB, databaseName, schemaName string) (schemaInfo SchemaInfo, err error) {
schemaInfo.DatabaseName = databaseName
schemaInfo.Name = schemaName
schemaInfo.TableInfos, err = getTableInfos(db, databaseName, schemaName)
if err != nil {
return
}
schemaInfo.EnumInfos, err = getEnumInfos(db, schemaName)
if err != nil {
return
}
return
}
func getTableInfos(db *sql.DB, dbName, schemaName string) ([]metadata.MetaData, error) {
query := `
SELECT table_name
FROM information_schema.tables
where table_catalog = $1 and table_schema = $2 and table_type = 'BASE TABLE';
`
rows, err := db.Query(query, dbName, schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []metadata.MetaData{}
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
return nil, err
}
tableInfo, err := GetTableInfo(db, dbName, schemaName, tableName)
if err != nil {
return nil, err
}
ret = append(ret, tableInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -0,0 +1,68 @@
package metadata
import (
"database/sql"
"fmt"
)
// SchemaMetaData struct
type SchemaMetaData struct {
TableInfos []MetaData
EnumInfos []MetaData
}
// GetSchemaInfo returns schema information from db connection.
func GetSchemaInfo(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData, err error) {
schemaInfo.TableInfos, err = getTableInfos(db, querySet, schemaName)
if err != nil {
return
}
schemaInfo.EnumInfos, err = querySet.GetEnumsMetaData(db, schemaName)
if err != nil {
return
}
fmt.Println(" FOUND", len(schemaInfo.TableInfos), "table(s), ", len(schemaInfo.EnumInfos), "enum(s)")
return
}
func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]MetaData, error) {
rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []MetaData{}
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
return nil, err
}
tableInfo, err := GetTableInfo(db, querySet, schemaName, tableName)
if err != nil {
return nil, err
}
ret = append(ret, tableInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -1,31 +1,31 @@
package postgresmeta package metadata
import ( import (
"database/sql" "database/sql"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
) )
// TableInfo metadata struct // TableMetaData metadata struct
type TableInfo struct { type TableMetaData struct {
SchemaName string SchemaName string
name string name string
PrimaryKeys map[string]bool PrimaryKeys map[string]bool
Columns []ColumnInfo Columns []ColumnMetaData
} }
// Name returns table info name // Name returns table info name
func (t TableInfo) Name() string { func (t TableMetaData) Name() string {
return t.name return t.name
} }
// IsPrimaryKey returns if column is a part of primary key // IsPrimaryKey returns if column is a part of primary key
func (t TableInfo) IsPrimaryKey(column string) bool { func (t TableMetaData) IsPrimaryKey(column string) bool {
return t.PrimaryKeys[column] return t.PrimaryKeys[column]
} }
// MutableColumns returns list of mutable columns for table // MutableColumns returns list of mutable columns for table
func (t TableInfo) MutableColumns() []ColumnInfo { func (t TableMetaData) MutableColumns() []ColumnMetaData {
ret := []ColumnInfo{} ret := []ColumnMetaData{}
for _, column := range t.Columns { for _, column := range t.Columns {
if t.IsPrimaryKey(column.Name) { if t.IsPrimaryKey(column.Name) {
@ -39,11 +39,11 @@ func (t TableInfo) MutableColumns() []ColumnInfo {
} }
// GetImports returns model imports for table. // GetImports returns model imports for table.
func (t TableInfo) GetImports() []string { func (t TableMetaData) GetImports() []string {
imports := map[string]string{} imports := map[string]string{}
for _, column := range t.Columns { for _, column := range t.Columns {
columnType := column.GoBaseType() columnType := column.GoBaseType
switch columnType { switch columnType {
case "time.Time": case "time.Time":
@ -63,22 +63,22 @@ func (t TableInfo) GetImports() []string {
} }
// GoStructName returns go struct name for sql builder // GoStructName returns go struct name for sql builder
func (t TableInfo) GoStructName() string { func (t TableMetaData) GoStructName() string {
return utils.ToGoIdentifier(t.name) + "Table" return utils.ToGoIdentifier(t.name) + "Table"
} }
// GetTableInfo returns table info metadata // GetTableInfo returns table info metadata
func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo TableInfo, err error) { func GetTableInfo(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData, err error) {
tableInfo.SchemaName = schemaName tableInfo.SchemaName = schemaName
tableInfo.name = tableName tableInfo.name = tableName
tableInfo.PrimaryKeys, err = getPrimaryKeys(db, dbName, schemaName, tableName) tableInfo.PrimaryKeys, err = getPrimaryKeys(db, querySet, schemaName, tableName)
if err != nil { if err != nil {
return return
} }
tableInfo.Columns, err = getColumnInfos(db, dbName, schemaName, tableName) tableInfo.Columns, err = getColumnsMetaData(db, querySet, schemaName, tableName)
if err != nil { if err != nil {
return return
@ -87,15 +87,9 @@ func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo T
return return
} }
func getPrimaryKeys(db *sql.DB, dbName, schemaName, tableName string) (map[string]bool, error) { func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (map[string]bool, error) {
query := `
SELECT c.column_name rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName)
FROM information_schema.key_column_usage AS c
LEFT JOIN information_schema.table_constraints AS t
ON t.constraint_name = c.constraint_name
WHERE t.table_catalog = $1 AND t.table_schema = $2 AND t.table_name = $3 AND t.constraint_type = 'PRIMARY KEY';
`
rows, err := db.Query(query, dbName, schemaName, tableName)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -0,0 +1,119 @@
package template
import (
"bytes"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/utils"
"path/filepath"
"text/template"
"time"
)
// GenerateFiles generates Go files from tables and enums metadata
func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect jet.Dialect) error {
if len(tables) == 0 && len(enums) == 0 {
return nil
}
fmt.Println("Destination directory:", destDir)
fmt.Println("Cleaning up destination directory...")
err := utils.CleanUpGeneratedFiles(destDir)
if err != nil {
return err
}
fmt.Println("Generating table sql builder files...")
err = generate(destDir, "table", tableSQLBuilderTemplate, tables, dialect)
if err != nil {
return err
}
fmt.Println("Generating table model files...")
err = generate(destDir, "model", tableModelTemplate, tables, dialect)
if err != nil {
return err
}
if len(enums) > 0 {
fmt.Println("Generating enum sql builder files...")
err = generate(destDir, "enum", enumSQLBuilderTemplate, enums, dialect)
if err != nil {
return err
}
fmt.Println("Generating enum model files...")
err = generate(destDir, "model", enumModelTemplate, enums, dialect)
if err != nil {
return err
}
}
fmt.Println("Done")
return nil
}
func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error {
modelDirPath := filepath.Join(dirPath, packageName)
err := utils.EnsureDirPath(modelDirPath)
if err != nil {
return err
}
autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect)
if err != nil {
return err
}
for _, metaData := range metaDataList {
text, err := GenerateTemplate(template, metaData, dialect)
if err != nil {
return err
}
err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...))
if err != nil {
return err
}
}
return nil
}
// GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}, dialect1 jet.Dialect) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": utils.ToGoIdentifier,
"now": func() string {
return time.Now().Format(time.RFC850)
},
"dialect": func() jet.Dialect {
return dialect1
},
}).Parse(templateText)
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := t.Execute(&buf, templateData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View file

@ -1,4 +1,4 @@
package postgres package template
var autoGenWarningTemplate = ` var autoGenWarningTemplate = `
// //
@ -11,7 +11,7 @@ var autoGenWarningTemplate = `
` `
var sqlBuilderTableTemplate = ` var tableSQLBuilderTemplate = `
{{define "column-list" -}} {{define "column-list" -}}
{{- range $i, $c := . }} {{- range $i, $c := . }}
{{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column {{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column
@ -21,21 +21,21 @@ var sqlBuilderTableTemplate = `
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/{{dialect.PackageName}}"
) )
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}() var {{ToGoIdentifier .Name}} = new{{.GoStructName}}()
type {{.GoStructName}} struct { type {{.GoStructName}} struct {
jet.Table {{dialect.PackageName}}.Table
//Columns //Columns
{{- range .Columns}} {{- range .Columns}}
{{ToGoIdentifier .Name}} jet.Column{{.SqlBuilderColumnType}} {{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
{{- end}} {{- end}}
AllColumns jet.ColumnList AllColumns {{dialect.PackageName}}.IColumnList
MutableColumns jet.ColumnList MutableColumns {{dialect.PackageName}}.IColumnList
} }
// creates new {{.GoStructName}} with assigned alias // creates new {{.GoStructName}} with assigned alias
@ -50,26 +50,26 @@ func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
func new{{.GoStructName}}() *{{.GoStructName}} { func new{{.GoStructName}}() *{{.GoStructName}} {
var ( var (
{{- range .Columns}} {{- range .Columns}}
{{ToGoIdentifier .Name}}Column = jet.{{.SqlBuilderColumnType}}Column("{{.Name}}") {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}} {{- end}}
) )
return &{{.GoStructName}}{ return &{{.GoStructName}}{
Table: jet.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}), Table: {{dialect.PackageName}}.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}),
//Columns //Columns
{{- range .Columns}} {{- range .Columns}}
{{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column, {{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column,
{{- end}} {{- end}}
AllColumns: jet.ColumnList{ {{template "column-list" .Columns}} }, AllColumns: {{dialect.PackageName}}.ColumnList( {{template "column-list" .Columns}} ),
MutableColumns: jet.ColumnList{ {{template "column-list" .MutableColumns}} }, MutableColumns: {{dialect.PackageName}}.ColumnList( {{template "column-list" .MutableColumns}} ),
} }
} }
` `
var dataModelTemplate = `package model var tableModelTemplate = `package model
{{ if .GetImports }} {{ if .GetImports }}
import ( import (
@ -85,6 +85,22 @@ type {{ToGoIdentifier .Name}} struct {
{{ToGoIdentifier .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + ` {{ToGoIdentifier .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + `
{{- end}} {{- end}}
} }
`
var enumSQLBuilderTemplate = `package enum
import "github.com/go-jet/jet/{{dialect.PackageName}}"
var {{ToGoIdentifier $.Name}} = &struct {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}} {{dialect.PackageName}}.StringExpression
{{- end}}
} {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}}: {{dialect.PackageName}}.NewEnumValue("{{$element}}"),
{{- end}}
}
` `
var enumModelTemplate = `package model var enumModelTemplate = `package model
@ -121,17 +137,3 @@ func (e {{ToGoIdentifier $.Name}}) String() string {
} }
` `
var enumTypeTemplate = `package enum
import "github.com/go-jet/jet"
var {{ToGoIdentifier $.Name}} = &struct {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}} jet.StringExpression
{{- end}}
} {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}}: jet.NewEnumValue("{{$element}}"),
{{- end}}
}
`

View file

@ -0,0 +1,71 @@
package mysql
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/template"
"github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/mysql"
"path"
)
// DBConnection contains MySQL connection details
type DBConnection struct {
Host string
Port int
User string
Password string
Params string
DBName string
}
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
db, err := openConnection(dbConn)
if err != nil {
return err
}
defer utils.DBClose(db)
fmt.Println("Retrieving database information...")
// No schemas in MySQL
dbInfo, err := metadata.GetSchemaInfo(db, dbConn.DBName, &mySqlQuerySet{})
if err != nil {
return err
}
genPath := path.Join(destDir, dbConn.DBName)
err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, mysql.Dialect)
if err != nil {
return err
}
return nil
}
func openConnection(dbConn DBConnection) (*sql.DB, error) {
var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName)
if dbConn.Params != "" {
connectionString += "?" + dbConn.Params
}
db, err := sql.Open("mysql", connectionString)
fmt.Println("Connecting to MySQL database: " + connectionString)
if err != nil {
return nil, err
}
err = db.Ping()
if err != nil {
return nil, err
}
return db, nil
}

View file

@ -0,0 +1,88 @@
package mysql
import (
"database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
"strings"
)
// mySqlQuerySet is dialect query set for MySQL
type mySqlQuerySet struct{}
func (m *mySqlQuerySet) ListOfTablesQuery() string {
return `
SELECT table_name
FROM INFORMATION_SCHEMA.tables
WHERE table_schema = ? and table_type = 'BASE TABLE';
`
}
func (m *mySqlQuerySet) PrimaryKeysQuery() string {
return `
SELECT k.column_name
FROM information_schema.table_constraints t
JOIN information_schema.key_column_usage k
USING(constraint_name,table_schema,table_name)
WHERE t.constraint_type='PRIMARY KEY'
AND t.table_schema= ?
AND t.table_name= ?;
`
}
func (m *mySqlQuerySet) ListOfColumnsQuery() string {
return `
SELECT COLUMN_NAME,
IS_NULLABLE, IF(COLUMN_TYPE = 'tinyint(1)', 'boolean', DATA_TYPE),
IF(DATA_TYPE = 'enum', CONCAT(TABLE_NAME, '_', COLUMN_NAME), ''),
COLUMN_TYPE LIKE '%unsigned%'
FROM information_schema.columns
WHERE table_schema = ? and table_name = ?
ORDER BY ordinal_position;
`
}
func (m *mySqlQuerySet) ListOfEnumsQuery() string {
return `
SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5)
FROM information_schema.columns as c
INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name)
WHERE c.table_schema = ? AND DATA_TYPE = 'enum' AND t.TABLE_TYPE = 'BASE TABLE';
`
}
func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.MetaData, error) {
rows, err := db.Query(m.ListOfEnumsQuery(), schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []metadata.MetaData{}
for rows.Next() {
var enumName string
var enumValues string
err = rows.Scan(&enumName, &enumValues)
if err != nil {
return nil, err
}
enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1)
ret = append(ret, metadata.EnumMetaData{
EnumName: enumName,
Values: strings.Split(enumValues, ","),
})
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -1,134 +0,0 @@
package postgres
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/metadata/postgresmeta"
"github.com/go-jet/jet/internal/utils"
"path"
"path/filepath"
"strconv"
)
// DBConnection contains postgres connection details
type DBConnection struct {
Host string
Port int
User string
Password string
SslMode string
Params string
DBName string
SchemaName string
}
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params)
fmt.Println("Connecting to postgres database: " + connectionString)
db, err := sql.Open("postgres", connectionString)
if err != nil {
return err
}
defer db.Close()
err = db.Ping()
if err != nil {
return err
}
fmt.Println("Retrieving schema information...")
schemaInfo, err := postgresmeta.GetSchemaInfo(db, dbConn.DBName, dbConn.SchemaName)
if err != nil {
return err
}
fmt.Println(" FOUND", len(schemaInfo.TableInfos), "table(s), ", len(schemaInfo.EnumInfos), "enum(s)")
if len(schemaInfo.TableInfos) == 0 && len(schemaInfo.EnumInfos) == 0 {
return nil
}
schemaGenPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
fmt.Println("Destination directory:", schemaGenPath)
fmt.Println("Cleaning up destination directory...")
err = utils.CleanUpGeneratedFiles(schemaGenPath)
if err != nil {
return err
}
fmt.Println("Generating table sql builder files...")
err = generate(schemaInfo, destDir, "table", sqlBuilderTableTemplate, schemaInfo.TableInfos)
if err != nil {
return err
}
fmt.Println("Generating table model files...")
err = generate(schemaInfo, destDir, "model", dataModelTemplate, schemaInfo.TableInfos)
if err != nil {
return err
}
if len(schemaInfo.EnumInfos) > 0 {
fmt.Println("Generating enum sql builder files...")
err = generate(schemaInfo, destDir, "enum", enumTypeTemplate, schemaInfo.EnumInfos)
if err != nil {
return err
}
fmt.Println("Generating enum model files...")
err = generate(schemaInfo, destDir, "model", enumModelTemplate, schemaInfo.EnumInfos)
if err != nil {
return err
}
}
fmt.Println("Done")
return nil
}
func generate(schemaInfo postgresmeta.SchemaInfo, dirPath, packageName string, template string, metaDataList []metadata.MetaData) error {
modelDirPath := filepath.Join(dirPath, schemaInfo.DatabaseName, schemaInfo.Name, packageName)
err := utils.EnsureDirPath(modelDirPath)
if err != nil {
return err
}
autoGenWarning, err := utils.GenerateTemplate(autoGenWarningTemplate, nil)
if err != nil {
return err
}
for _, metaData := range metaDataList {
text, err := utils.GenerateTemplate(template, metaData)
if err != nil {
return err
}
err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...))
if err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,73 @@
package postgres
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/template"
"github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/postgres"
"path"
"strconv"
)
// DBConnection contains postgres connection details
type DBConnection struct {
Host string
Port int
User string
Password string
SslMode string
Params string
DBName string
SchemaName string
}
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
db, err := openConnection(dbConn)
defer utils.DBClose(db)
if err != nil {
return err
}
fmt.Println("Retrieving schema information...")
schemaInfo, err := metadata.GetSchemaInfo(db, dbConn.SchemaName, &postgresQuerySet{})
if err != nil {
return err
}
genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, postgres.Dialect)
if err != nil {
return err
}
return nil
}
func openConnection(dbConn DBConnection) (*sql.DB, error) {
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params)
fmt.Println("Connecting to postgres database: " + connectionString)
db, err := sql.Open("postgres", connectionString)
if err != nil {
return nil, err
}
err = db.Ping()
if err != nil {
return nil, err
}
return db, nil
}

View file

@ -0,0 +1,88 @@
package postgres
import (
"database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
)
// postgresQuerySet is dialect query set for PostgreSQL
type postgresQuerySet struct{}
func (p *postgresQuerySet) ListOfTablesQuery() string {
return `
SELECT table_name
FROM information_schema.tables
where table_schema = $1 and table_type = 'BASE TABLE';
`
}
func (p *postgresQuerySet) PrimaryKeysQuery() string {
return `
SELECT c.column_name
FROM information_schema.key_column_usage AS c
LEFT JOIN information_schema.table_constraints AS t
ON t.constraint_name = c.constraint_name
WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY';
`
}
func (p *postgresQuerySet) ListOfColumnsQuery() string {
return `
SELECT column_name, is_nullable, data_type, udt_name, FALSE
FROM information_schema.columns
where table_schema = $1 and table_name = $2
order by ordinal_position;`
}
func (p *postgresQuerySet) ListOfEnumsQuery() string {
return `
SELECT t.typname,
e.enumlabel
FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;`
}
func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.MetaData, error) {
rows, err := db.Query(p.ListOfEnumsQuery(), schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
enumsInfosMap := map[string][]string{}
for rows.Next() {
var enumName string
var enumValue string
err = rows.Scan(&enumName, &enumValue)
if err != nil {
return nil, err
}
enumValues := enumsInfosMap[enumName]
enumValues = append(enumValues, enumValue)
enumsInfosMap[enumName] = enumValues
}
err = rows.Err()
if err != nil {
return nil, err
}
ret := []metadata.MetaData{}
for enumName, enumValues := range enumsInfosMap {
ret = append(ret, metadata.EnumMetaData{
EnumName: enumName,
Values: enumValues,
})
}
return ret, nil
}

View file

@ -1,5 +0,0 @@
package jet
type groupByClause interface {
serializeForGroupBy(statement statementType, out *sqlBuilder) error
}

View file

@ -1,170 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/utils"
)
// InsertStatement is interface for SQL INSERT statements
type InsertStatement interface {
Statement
// Insert row of values
VALUES(value interface{}, values ...interface{}) InsertStatement
// Insert row of values, where value for each column is extracted from filed of structure data.
// If data is not struct or there is no field for every column selected, this method will panic.
MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
RETURNING(projections ...projection) InsertStatement
}
func newInsertStatement(t WritableTable, columns []column) InsertStatement {
return &insertStatementImpl{
table: t,
columns: columns,
}
}
type insertStatementImpl struct {
table WritableTable
columns []column
rows [][]clause
query SelectStatement
returning []projection
}
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
i.rows = append(i.rows, unwindRowFromValues(value, values))
return i
}
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
i.rows = append(i.rows, unwindRowFromModel(i.getColumns(), data))
return i
}
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
i.rows = append(i.rows, unwindRowsFromModels(i.getColumns(), data)...)
return i
}
func (i *insertStatementImpl) RETURNING(projections ...projection) InsertStatement {
i.returning = projections
return i
}
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
i.query = selectStatement
return i
}
func (i *insertStatementImpl) getColumns() []column {
if len(i.columns) > 0 {
return i.columns
}
return i.table.columns()
}
func (i *insertStatementImpl) DebugSql() (query string, err error) {
return debugSql(i)
}
func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) {
queryData := &sqlBuilder{}
queryData.newLine()
queryData.writeString("INSERT INTO")
if utils.IsNil(i.table) {
return "", nil, errors.New("jet: table is nil")
}
err = i.table.serialize(insertStatement, queryData)
if err != nil {
return
}
if len(i.columns) > 0 {
queryData.writeString("(")
err = serializeColumnNames(i.columns, queryData)
if err != nil {
return
}
queryData.writeString(")")
}
if len(i.rows) == 0 && i.query == nil {
return "", nil, errors.New("jet: no row values or query specified")
}
if len(i.rows) > 0 && i.query != nil {
return "", nil, errors.New("jet: only row values or query has to be specified")
}
if len(i.rows) > 0 {
queryData.writeString("VALUES")
for rowIndex, row := range i.rows {
if rowIndex > 0 {
queryData.writeString(",")
}
queryData.increaseIdent()
queryData.newLine()
queryData.writeString("(")
err = serializeClauseList(insertStatement, row, queryData)
if err != nil {
return "", nil, err
}
queryData.writeByte(')')
queryData.decreaseIdent()
}
}
if i.query != nil {
err = i.query.serialize(insertStatement, queryData)
if err != nil {
return
}
}
if err = queryData.writeReturning(insertStatement, i.returning); err != nil {
return
}
sql, args = queryData.finalize()
return
}
func (i *insertStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(i, db, destination)
}
func (i *insertStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, i, db, destination)
}
func (i *insertStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(i, db)
}
func (i *insertStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, i, db)
}

View file

@ -1,13 +0,0 @@
package snaker
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestDb(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Snaker Suite")
}

View file

@ -1,40 +1,16 @@
package snaker package snaker
import ( import (
. "github.com/onsi/ginkgo" "gotest.tools/assert"
. "github.com/onsi/gomega" "testing"
) )
var _ = Describe("Snaker", func() { func TestSnakeToCamel(t *testing.T) {
assert.Equal(t, SnakeToCamel(""), "")
Describe("SnakeToCamel test", func() { assert.Equal(t, SnakeToCamel("potato_"), "Potato")
It("should return an empty string on an empty input", func() { assert.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased")
Expect(SnakeToCamel("")).To(Equal("")) assert.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID")
}) assert.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier")
assert.Equal(t, SnakeToCamel("id"), "ID")
It("should not blow up on trailing _", func() { assert.Equal(t, SnakeToCamel("oauth_client"), "OAuthClient")
Expect(SnakeToCamel("potato_")).To(Equal("Potato")) }
})
It("should return a snaked text as camel case", func() {
Expect(SnakeToCamel("this_has_to_be_uppercased")).To(
Equal("ThisHasToBeUppercased"))
})
It("should return a snaked text as camel case, except the word ID", func() {
Expect(SnakeToCamel("this_is_an_id")).To(Equal("ThisIsAnID"))
})
It("should return 'id' not as uppercase", func() {
Expect(SnakeToCamel("this_is_an_identifier")).To(Equal("ThisIsAnIdentifier"))
})
It("should simply work with id", func() {
Expect(SnakeToCamel("id")).To(Equal("ID"))
})
It("should work with initialism where only certain characters are uppercase", func() {
Expect(SnakeToCamel("oauth_client")).To(Equal("OAuthClient"))
})
})
})

28
internal/jet/alias.go Normal file
View file

@ -0,0 +1,28 @@
package jet
type alias struct {
expression Expression
alias string
}
func newAlias(expression Expression, aliasName string) Projection {
return &alias{
expression: expression,
alias: aliasName,
}
}
func (a *alias) fromImpl(subQuery SelectTable) Projection {
column := newColumn(a.alias, "", nil)
column.Parent = &column
column.subQuery = subQuery
return &column
}
func (a *alias) serializeForProjection(statement StatementType, out *SQLBuilder) {
a.expression.serialize(statement, out)
out.WriteString("AS")
out.WriteAlias(a.alias)
}

View file

@ -92,14 +92,14 @@ type binaryBoolExpression struct {
binaryOpExpression binaryOpExpression
} }
func newBinaryBoolOperator(lhs, rhs Expression, operator string) BoolExpression { func newBinaryBoolOperator(lhs, rhs Expression, operator string, additionalParams ...Expression) BoolExpression {
boolExpression := binaryBoolExpression{} binaryBoolExpression := binaryBoolExpression{}
boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) binaryBoolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator, additionalParams...)
boolExpression.expressionInterfaceImpl.parent = &boolExpression binaryBoolExpression.expressionInterfaceImpl.Parent = &binaryBoolExpression
boolExpression.boolInterfaceImpl.parent = &boolExpression binaryBoolExpression.boolInterfaceImpl.parent = &binaryBoolExpression
return &boolExpression return &binaryBoolExpression
} }
//---------------------------------------------------// //---------------------------------------------------//
@ -114,7 +114,7 @@ func newPrefixBoolOperator(expression Expression, operator string) BoolExpressio
exp := prefixBoolExpression{} exp := prefixBoolExpression{}
exp.prefixOpExpression = newPrefixExpression(expression, operator) exp.prefixOpExpression = newPrefixExpression(expression, operator)
exp.expressionInterfaceImpl.parent = &exp exp.expressionInterfaceImpl.Parent = &exp
exp.boolInterfaceImpl.parent = &exp exp.boolInterfaceImpl.parent = &exp
return &exp return &exp
@ -132,7 +132,7 @@ func newPostifxBoolExpression(expression Expression, operator string) BoolExpres
exp := postfixBoolOpExpression{} exp := postfixBoolOpExpression{}
exp.postfixOpExpression = newPostfixOpExpression(expression, operator) exp.postfixOpExpression = newPostfixOpExpression(expression, operator)
exp.expressionInterfaceImpl.parent = &exp exp.expressionInterfaceImpl.Parent = &exp
exp.boolInterfaceImpl.parent = &exp exp.boolInterfaceImpl.parent = &exp
return &exp return &exp

View file

@ -5,9 +5,8 @@ import (
) )
func TestBoolExpressionEQ(t *testing.T) { func TestBoolExpressionEQ(t *testing.T) {
assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: nil rhs")
assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)") assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)")
assertClauseSerialize(t, table1ColBool.EQ(Bool(true)), "(table1.col_bool = $1)", true) assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: rhs is nil for '=' operator")
} }
func TestBoolExpressionNOT_EQ(t *testing.T) { func TestBoolExpressionNOT_EQ(t *testing.T) {
@ -57,6 +56,7 @@ func TestBinaryBoolExpression(t *testing.T) {
boolExpression := Int(2).EQ(Int(3)) boolExpression := Int(2).EQ(Int(3))
assertClauseSerialize(t, boolExpression, "($1 = $2)", int64(2), int64(3)) assertClauseSerialize(t, boolExpression, "($1 = $2)", int64(2), int64(3))
assertProjectionSerialize(t, boolExpression, "$1 = $2", int64(2), int64(3)) assertProjectionSerialize(t, boolExpression, "$1 = $2", int64(2), int64(3))
assertProjectionSerialize(t, boolExpression.AS("alias_eq_expression"), assertProjectionSerialize(t, boolExpression.AS("alias_eq_expression"),
`($1 = $2) AS "alias_eq_expression"`, int64(2), int64(3)) `($1 = $2) AS "alias_eq_expression"`, int64(2), int64(3))
@ -71,20 +71,6 @@ func TestBoolLiteral(t *testing.T) {
assertClauseSerialize(t, Bool(false), "$1", false) assertClauseSerialize(t, Bool(false), "$1", false)
} }
func TestExists(t *testing.T) {
assertClauseSerialize(t, EXISTS(
table2.
SELECT(Int(1)).
WHERE(table1Col1.EQ(table2Col3)),
),
`EXISTS (
SELECT $1
FROM db.table2
WHERE table1.col1 = table2.col3
)`, int64(1))
}
func TestBoolExp(t *testing.T) { func TestBoolExp(t *testing.T) {
assertClauseSerialize(t, BoolExp(String("true")), "$1", "true") assertClauseSerialize(t, BoolExp(String("true")), "$1", "true")
assertClauseSerialize(t, BoolExp(String("true")).IS_TRUE(), "$1 IS TRUE", "true") assertClauseSerialize(t, BoolExp(String("true")).IS_TRUE(), "$1 IS TRUE", "true")

53
internal/jet/cast.go Normal file
View file

@ -0,0 +1,53 @@
package jet
// Cast interface
type Cast interface {
AS(castType string) Expression
}
type castImpl struct {
expression Expression
}
// NewCastImpl creates new generic cast
func NewCastImpl(expression Expression) Cast {
castImpl := castImpl{
expression: expression,
}
return &castImpl
}
func (b *castImpl) AS(castType string) Expression {
castExp := &castExpression{
expression: b.expression,
cast: string(castType),
}
castExp.expressionInterfaceImpl.Parent = castExp
return castExp
}
type castExpression struct {
expressionInterfaceImpl
expression Expression
cast string
}
func (b *castExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
expression := b.expression
castType := b.cast
if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil {
castOverride(expression, String(castType))(statement, out, options...)
return
}
out.WriteString("CAST(")
expression.serialize(statement, out, options...)
out.WriteString("AS")
out.WriteString(castType + ")")
}

11
internal/jet/cast_test.go Normal file
View file

@ -0,0 +1,11 @@
package jet
import (
"testing"
)
func TestCastAS(t *testing.T) {
assertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST($1 AS boolean)", int64(1))
assertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)")
assertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)")
}

471
internal/jet/clause.go Normal file
View file

@ -0,0 +1,471 @@
package jet
import (
"github.com/go-jet/jet/internal/utils"
)
// Clause interface
type Clause interface {
Serialize(statementType StatementType, out *SQLBuilder)
}
// ClauseWithProjections interface
type ClauseWithProjections interface {
Clause
projections() ProjectionList
}
// ClauseSelect struct
type ClauseSelect struct {
Distinct bool
Projections []Projection
}
func (s *ClauseSelect) projections() ProjectionList {
return s.Projections
}
// Serialize serializes clause into SQLBuilder
func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder) {
out.NewLine()
out.WriteString("SELECT")
if s.Distinct {
out.WriteString("DISTINCT")
}
if len(s.Projections) == 0 {
panic("jet: SELECT clause has to have at least one projection")
}
out.WriteProjections(statementType, s.Projections)
}
// ClauseFrom struct
type ClauseFrom struct {
Table Serializer
}
// Serialize serializes clause into SQLBuilder
func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder) {
if f.Table == nil {
return
}
out.NewLine()
out.WriteString("FROM")
out.IncreaseIdent()
f.Table.serialize(statementType, out)
out.DecreaseIdent()
}
// ClauseWhere struct
type ClauseWhere struct {
Condition BoolExpression
Mandatory bool
}
// Serialize serializes clause into SQLBuilder
func (c *ClauseWhere) Serialize(statementType StatementType, out *SQLBuilder) {
if c.Condition == nil {
if c.Mandatory {
panic("jet: WHERE clause not set")
}
return
}
out.NewLine()
out.WriteString("WHERE")
out.IncreaseIdent()
c.Condition.serialize(statementType, out, noWrap)
out.DecreaseIdent()
}
// ClauseGroupBy struct
type ClauseGroupBy struct {
List []GroupByClause
}
// Serialize serializes clause into SQLBuilder
func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SQLBuilder) {
if len(c.List) == 0 {
return
}
out.NewLine()
out.WriteString("GROUP BY")
out.IncreaseIdent()
for i, c := range c.List {
if i > 0 {
out.WriteString(", ")
}
if c == nil {
panic("jet: nil clause in GROUP BY list")
}
c.serializeForGroupBy(statementType, out)
}
out.DecreaseIdent()
}
// ClauseHaving struct
type ClauseHaving struct {
Condition BoolExpression
}
// Serialize serializes clause into SQLBuilder
func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) {
if c.Condition == nil {
return
}
out.NewLine()
out.WriteString("HAVING")
out.IncreaseIdent()
c.Condition.serialize(statementType, out, noWrap)
out.DecreaseIdent()
}
// ClauseOrderBy struct
type ClauseOrderBy struct {
List []OrderByClause
}
// Serialize serializes clause into SQLBuilder
func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder) {
if o.List == nil {
return
}
out.NewLine()
out.WriteString("ORDER BY")
out.IncreaseIdent()
for i, value := range o.List {
if i > 0 {
out.WriteString(", ")
}
value.serializeForOrderBy(statementType, out)
}
out.DecreaseIdent()
}
// ClauseLimit struct
type ClauseLimit struct {
Count int64
}
// Serialize serializes clause into SQLBuilder
func (l *ClauseLimit) Serialize(statementType StatementType, out *SQLBuilder) {
if l.Count >= 0 {
out.NewLine()
out.WriteString("LIMIT")
out.insertParametrizedArgument(l.Count)
}
}
// ClauseOffset struct
type ClauseOffset struct {
Count int64
}
// Serialize serializes clause into SQLBuilder
func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder) {
if o.Count >= 0 {
out.NewLine()
out.WriteString("OFFSET")
out.insertParametrizedArgument(o.Count)
}
}
// ClauseFor struct
type ClauseFor struct {
Lock RowLock
}
// Serialize serializes clause into SQLBuilder
func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder) {
if f.Lock == nil {
return
}
out.NewLine()
out.WriteString("FOR")
f.Lock.serialize(statementType, out)
}
// ClauseSetStmtOperator struct
type ClauseSetStmtOperator struct {
Operator string
All bool
Selects []StatementWithProjections
OrderBy ClauseOrderBy
Limit ClauseLimit
Offset ClauseOffset
}
func (s *ClauseSetStmtOperator) projections() ProjectionList {
if len(s.Selects) > 0 {
return s.Selects[0].projections()
}
return nil
}
// Serialize serializes clause into SQLBuilder
func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLBuilder) {
if len(s.Selects) < 2 {
panic("jet: UNION Statement must contain at least two SELECT statements")
}
for i, selectStmt := range s.Selects {
out.NewLine()
if i > 0 {
out.WriteString(s.Operator)
if s.All {
out.WriteString("ALL")
}
out.NewLine()
}
if selectStmt == nil {
panic("jet: select statement of '" + s.Operator + "' is nil")
}
selectStmt.serialize(statementType, out)
}
s.OrderBy.Serialize(statementType, out)
s.Limit.Serialize(statementType, out)
s.Offset.Serialize(statementType, out)
}
// ClauseUpdate struct
type ClauseUpdate struct {
Table SerializerTable
}
// Serialize serializes clause into SQLBuilder
func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder) {
out.NewLine()
out.WriteString("UPDATE")
if utils.IsNil(u.Table) {
panic("jet: table to update is nil")
}
u.Table.serialize(statementType, out)
}
// ClauseSet struct
type ClauseSet struct {
Columns []Column
Values []Serializer
}
// Serialize serializes clause into SQLBuilder
func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder) {
out.NewLine()
out.WriteString("SET")
if len(s.Columns) != len(s.Values) {
panic("jet: mismatch in numbers of columns and values for SET clause")
}
out.IncreaseIdent(4)
for i, column := range s.Columns {
if i > 0 {
out.WriteString(", ")
out.NewLine()
}
if column == nil {
panic("jet: nil column in columns list for SET clause")
}
out.WriteString(column.Name())
out.WriteString(" = ")
s.Values[i].serialize(UpdateStatementType, out)
}
out.DecreaseIdent(4)
}
// ClauseInsert struct
type ClauseInsert struct {
Table SerializerTable
Columns []Column
}
// GetColumns gets list of columns for insert
func (i *ClauseInsert) GetColumns() []Column {
if len(i.Columns) > 0 {
return i.Columns
}
return i.Table.columns()
}
// Serialize serializes clause into SQLBuilder
func (i *ClauseInsert) Serialize(statementType StatementType, out *SQLBuilder) {
out.NewLine()
out.WriteString("INSERT INTO")
if utils.IsNil(i.Table) {
panic("jet: table is nil for INSERT clause")
}
i.Table.serialize(statementType, out)
if len(i.Columns) > 0 {
out.WriteString("(")
SerializeColumnNames(i.Columns, out)
out.WriteString(")")
}
}
// ClauseValuesQuery struct
type ClauseValuesQuery struct {
ClauseValues
ClauseQuery
}
// Serialize serializes clause into SQLBuilder
func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder) {
if len(v.Rows) == 0 && v.Query == nil {
panic("jet: VALUES or QUERY has to be specified for INSERT statement")
}
if len(v.Rows) > 0 && v.Query != nil {
panic("jet: VALUES or QUERY has to be specified for INSERT statement")
}
v.ClauseValues.Serialize(statementType, out)
v.ClauseQuery.Serialize(statementType, out)
}
// ClauseValues struct
type ClauseValues struct {
Rows [][]Serializer
}
// Serialize serializes clause into SQLBuilder
func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder) {
if len(v.Rows) == 0 {
return
}
out.WriteString("VALUES")
for rowIndex, row := range v.Rows {
if rowIndex > 0 {
out.WriteString(",")
}
out.IncreaseIdent()
out.NewLine()
out.WriteString("(")
SerializeClauseList(statementType, row, out)
out.WriteByte(')')
out.DecreaseIdent()
}
}
// ClauseQuery struct
type ClauseQuery struct {
Query SerializerStatement
}
// Serialize serializes clause into SQLBuilder
func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder) {
if v.Query == nil {
return
}
v.Query.serialize(statementType, out)
}
// ClauseDelete struct
type ClauseDelete struct {
Table SerializerTable
}
// Serialize serializes clause into SQLBuilder
func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder) {
out.NewLine()
out.WriteString("DELETE FROM")
if d.Table == nil {
panic("jet: nil table in DELETE clause")
}
d.Table.serialize(statementType, out)
}
// ClauseStatementBegin struct
type ClauseStatementBegin struct {
Name string
Tables []SerializerTable
}
// Serialize serializes clause into SQLBuilder
func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBuilder) {
out.NewLine()
out.WriteString(d.Name)
for i, table := range d.Tables {
if i > 0 {
out.WriteString(", ")
}
table.serialize(statementType, out)
}
}
// ClauseOptional struct
type ClauseOptional struct {
Name string
Show bool
InNewLine bool
}
// Serialize serializes clause into SQLBuilder
func (d *ClauseOptional) Serialize(statementType StatementType, out *SQLBuilder) {
if !d.Show {
return
}
if d.InNewLine {
out.NewLine()
}
out.WriteString(d.Name)
}
// ClauseIn struct
type ClauseIn struct {
LockMode string
}
// Serialize serializes clause into SQLBuilder
func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder) {
if i.LockMode == "" {
return
}
out.WriteString("IN")
out.WriteString(string(i.LockMode))
out.WriteString("MODE")
}

View file

@ -0,0 +1,16 @@
package jet
import (
"gotest.tools/assert"
"testing"
)
func TestClauseSelect_Serialize(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "jet: SELECT clause has to have at least one projection")
}()
selectClause := &ClauseSelect{}
selectClause.Serialize(SelectStatementType, &SQLBuilder{})
}

148
internal/jet/column.go Normal file
View file

@ -0,0 +1,148 @@
// Modeling of columns
package jet
// Column is common column interface for all types of columns.
type Column interface {
Name() string
TableName() string
setTableName(table string)
setSubQuery(subQuery SelectTable)
defaultAlias() string
}
// ColumnExpression interface
type ColumnExpression interface {
Column
Expression
}
// The base type for real materialized columns.
type columnImpl struct {
expressionInterfaceImpl
name string
tableName string
subQuery SelectTable
}
func newColumn(name string, tableName string, parent ColumnExpression) columnImpl {
bc := columnImpl{
name: name,
tableName: tableName,
}
bc.expressionInterfaceImpl.Parent = parent
return bc
}
func (c *columnImpl) Name() string {
return c.name
}
func (c *columnImpl) TableName() string {
return c.tableName
}
func (c *columnImpl) setTableName(table string) {
c.tableName = table
}
func (c *columnImpl) setSubQuery(subQuery SelectTable) {
c.subQuery = subQuery
}
func (c *columnImpl) defaultAlias() string {
if c.tableName != "" {
return c.tableName + "." + c.name
}
return c.name
}
func (c *columnImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) {
if statement == SetStatementType {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.WriteAlias(c.defaultAlias()) //always quote
return
}
c.serialize(statement, out)
}
func (c columnImpl) serializeForProjection(statement StatementType, out *SQLBuilder) {
c.serialize(statement, out)
out.WriteString("AS")
out.WriteAlias(c.defaultAlias())
}
func (c columnImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.subQuery != nil {
out.WriteIdentifier(c.subQuery.Alias())
out.WriteByte('.')
out.WriteIdentifier(c.defaultAlias(), true)
} else {
if c.tableName != "" {
out.WriteIdentifier(c.tableName)
out.WriteByte('.')
}
out.WriteIdentifier(c.name)
}
}
//------------------------------------------------------//
// IColumnList is used to store list of columns for later reuse as single projection or
// column list for UPDATE and INSERT statement.
type IColumnList interface {
Projection
Column
columns() []ColumnExpression
}
// ColumnList function returns list of columns that be used as projection or column list for UPDATE and INSERT statement.
func ColumnList(columns ...ColumnExpression) IColumnList {
return columnListImpl(columns)
}
// ColumnList is redefined type to support list of columns as single Projection
type columnListImpl []ColumnExpression
func (cl columnListImpl) columns() []ColumnExpression {
return cl
}
func (cl columnListImpl) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
}
return newProjectionList
}
func (cl columnListImpl) serializeForProjection(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionList(statement, projections, out)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl columnListImpl) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl columnListImpl) TableName() string { return "" }
func (cl columnListImpl) setTableName(name string) {}
func (cl columnListImpl) setSubQuery(subQuery SelectTable) {}
func (cl columnListImpl) defaultAlias() string { return "" }

View file

@ -4,7 +4,7 @@ import "testing"
func TestColumn(t *testing.T) { func TestColumn(t *testing.T) {
column := newColumn("col", "", nil) column := newColumn("col", "", nil)
column.expressionInterfaceImpl.parent = &column column.expressionInterfaceImpl.Parent = &column
assertClauseSerialize(t, column, "col") assertClauseSerialize(t, column, "col")
column.setTableName("table1") column.setTableName("table1")

View file

@ -3,7 +3,7 @@ package jet
// ColumnBool is interface for SQL boolean columns. // ColumnBool is interface for SQL boolean columns.
type ColumnBool interface { type ColumnBool interface {
BoolExpression BoolExpression
column Column
From(subQuery SelectTable) ColumnBool From(subQuery SelectTable) ColumnBool
} }
@ -14,7 +14,7 @@ type boolColumnImpl struct {
columnImpl columnImpl
} }
func (i *boolColumnImpl) from(subQuery SelectTable) projection { func (i *boolColumnImpl) fromImpl(subQuery SelectTable) Projection {
newBoolColumn := BoolColumn(i.name) newBoolColumn := BoolColumn(i.name)
newBoolColumn.setTableName(i.tableName) newBoolColumn.setTableName(i.tableName)
newBoolColumn.setSubQuery(subQuery) newBoolColumn.setSubQuery(subQuery)
@ -23,7 +23,7 @@ func (i *boolColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool { func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
newBoolColumn := i.from(subQuery).(ColumnBool) newBoolColumn := i.fromImpl(subQuery).(ColumnBool)
return newBoolColumn return newBoolColumn
} }
@ -42,7 +42,7 @@ func BoolColumn(name string) ColumnBool {
// ColumnFloat is interface for SQL real, numeric, decimal or double precision column. // ColumnFloat is interface for SQL real, numeric, decimal or double precision column.
type ColumnFloat interface { type ColumnFloat interface {
FloatExpression FloatExpression
column Column
From(subQuery SelectTable) ColumnFloat From(subQuery SelectTable) ColumnFloat
} }
@ -52,7 +52,7 @@ type floatColumnImpl struct {
columnImpl columnImpl
} }
func (i *floatColumnImpl) from(subQuery SelectTable) projection { func (i *floatColumnImpl) fromImpl(subQuery SelectTable) Projection {
newFloatColumn := FloatColumn(i.name) newFloatColumn := FloatColumn(i.name)
newFloatColumn.setTableName(i.tableName) newFloatColumn.setTableName(i.tableName)
newFloatColumn.setSubQuery(subQuery) newFloatColumn.setSubQuery(subQuery)
@ -61,7 +61,7 @@ func (i *floatColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat { func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
newFloatColumn := i.from(subQuery).(ColumnFloat) newFloatColumn := i.fromImpl(subQuery).(ColumnFloat)
return newFloatColumn return newFloatColumn
} }
@ -80,7 +80,7 @@ func FloatColumn(name string) ColumnFloat {
// ColumnInteger is interface for SQL smallint, integer, bigint columns. // ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger interface { type ColumnInteger interface {
IntegerExpression IntegerExpression
column Column
From(subQuery SelectTable) ColumnInteger From(subQuery SelectTable) ColumnInteger
} }
@ -91,7 +91,7 @@ type integerColumnImpl struct {
columnImpl columnImpl
} }
func (i *integerColumnImpl) from(subQuery SelectTable) projection { func (i *integerColumnImpl) fromImpl(subQuery SelectTable) Projection {
newIntColumn := IntegerColumn(i.name) newIntColumn := IntegerColumn(i.name)
newIntColumn.setTableName(i.tableName) newIntColumn.setTableName(i.tableName)
newIntColumn.setSubQuery(subQuery) newIntColumn.setSubQuery(subQuery)
@ -100,7 +100,7 @@ func (i *integerColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger { func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
return i.from(subQuery).(ColumnInteger) return i.fromImpl(subQuery).(ColumnInteger)
} }
// IntegerColumn creates named integer column. // IntegerColumn creates named integer column.
@ -118,7 +118,7 @@ func IntegerColumn(name string) ColumnInteger {
// bytea, uuid columns and enums types. // bytea, uuid columns and enums types.
type ColumnString interface { type ColumnString interface {
StringExpression StringExpression
column Column
From(subQuery SelectTable) ColumnString From(subQuery SelectTable) ColumnString
} }
@ -129,7 +129,7 @@ type stringColumnImpl struct {
columnImpl columnImpl
} }
func (i *stringColumnImpl) from(subQuery SelectTable) projection { func (i *stringColumnImpl) fromImpl(subQuery SelectTable) Projection {
newStrColumn := StringColumn(i.name) newStrColumn := StringColumn(i.name)
newStrColumn.setTableName(i.tableName) newStrColumn.setTableName(i.tableName)
newStrColumn.setSubQuery(subQuery) newStrColumn.setSubQuery(subQuery)
@ -138,7 +138,7 @@ func (i *stringColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString { func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
return i.from(subQuery).(ColumnString) return i.fromImpl(subQuery).(ColumnString)
} }
// StringColumn creates named string column. // StringColumn creates named string column.
@ -155,7 +155,7 @@ func StringColumn(name string) ColumnString {
// ColumnTime is interface for SQL time column. // ColumnTime is interface for SQL time column.
type ColumnTime interface { type ColumnTime interface {
TimeExpression TimeExpression
column Column
From(subQuery SelectTable) ColumnTime From(subQuery SelectTable) ColumnTime
} }
@ -165,7 +165,7 @@ type timeColumnImpl struct {
columnImpl columnImpl
} }
func (i *timeColumnImpl) from(subQuery SelectTable) projection { func (i *timeColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimeColumn := TimeColumn(i.name) newTimeColumn := TimeColumn(i.name)
newTimeColumn.setTableName(i.tableName) newTimeColumn.setTableName(i.tableName)
newTimeColumn.setSubQuery(subQuery) newTimeColumn.setSubQuery(subQuery)
@ -174,7 +174,7 @@ func (i *timeColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime { func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
return i.from(subQuery).(ColumnTime) return i.fromImpl(subQuery).(ColumnTime)
} }
// TimeColumn creates named time column // TimeColumn creates named time column
@ -190,7 +190,7 @@ func TimeColumn(name string) ColumnTime {
// ColumnTimez is interface of SQL time with time zone columns. // ColumnTimez is interface of SQL time with time zone columns.
type ColumnTimez interface { type ColumnTimez interface {
TimezExpression TimezExpression
column Column
From(subQuery SelectTable) ColumnTimez From(subQuery SelectTable) ColumnTimez
} }
@ -201,7 +201,7 @@ type timezColumnImpl struct {
columnImpl columnImpl
} }
func (i *timezColumnImpl) from(subQuery SelectTable) projection { func (i *timezColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimezColumn := TimezColumn(i.name) newTimezColumn := TimezColumn(i.name)
newTimezColumn.setTableName(i.tableName) newTimezColumn.setTableName(i.tableName)
newTimezColumn.setSubQuery(subQuery) newTimezColumn.setSubQuery(subQuery)
@ -210,7 +210,7 @@ func (i *timezColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez { func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
return i.from(subQuery).(ColumnTimez) return i.fromImpl(subQuery).(ColumnTimez)
} }
// TimezColumn creates named time with time zone column. // TimezColumn creates named time with time zone column.
@ -227,7 +227,7 @@ func TimezColumn(name string) ColumnTimez {
// ColumnTimestamp is interface of SQL timestamp columns. // ColumnTimestamp is interface of SQL timestamp columns.
type ColumnTimestamp interface { type ColumnTimestamp interface {
TimestampExpression TimestampExpression
column Column
From(subQuery SelectTable) ColumnTimestamp From(subQuery SelectTable) ColumnTimestamp
} }
@ -238,7 +238,7 @@ type timestampColumnImpl struct {
columnImpl columnImpl
} }
func (i *timestampColumnImpl) from(subQuery SelectTable) projection { func (i *timestampColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimestampColumn := TimestampColumn(i.name) newTimestampColumn := TimestampColumn(i.name)
newTimestampColumn.setTableName(i.tableName) newTimestampColumn.setTableName(i.tableName)
newTimestampColumn.setSubQuery(subQuery) newTimestampColumn.setSubQuery(subQuery)
@ -247,7 +247,7 @@ func (i *timestampColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp { func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
return i.from(subQuery).(ColumnTimestamp) return i.fromImpl(subQuery).(ColumnTimestamp)
} }
// TimestampColumn creates named timestamp column // TimestampColumn creates named timestamp column
@ -264,7 +264,7 @@ func TimestampColumn(name string) ColumnTimestamp {
// ColumnTimestampz is interface of SQL timestamp with timezone columns. // ColumnTimestampz is interface of SQL timestamp with timezone columns.
type ColumnTimestampz interface { type ColumnTimestampz interface {
TimestampzExpression TimestampzExpression
column Column
From(subQuery SelectTable) ColumnTimestampz From(subQuery SelectTable) ColumnTimestampz
} }
@ -275,7 +275,7 @@ type timestampzColumnImpl struct {
columnImpl columnImpl
} }
func (i *timestampzColumnImpl) from(subQuery SelectTable) projection { func (i *timestampzColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimestampzColumn := TimestampzColumn(i.name) newTimestampzColumn := TimestampzColumn(i.name)
newTimestampzColumn.setTableName(i.tableName) newTimestampzColumn.setTableName(i.tableName)
newTimestampzColumn.setSubQuery(subQuery) newTimestampzColumn.setSubQuery(subQuery)
@ -284,7 +284,7 @@ func (i *timestampzColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz { func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
return i.from(subQuery).(ColumnTimestampz) return i.fromImpl(subQuery).(ColumnTimestampz)
} }
// TimestampzColumn creates named timestamp with time zone column. // TimestampzColumn creates named timestamp with time zone column.
@ -301,7 +301,7 @@ func TimestampzColumn(name string) ColumnTimestampz {
// ColumnDate is interface of SQL date columns. // ColumnDate is interface of SQL date columns.
type ColumnDate interface { type ColumnDate interface {
DateExpression DateExpression
column Column
From(subQuery SelectTable) ColumnDate From(subQuery SelectTable) ColumnDate
} }
@ -312,7 +312,7 @@ type dateColumnImpl struct {
columnImpl columnImpl
} }
func (i *dateColumnImpl) from(subQuery SelectTable) projection { func (i *dateColumnImpl) fromImpl(subQuery SelectTable) Projection {
newDateColumn := DateColumn(i.name) newDateColumn := DateColumn(i.name)
newDateColumn.setTableName(i.tableName) newDateColumn.setTableName(i.tableName)
newDateColumn.setSubQuery(subQuery) newDateColumn.setSubQuery(subQuery)
@ -321,7 +321,7 @@ func (i *dateColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate { func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
return i.from(subQuery).(ColumnDate) return i.fromImpl(subQuery).(ColumnDate)
} }
// DateColumn creates named date column. // DateColumn creates named date column.

View file

@ -4,7 +4,9 @@ import (
"testing" "testing"
) )
var subQuery = table1.SELECT(table1ColFloat, table1ColInt).AsTable("sub_query") var subQuery = &selectTableImpl{
alias: "sub_query",
}
func TestNewBoolColumn(t *testing.T) { func TestNewBoolColumn(t *testing.T) {
boolColumn := BoolColumn("colBool").From(subQuery) boolColumn := BoolColumn("colBool").From(subQuery)

View file

@ -1,6 +1,6 @@
package jet package jet
// DateExpression is interface for all SQL date expressions. // DateExpression is interface for date types
type DateExpression interface { type DateExpression interface {
Expression Expression

91
internal/jet/dialect.go Normal file
View file

@ -0,0 +1,91 @@
package jet
// Dialect interface
type Dialect interface {
Name() string
PackageName() string
OperatorSerializeOverride(operator string) SerializeOverride
FunctionSerializeOverride(function string) SerializeOverride
AliasQuoteChar() byte
IdentifierQuoteChar() byte
ArgumentPlaceholder() QueryPlaceholderFunc
}
// SerializeFunc func
type SerializeFunc func(statement StatementType, out *SQLBuilder, options ...SerializeOption)
// SerializeOverride func
type SerializeOverride func(expressions ...Expression) SerializeFunc
// QueryPlaceholderFunc func
type QueryPlaceholderFunc func(ord int) string
// DialectParams struct
type DialectParams struct {
Name string
PackageName string
OperatorSerializeOverrides map[string]SerializeOverride
FunctionSerializeOverrides map[string]SerializeOverride
AliasQuoteChar byte
IdentifierQuoteChar byte
ArgumentPlaceholder QueryPlaceholderFunc
}
// NewDialect creates new dialect with params
func NewDialect(params DialectParams) Dialect {
return &dialectImpl{
name: params.Name,
packageName: params.PackageName,
operatorSerializeOverrides: params.OperatorSerializeOverrides,
functionSerializeOverrides: params.FunctionSerializeOverrides,
aliasQuoteChar: params.AliasQuoteChar,
identifierQuoteChar: params.IdentifierQuoteChar,
argumentPlaceholder: params.ArgumentPlaceholder,
}
}
type dialectImpl struct {
name string
packageName string
operatorSerializeOverrides map[string]SerializeOverride
functionSerializeOverrides map[string]SerializeOverride
aliasQuoteChar byte
identifierQuoteChar byte
argumentPlaceholder QueryPlaceholderFunc
supportsReturning bool
}
func (d *dialectImpl) Name() string {
return d.name
}
func (d *dialectImpl) PackageName() string {
return d.packageName
}
func (d *dialectImpl) OperatorSerializeOverride(operator string) SerializeOverride {
if d.operatorSerializeOverrides == nil {
return nil
}
return d.operatorSerializeOverrides[operator]
}
func (d *dialectImpl) FunctionSerializeOverride(function string) SerializeOverride {
if d.functionSerializeOverrides == nil {
return nil
}
return d.functionSerializeOverrides[function]
}
func (d *dialectImpl) AliasQuoteChar() byte {
return d.aliasQuoteChar
}
func (d *dialectImpl) IdentifierQuoteChar() byte {
return d.identifierQuoteChar
}
func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc {
return d.argumentPlaceholder
}

View file

@ -3,6 +3,7 @@ package jet
type enumValue struct { type enumValue struct {
expressionInterfaceImpl expressionInterfaceImpl
stringInterfaceImpl stringInterfaceImpl
name string name string
} }
@ -10,13 +11,12 @@ type enumValue struct {
func NewEnumValue(name string) StringExpression { func NewEnumValue(name string) StringExpression {
enumValue := &enumValue{name: name} enumValue := &enumValue{name: name}
enumValue.expressionInterfaceImpl.parent = enumValue enumValue.expressionInterfaceImpl.Parent = enumValue
enumValue.stringInterfaceImpl.parent = enumValue enumValue.stringInterfaceImpl.parent = enumValue
return enumValue return enumValue
} }
func (e enumValue) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (e enumValue) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.insertConstantArgument(e.name) out.insertConstantArgument(e.name)
return nil
} }

178
internal/jet/expression.go Normal file
View file

@ -0,0 +1,178 @@
package jet
// Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface {
Serializer
Projection
GroupByClause
OrderByClause
// Test expression whether it is a NULL value.
IS_NULL() BoolExpression
// Test expression whether it is a non-NULL value.
IS_NOT_NULL() BoolExpression
// Check if this expressions matches any in expressions list
IN(expressions ...Expression) BoolExpression
// Check if this expressions is different of all expressions in expressions list
NOT_IN(expressions ...Expression) BoolExpression
// The temporary alias name to assign to the expression
AS(alias string) Projection
// Expression will be used to sort query result in ascending order
ASC() OrderByClause
// Expression will be used to sort query result in ascending order
DESC() OrderByClause
}
type expressionInterfaceImpl struct {
Parent Expression
}
func (e *expressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection {
return e.Parent
}
func (e *expressionInterfaceImpl) IS_NULL() BoolExpression {
return newPostifxBoolExpression(e.Parent, "IS NULL")
}
func (e *expressionInterfaceImpl) IS_NOT_NULL() BoolExpression {
return newPostifxBoolExpression(e.Parent, "IS NOT NULL")
}
func (e *expressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "IN")
}
func (e *expressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "NOT IN")
}
func (e *expressionInterfaceImpl) AS(alias string) Projection {
return newAlias(e.Parent, alias)
}
func (e *expressionInterfaceImpl) ASC() OrderByClause {
return newOrderByClause(e.Parent, true)
}
func (e *expressionInterfaceImpl) DESC() OrderByClause {
return newOrderByClause(e.Parent, false)
}
func (e *expressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap)
}
func (e *expressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap)
}
func (e *expressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap)
}
// Representation of binary operations (e.g. comparisons, arithmetic)
type binaryOpExpression struct {
lhs, rhs Expression
additionalParam Expression
operator string
}
func newBinaryExpression(lhs, rhs Expression, operator string, additionalParam ...Expression) binaryOpExpression {
binaryExpression := binaryOpExpression{
lhs: lhs,
rhs: rhs,
operator: operator,
}
if len(additionalParam) > 0 {
binaryExpression.additionalParam = additionalParam[0]
}
return binaryExpression
}
func (c *binaryOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.lhs == nil {
panic("jet: lhs is nil for '" + c.operator + "' operator")
}
if c.rhs == nil {
panic("jet: rhs is nil for '" + c.operator + "' operator")
}
wrap := !contains(options, noWrap)
if wrap {
out.WriteString("(")
}
if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam)
serializeOverrideFunc(statement, out, options...)
} else {
c.lhs.serialize(statement, out)
out.WriteString(c.operator)
c.rhs.serialize(statement, out)
}
if wrap {
out.WriteString(")")
}
}
// A prefix operator Expression
type prefixOpExpression struct {
expression Expression
operator string
}
func newPrefixExpression(expression Expression, operator string) prefixOpExpression {
prefixExpression := prefixOpExpression{
expression: expression,
operator: operator,
}
return prefixExpression
}
func (p *prefixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
out.WriteString(p.operator)
if p.expression == nil {
panic("jet: nil prefix expression in prefix operator " + p.operator)
}
p.expression.serialize(statement, out)
out.WriteString(")")
}
// A postifx operator Expression
type postfixOpExpression struct {
expression Expression
operator string
}
func newPostfixOpExpression(expression Expression, operator string) postfixOpExpression {
postfixOpExpression := postfixOpExpression{
expression: expression,
operator: operator,
}
return postfixOpExpression
}
func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if p.expression == nil {
panic("jet: nil prefix expression in postfix operator " + p.operator)
}
p.expression.serialize(statement, out)
out.WriteString(p.operator)
}

View file

@ -4,10 +4,13 @@ import (
"testing" "testing"
) )
func TestInvalidExpression(t *testing.T) {
assertClauseSerializeErr(t, table2Col3.ADD(nil), `jet: rhs is nil for '+' operator`)
}
func TestExpressionIS_NULL(t *testing.T) { func TestExpressionIS_NULL(t *testing.T) {
assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL") assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL")
assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL") assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL")
assertClauseSerializeErr(t, table2Col3.ADD(nil), "jet: nil rhs")
} }
func TestExpressionIS_NOT_NULL(t *testing.T) { func TestExpressionIS_NOT_NULL(t *testing.T) {
@ -26,33 +29,14 @@ func TestExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
} }
func TestIN(t *testing.T) { func TestIN(t *testing.T) {
assertClauseSerialize(t, table2ColInt.IN(Int(1), Int(2), Int(3)),
`(table2.col_int IN ($1, $2, $3))`, int64(1), int64(2), int64(3))
assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)),
`($1 IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
} }
func TestNOT_IN(t *testing.T) { func TestNOT_IN(t *testing.T) {
assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), assertClauseSerialize(t, table2ColInt.NOT_IN(Int(1), Int(2), Int(3)),
`($1 NOT IN (( `(table2.col_int NOT IN ($1, $2, $3))`, int64(1), int64(2), int64(3))
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) NOT IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
} }

View file

@ -81,7 +81,7 @@ func (n *floatInterfaceImpl) MOD(expression NumericExpression) FloatExpression {
} }
func (n *floatInterfaceImpl) POW(expression NumericExpression) FloatExpression { func (n *floatInterfaceImpl) POW(expression NumericExpression) FloatExpression {
return newBinaryFloatExpression(n.parent, expression, "^") return POW(n.parent, expression)
} }
//---------------------------------------------------// //---------------------------------------------------//
@ -97,7 +97,7 @@ func newBinaryFloatExpression(lhs, rhs Expression, operator string) FloatExpress
floatExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) floatExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
floatExpression.expressionInterfaceImpl.parent = &floatExpression floatExpression.expressionInterfaceImpl.Parent = &floatExpression
floatExpression.floatInterfaceImpl.parent = &floatExpression floatExpression.floatInterfaceImpl.parent = &floatExpression
return &floatExpression return &floatExpression

View file

@ -70,8 +70,8 @@ func TestFloatExpressionMOD(t *testing.T) {
} }
func TestFloatExpressionPOW(t *testing.T) { func TestFloatExpressionPOW(t *testing.T) {
assertClauseSerialize(t, table1ColFloat.POW(table2ColFloat), "(table1.col_float ^ table2.col_float)") assertClauseSerialize(t, table1ColFloat.POW(table2ColFloat), "POW(table1.col_float, table2.col_float)")
assertClauseSerialize(t, table1ColFloat.POW(Float(2.11)), "(table1.col_float ^ $1)", float64(2.11)) assertClauseSerialize(t, table1ColFloat.POW(Float(2.11)), "POW(table1.col_float, $1)", float64(2.11))
} }
func TestFloatExp(t *testing.T) { func TestFloatExp(t *testing.T) {

View file

@ -1,7 +1,5 @@
package jet package jet
import "errors"
// ROW is construct one table row from list of expressions. // ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression { func ROW(expressions ...Expression) Expression {
return newFunc("ROW", expressions, nil) return newFunc("ROW", expressions, nil)
@ -11,7 +9,7 @@ func ROW(expressions ...Expression) Expression {
// ABSf calculates absolute value from float expression // ABSf calculates absolute value from float expression
func ABSf(floatExpression FloatExpression) FloatExpression { func ABSf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("ABS", floatExpression) return NewFloatFunc("ABS", floatExpression)
} }
// ABSi calculates absolute value from int expression // ABSi calculates absolute value from int expression
@ -19,62 +17,72 @@ func ABSi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("ABS", integerExpression) return newIntegerFunc("ABS", integerExpression)
} }
// POW calculates power of base with exponent
func POW(base, exponent NumericExpression) FloatExpression {
return NewFloatFunc("POW", base, exponent)
}
// POWER calculates power of base with exponent
func POWER(base, exponent NumericExpression) FloatExpression {
return NewFloatFunc("POWER", base, exponent)
}
// SQRT calculates square root of numeric expression // SQRT calculates square root of numeric expression
func SQRT(numericExpression NumericExpression) FloatExpression { func SQRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("SQRT", numericExpression) return NewFloatFunc("SQRT", numericExpression)
} }
// CBRT calculates cube root of numeric expression // CBRT calculates cube root of numeric expression
func CBRT(numericExpression NumericExpression) FloatExpression { func CBRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("CBRT", numericExpression) return NewFloatFunc("CBRT", numericExpression)
} }
// CEIL calculates ceil of float expression // CEIL calculates ceil of float expression
func CEIL(floatExpression FloatExpression) FloatExpression { func CEIL(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("CEIL", floatExpression) return NewFloatFunc("CEIL", floatExpression)
} }
// FLOOR calculates floor of float expression // FLOOR calculates floor of float expression
func FLOOR(floatExpression FloatExpression) FloatExpression { func FLOOR(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("FLOOR", floatExpression) return NewFloatFunc("FLOOR", floatExpression)
} }
// ROUND calculates round of a float expressions with optional precision // ROUND calculates round of a float expressions with optional precision
func ROUND(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression { func ROUND(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression {
if len(precision) > 0 { if len(precision) > 0 {
return newFloatFunc("ROUND", floatExpression, precision[0]) return NewFloatFunc("ROUND", floatExpression, precision[0])
} }
return newFloatFunc("ROUND", floatExpression) return NewFloatFunc("ROUND", floatExpression)
} }
// SIGN returns sign of float expression // SIGN returns sign of float expression
func SIGN(floatExpression FloatExpression) FloatExpression { func SIGN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SIGN", floatExpression) return NewFloatFunc("SIGN", floatExpression)
} }
// TRUNC calculates trunc of float expression with optional precision // TRUNC calculates trunc of float expression with optional precision
func TRUNC(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression { func TRUNC(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression {
if len(precision) > 0 { if len(precision) > 0 {
return newFloatFunc("TRUNC", floatExpression, precision[0]) return NewFloatFunc("TRUNC", floatExpression, precision[0])
} }
return newFloatFunc("TRUNC", floatExpression) return NewFloatFunc("TRUNC", floatExpression)
} }
// LN calculates natural algorithm of float expression // LN calculates natural algorithm of float expression
func LN(floatExpression FloatExpression) FloatExpression { func LN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LN", floatExpression) return NewFloatFunc("LN", floatExpression)
} }
// LOG calculates logarithm of float expression // LOG calculates logarithm of float expression
func LOG(floatExpression FloatExpression) FloatExpression { func LOG(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LOG", floatExpression) return NewFloatFunc("LOG", floatExpression)
} }
// ----------------- Aggregate functions -------------------// // ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression // AVG is aggregate function used to calculate avg value from numeric expression
func AVG(numericExpression NumericExpression) FloatExpression { func AVG(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("AVG", numericExpression) return NewFloatFunc("AVG", numericExpression)
} }
// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none. // BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none.
@ -109,7 +117,7 @@ func EVERY(boolExpression BoolExpression) BoolExpression {
// MAXf is aggregate function. Returns maximum value of float expression across all input values // MAXf is aggregate function. Returns maximum value of float expression across all input values
func MAXf(floatExpression FloatExpression) FloatExpression { func MAXf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MAX", floatExpression) return NewFloatFunc("MAX", floatExpression)
} }
// MAXi is aggregate function. Returns maximum value of int expression across all input values // MAXi is aggregate function. Returns maximum value of int expression across all input values
@ -119,7 +127,7 @@ func MAXi(integerExpression IntegerExpression) IntegerExpression {
// MINf is aggregate function. Returns minimum value of float expression across all input values // MINf is aggregate function. Returns minimum value of float expression across all input values
func MINf(floatExpression FloatExpression) FloatExpression { func MINf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MIN", floatExpression) return NewFloatFunc("MIN", floatExpression)
} }
// MINi is aggregate function. Returns minimum value of int expression across all input values // MINi is aggregate function. Returns minimum value of int expression across all input values
@ -129,7 +137,7 @@ func MINi(integerExpression IntegerExpression) IntegerExpression {
// SUMf is aggregate function. Returns sum of expression across all float expressions // SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) FloatExpression { func SUMf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SUM", floatExpression) return NewFloatFunc("SUM", floatExpression)
} }
// SUMi is aggregate function. Returns sum of expression across all integer expression. // SUMi is aggregate function. Returns sum of expression across all integer expression.
@ -196,14 +204,15 @@ func CHR(integerExpression IntegerExpression) StringExpression {
return newStringFunc("CHR", integerExpression) return newStringFunc("CHR", integerExpression)
} }
// // CONCAT adds two or more expressions together
//func CONCAT(expressions ...Expression) StringExpression { func CONCAT(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT", expressions...) return newStringFunc("CONCAT", expressions...)
//} }
//
//func CONCAT_WS(expressions ...Expression) StringExpression { // CONCAT_WS adds two or more expressions together with a separator.
// return newStringFunc("CONCAT_WS", expressions...) func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression {
//} return newStringFunc("CONCAT_WS", append([]Expression{separator}, expressions...)...)
}
// CONVERT converts string to dest_encoding. The original encoding is // CONVERT converts string to dest_encoding. The original encoding is
// specified by src_encoding. The string must be valid in this encoding. // specified by src_encoding. The string must be valid in this encoding.
@ -235,11 +244,12 @@ func DECODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("DECODE", data, format) return newStringFunc("DECODE", data, format)
} }
//func FORMAT(formatStr StringExpression, formatArgs ...expressions) StringExpression { // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
// args := []expressions{formatStr} func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression {
// args = append(args, formatArgs...) args := []Expression{formatStr}
// return newStringFunc("FORMAT", args...) args = append(args, formatArgs...)
//} return newStringFunc("FORMAT", args...)
}
// INITCAP converts the first letter of each word to upper case // INITCAP converts the first letter of each word to upper case
// and the rest to lower case. Words are sequences of alphanumeric // and the rest to lower case. Words are sequences of alphanumeric
@ -336,6 +346,15 @@ func TO_HEX(number IntegerExpression) StringExpression {
return newStringFunc("TO_HEX", number) return newStringFunc("TO_HEX", number)
} }
// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise.
func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType ...string) BoolExpression {
if len(matchType) > 0 {
return newBoolFunc("REGEXP_LIKE", stringExp, pattern, ConstLiteral(matchType[0]))
}
return newBoolFunc("REGEXP_LIKE", stringExp, pattern)
}
//----------Data Type Formatting Functions ----------------------// //----------Data Type Formatting Functions ----------------------//
// TO_CHAR converts expression to string with format // TO_CHAR converts expression to string with format
@ -350,7 +369,7 @@ func TO_DATE(dateStr, format StringExpression) DateExpression {
// TO_NUMBER converts string to numeric using format // TO_NUMBER converts string to numeric using format
func TO_NUMBER(floatStr, format StringExpression) FloatExpression { func TO_NUMBER(floatStr, format StringExpression) FloatExpression {
return newFloatFunc("TO_NUMBER", floatStr, format) return NewFloatFunc("TO_NUMBER", floatStr, format)
} }
// TO_TIMESTAMP converts string to time stamp with time zone using format // TO_TIMESTAMP converts string to time stamp with time zone using format
@ -372,7 +391,7 @@ func CURRENT_TIME(precision ...int) TimezExpression {
var timezFunc *timezFunc var timezFunc *timezFunc
if len(precision) > 0 { if len(precision) > 0 {
timezFunc = newTimezFunc("CURRENT_TIME", constLiteral(precision[0])) timezFunc = newTimezFunc("CURRENT_TIME", ConstLiteral(precision[0]))
} else { } else {
timezFunc = newTimezFunc("CURRENT_TIME") timezFunc = newTimezFunc("CURRENT_TIME")
} }
@ -387,7 +406,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression {
var timestampzFunc *timestampzFunc var timestampzFunc *timestampzFunc
if len(precision) > 0 { if len(precision) > 0 {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", constLiteral(precision[0])) timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", ConstLiteral(precision[0]))
} else { } else {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP") timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP")
} }
@ -402,7 +421,7 @@ func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc var timeFunc *timeFunc
if len(precision) > 0 { if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", constLiteral(precision[0])) timeFunc = newTimeFunc("LOCALTIME", ConstLiteral(precision[0]))
} else { } else {
timeFunc = newTimeFunc("LOCALTIME") timeFunc = newTimeFunc("LOCALTIME")
} }
@ -417,9 +436,9 @@ func LOCALTIMESTAMP(precision ...int) TimestampExpression {
var timestampFunc *timestampFunc var timestampFunc *timestampFunc
if len(precision) > 0 { if len(precision) > 0 {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP", constLiteral(precision[0])) timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", ConstLiteral(precision[0]))
} else { } else {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP") timestampFunc = NewTimestampFunc("LOCALTIMESTAMP")
} }
timestampFunc.noBrackets = true timestampFunc.noBrackets = true
@ -477,37 +496,34 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
} }
if parent != nil { if parent != nil {
funcExp.expressionInterfaceImpl.parent = parent funcExp.expressionInterfaceImpl.Parent = parent
} else { } else {
funcExp.expressionInterfaceImpl.parent = funcExp funcExp.expressionInterfaceImpl.Parent = funcExp
} }
return funcExp return funcExp
} }
func (f *funcExpressionImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if f == nil { if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil {
return errors.New("jet: Function expressions is nil. ") serializeOverrideFunc := serializeOverride(f.expressions...)
serializeOverrideFunc(statement, out, options...)
return
} }
addBrackets := !f.noBrackets || len(f.expressions) > 0 addBrackets := !f.noBrackets || len(f.expressions) > 0
if addBrackets { if addBrackets {
out.writeString(f.name + "(") out.WriteString(f.name + "(")
} else { } else {
out.writeString(f.name) out.WriteString(f.name)
} }
err := serializeExpressionList(statement, f.expressions, ", ", out) serializeExpressionList(statement, f.expressions, ", ", out)
if err != nil {
return err
}
if addBrackets { if addBrackets {
out.writeString(")") out.WriteString(")")
} }
return nil
} }
type boolFunc struct { type boolFunc struct {
@ -529,7 +545,8 @@ type floatFunc struct {
floatInterfaceImpl floatInterfaceImpl
} }
func newFloatFunc(name string, expressions ...Expression) FloatExpression { // NewFloatFunc creates new float function with name and expressions
func NewFloatFunc(name string, expressions ...Expression) FloatExpression {
floatFunc := &floatFunc{} floatFunc := &floatFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc)
@ -613,7 +630,8 @@ type timestampFunc struct {
timestampInterfaceImpl timestampInterfaceImpl
} }
func newTimestampFunc(name string, expressions ...Expression) *timestampFunc { // NewTimestampFunc creates new timestamp function with name and expressions
func NewTimestampFunc(name string, expressions ...Expression) *timestampFunc {
timestampFunc := &timestampFunc{} timestampFunc := &timestampFunc{}
timestampFunc.funcExpressionImpl = *newFunc(name, expressions, timestampFunc) timestampFunc.funcExpressionImpl = *newFunc(name, expressions, timestampFunc)

View file

@ -0,0 +1,6 @@
package jet
// GroupByClause interface
type GroupByClause interface {
serializeForGroupBy(statement StatementType, out *SQLBuilder)
}

View file

@ -106,7 +106,7 @@ func (i *integerInterfaceImpl) MOD(expression IntegerExpression) IntegerExpressi
} }
func (i *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, expression, "^") return IntExp(POW(i.parent, expression))
} }
func (i *integerInterfaceImpl) BIT_AND(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) BIT_AND(expression IntegerExpression) IntegerExpression {
@ -140,7 +140,7 @@ type binaryIntegerExpression struct {
func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression { func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression {
integerExpression := binaryIntegerExpression{} integerExpression := binaryIntegerExpression{}
integerExpression.expressionInterfaceImpl.parent = &integerExpression integerExpression.expressionInterfaceImpl.Parent = &integerExpression
integerExpression.integerInterfaceImpl.parent = &integerExpression integerExpression.integerInterfaceImpl.parent = &integerExpression
integerExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) integerExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
@ -160,12 +160,30 @@ func newPrefixIntegerOperator(expression IntegerExpression, operator string) Int
integerExpression := prefixIntegerOpExpression{} integerExpression := prefixIntegerOpExpression{}
integerExpression.prefixOpExpression = newPrefixExpression(expression, operator) integerExpression.prefixOpExpression = newPrefixExpression(expression, operator)
integerExpression.expressionInterfaceImpl.parent = &integerExpression integerExpression.expressionInterfaceImpl.Parent = &integerExpression
integerExpression.integerInterfaceImpl.parent = &integerExpression integerExpression.integerInterfaceImpl.parent = &integerExpression
return &integerExpression return &integerExpression
} }
//---------------------------------------------------//
type prefixFloatOpExpression struct {
expressionInterfaceImpl
floatInterfaceImpl
prefixOpExpression
}
func newPrefixFloatOperator(expression FloatExpression, operator string) FloatExpression {
floatOpExpression := prefixFloatOpExpression{}
floatOpExpression.prefixOpExpression = newPrefixExpression(expression, operator)
floatOpExpression.expressionInterfaceImpl.Parent = &floatOpExpression
floatOpExpression.floatInterfaceImpl.parent = &floatOpExpression
return &floatOpExpression
}
//---------------------------------------------------// //---------------------------------------------------//
type integerExpressionWrapper struct { type integerExpressionWrapper struct {
integerInterfaceImpl integerInterfaceImpl

View file

@ -60,13 +60,28 @@ func TestIntExpressionMOD(t *testing.T) {
} }
func TestIntExpressionPOW(t *testing.T) { func TestIntExpressionPOW(t *testing.T) {
assertClauseSerialize(t, table1ColInt.POW(table2ColInt), "(table1.col_int ^ table2.col_int)") assertClauseSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)")
assertClauseSerialize(t, table1ColInt.POW(Int(11)), "(table1.col_int ^ $1)", int64(11)) assertClauseSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, $1)", int64(11))
} }
func TestIntExpressionBIT_NOT(t *testing.T) { func TestIntExpressionBIT_NOT(t *testing.T) {
assertClauseSerialize(t, BIT_NOT(table2ColInt), "~ table2.col_int") assertClauseSerialize(t, BIT_NOT(table2ColInt), "(~ table2.col_int)")
assertClauseSerialize(t, BIT_NOT(Int(11)), "~ $1", int64(11)) assertClauseSerialize(t, BIT_NOT(Int(11)), "(~ 11)")
}
func TestIntExpressionBIT_AND(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BIT_AND(table2ColInt), "(table1.col_int & table2.col_int)")
assertClauseSerialize(t, table1ColInt.BIT_AND(Int(11)), "(table1.col_int & $1)", int64(11))
}
func TestIntExpressionBIT_OR(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BIT_OR(table2ColInt), "(table1.col_int | table2.col_int)")
assertClauseSerialize(t, table1ColInt.BIT_OR(Int(11)), "(table1.col_int | $1)", int64(11))
}
func TestIntExpressionBIT_XOR(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int # table2.col_int)")
assertClauseSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int # $1)", int64(11))
} }
func TestIntExpressionBIT_SHIFT_LEFT(t *testing.T) { func TestIntExpressionBIT_SHIFT_LEFT(t *testing.T) {

View file

@ -14,8 +14,6 @@ var (
type keywordClause string type keywordClause string
func (k keywordClause) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (k keywordClause) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.writeString(string(k)) out.WriteString(string(k))
return nil
} }

View file

@ -0,0 +1,348 @@
package jet
import (
"fmt"
"time"
)
// LiteralExpression is representation of an escaped literal
type LiteralExpression interface {
Expression
Value() interface{}
SetConstant(constant bool)
}
type literalExpressionImpl struct {
expressionInterfaceImpl
value interface{}
constant bool
}
func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl {
exp := literalExpressionImpl{value: value}
if len(optionalConstant) > 0 {
exp.constant = optionalConstant[0]
}
exp.expressionInterfaceImpl.Parent = &exp
return &exp
}
// ConstLiteral is injected directly to SQL query, and does not appear in argument list.
func ConstLiteral(value interface{}) *literalExpressionImpl {
exp := literal(value)
exp.constant = true
return exp
}
func (l *literalExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if l.constant {
out.insertConstantArgument(l.value)
} else {
out.insertParametrizedArgument(l.value)
}
}
func (l *literalExpressionImpl) Value() interface{} {
return l.value
}
func (l *literalExpressionImpl) SetConstant(constant bool) {
l.constant = constant
}
type integerLiteralExpression struct {
literalExpressionImpl
integerInterfaceImpl
}
// Int creates new integer literal
func Int(value int64) IntegerExpression {
numLiteral := &integerLiteralExpression{}
numLiteral.literalExpressionImpl = *literal(value)
numLiteral.literalExpressionImpl.Parent = numLiteral
numLiteral.integerInterfaceImpl.parent = numLiteral
return numLiteral
}
//---------------------------------------------------//
type boolLiteralExpression struct {
boolInterfaceImpl
literalExpressionImpl
}
// Bool creates new bool literal expression
func Bool(value bool) BoolExpression {
boolLiteralExpression := boolLiteralExpression{}
boolLiteralExpression.literalExpressionImpl = *literal(value)
boolLiteralExpression.boolInterfaceImpl.parent = &boolLiteralExpression
return &boolLiteralExpression
}
//---------------------------------------------------//
type floatLiteral struct {
floatInterfaceImpl
literalExpressionImpl
}
// Float creates new float literal
func Float(value float64) FloatExpression {
floatLiteral := floatLiteral{}
floatLiteral.literalExpressionImpl = *literal(value)
floatLiteral.floatInterfaceImpl.parent = &floatLiteral
return &floatLiteral
}
//---------------------------------------------------//
type stringLiteral struct {
stringInterfaceImpl
literalExpressionImpl
}
// String creates new string literal expression
func String(value string) StringExpression {
stringLiteral := stringLiteral{}
stringLiteral.literalExpressionImpl = *literal(value)
stringLiteral.stringInterfaceImpl.parent = &stringLiteral
return &stringLiteral
}
//---------------------------------------------------//
type timeLiteral struct {
timeInterfaceImpl
literalExpressionImpl
}
// Time creates new time literal expression
func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression {
timeLiteral := &timeLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second)
timeStr += formatNanoseconds(nanoseconds...)
timeLiteral.literalExpressionImpl = *literal(timeStr)
timeLiteral.timeInterfaceImpl.parent = timeLiteral
return timeLiteral
}
// TimeT creates new time literal expression from time.Time object
func TimeT(t time.Time) TimeExpression {
timeLiteral := &timeLiteral{}
timeLiteral.literalExpressionImpl = *literal(t)
timeLiteral.timeInterfaceImpl.parent = timeLiteral
return timeLiteral
}
//---------------------------------------------------//
type timezLiteral struct {
timezInterfaceImpl
literalExpressionImpl
}
// Timez creates new time with time zone literal expression
func Timez(hour, minute, second int, nanoseconds time.Duration, timezone string) TimezExpression {
timezLiteral := timezLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second)
timeStr += formatNanoseconds(nanoseconds)
timeStr += " " + timezone
timezLiteral.literalExpressionImpl = *literal(timeStr)
return TimezExp(literal(timeStr))
}
// TimezT creates new time with time zone literal expression from time.Time object
func TimezT(t time.Time) TimezExpression {
timeLiteral := &timezLiteral{}
timeLiteral.literalExpressionImpl = *literal(t)
timeLiteral.timezInterfaceImpl.parent = timeLiteral
return timeLiteral
}
//---------------------------------------------------//
type timestampLiteral struct {
timestampInterfaceImpl
literalExpressionImpl
}
// Timestamp creates new timestamp literal expression
func Timestamp(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) TimestampExpression {
timestamp := &timestampLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second)
timeStr += formatNanoseconds(nanoseconds...)
timestamp.literalExpressionImpl = *literal(timeStr)
timestamp.timestampInterfaceImpl.parent = timestamp
return timestamp
}
// TimestampT creates new timestamp literal expression from time.Time object
func TimestampT(t time.Time) TimestampExpression {
timestamp := &timestampLiteral{}
timestamp.literalExpressionImpl = *literal(t)
timestamp.timestampInterfaceImpl.parent = timestamp
return timestamp
}
//---------------------------------------------------//
type timestampzLiteral struct {
timestampzInterfaceImpl
literalExpressionImpl
}
// Timestampz creates new timestamp with time zone literal expression
func Timestampz(year int, month time.Month, day, hour, minute, second int, nanoseconds time.Duration, timezone string) TimestampzExpression {
timestamp := &timestampzLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second)
timeStr += formatNanoseconds(nanoseconds)
timeStr += " " + timezone
timestamp.literalExpressionImpl = *literal(timeStr)
timestamp.timestampzInterfaceImpl.parent = timestamp
return timestamp
}
// TimestampzT creates new timestamp literal expression from time.Time object
func TimestampzT(t time.Time) TimestampzExpression {
timestamp := &timestampzLiteral{}
timestamp.literalExpressionImpl = *literal(t)
timestamp.timestampzInterfaceImpl.parent = timestamp
return timestamp
}
//---------------------------------------------------//
type dateLiteral struct {
dateInterfaceImpl
literalExpressionImpl
}
// Date creates new date literal expression
func Date(year int, month time.Month, day int) DateExpression {
dateLiteral := &dateLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d", year, month, day)
dateLiteral.literalExpressionImpl = *literal(timeStr)
dateLiteral.dateInterfaceImpl.parent = dateLiteral
return dateLiteral
}
// DateT creates new date literal expression from time.Time object
func DateT(t time.Time) DateExpression {
dateLiteral := &dateLiteral{}
dateLiteral.literalExpressionImpl = *literal(t)
dateLiteral.dateInterfaceImpl.parent = dateLiteral
return dateLiteral
}
func formatNanoseconds(nanoseconds ...time.Duration) string {
if len(nanoseconds) > 0 && nanoseconds[0] != 0 {
duration := fmt.Sprintf("%09d", nanoseconds[0])
i := len(duration) - 1
for ; i >= 3; i-- {
if duration[i] != '0' {
break
}
}
return "." + duration[0:i+1]
}
return ""
}
//--------------------------------------------------//
type nullLiteral struct {
expressionInterfaceImpl
}
func newNullLiteral() Expression {
nullExpression := &nullLiteral{}
nullExpression.expressionInterfaceImpl.Parent = nullExpression
return nullExpression
}
func (n *nullLiteral) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("NULL")
}
//--------------------------------------------------//
type starLiteral struct {
expressionInterfaceImpl
}
func newStarLiteral() Expression {
starExpression := &starLiteral{}
starExpression.expressionInterfaceImpl.Parent = starExpression
return starExpression
}
func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("*")
}
//---------------------------------------------------//
type wrap struct {
expressionInterfaceImpl
expressions []Expression
}
func (n *wrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
serializeExpressionList(statement, n.expressions, ", ", out)
out.WriteString(")")
}
// WRAP wraps list of expressions with brackets '(' and ')'
func WRAP(expression ...Expression) Expression {
wrap := &wrap{expressions: expression}
wrap.expressionInterfaceImpl.Parent = wrap
return wrap
}
//---------------------------------------------------//
type rawExpression struct {
expressionInterfaceImpl
raw string
}
func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(n.raw)
}
// Raw can be used for any unsupported functions, operators or expressions.
// For example: Raw("current_database()")
func Raw(raw string) Expression {
rawExp := &rawExpression{raw: raw}
rawExp.expressionInterfaceImpl.Parent = rawExp
return rawExp
}

View file

@ -0,0 +1,60 @@
package jet
import (
"testing"
"time"
)
func TestRawExpression(t *testing.T) {
assertClauseSerialize(t, Raw("current_database()"), "current_database()")
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
assertClauseSerialize(t, DateT(timeT), "$1", timeT)
}
func TestTimeLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Time(11, 5, 30), "'11:05:30'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 0), "'11:05:30'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 3*time.Millisecond), "'11:05:30.003'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 30*time.Millisecond), "'11:05:30.030'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 300*time.Millisecond), "'11:05:30.300'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 300*time.Microsecond), "'11:05:30.0003'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 4*time.Nanosecond), "'11:05:30.000000004'")
}
func TestTimeT(t *testing.T) {
timeT := time.Date(2000, 1, 1, 11, 40, 20, 124, time.UTC)
assertClauseDebugSerialize(t, TimeT(timeT), `'2000-01-01 11:40:20.000000124Z'`)
}
func TestTimezLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Timez(11, 5, 30, 10*time.Nanosecond, "UTC"), "'11:05:30.00000001 UTC'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 0, "+1"), "'11:05:30 +1'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 3*time.Microsecond, "-7"), "'11:05:30.000003 -7'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 30*time.Millisecond, "+8:00"), "'11:05:30.030 +8:00'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 300*time.Nanosecond, "America/New_Yor"), "'11:05:30.0000003 America/New_Yor'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 3000*time.Nanosecond, "zulu"), "'11:05:30.000003 zulu'")
}
func TestTimestampLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Timestamp(2011, 1, 8, 11, 5, 30), "'2011-01-08 11:05:30'")
assertClauseDebugSerialize(t, Timestamp(2011, 2, 7, 11, 5, 30, 0), "'2011-02-07 11:05:30'")
assertClauseDebugSerialize(t, Timestamp(2011, 3, 6, 11, 5, 30, 3*time.Millisecond), "'2011-03-06 11:05:30.003'")
assertClauseDebugSerialize(t, Timestamp(2011, 4, 5, 11, 5, 30, 30*time.Millisecond), "'2011-04-05 11:05:30.030'")
assertClauseDebugSerialize(t, Timestamp(2011, 5, 4, 11, 5, 30, 300*time.Millisecond), "'2011-05-04 11:05:30.300'")
assertClauseDebugSerialize(t, Timestamp(2011, 6, 3, 11, 5, 30, 3000*time.Microsecond), "'2011-06-03 11:05:30.003'")
}
func TestTimestampzLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Timestampz(2011, 1, 8, 11, 5, 30, 0, "UTC"), "'2011-01-08 11:05:30 UTC'")
assertClauseDebugSerialize(t, Timestampz(2011, 2, 7, 11, 5, 30, 0, "PST"), "'2011-02-07 11:05:30 PST'")
assertClauseDebugSerialize(t, Timestampz(2011, 3, 6, 11, 5, 30, 3, "+4:00"), "'2011-03-06 11:05:30.000000003 +4:00'")
assertClauseDebugSerialize(t, Timestampz(2011, 4, 5, 11, 5, 30, 30, "-8:00"), "'2011-04-05 11:05:30.00000003 -8:00'")
assertClauseDebugSerialize(t, Timestampz(2011, 5, 4, 11, 5, 30, 300, "400"), "'2011-05-04 11:05:30.0000003 400'")
assertClauseDebugSerialize(t, Timestampz(2011, 6, 3, 11, 5, 30, 3000, "zulu"), "'2011-06-03 11:05:30.000003 zulu'")
}
func TestDate(t *testing.T) {
assertClauseDebugSerialize(t, Date(2019, 8, 8), `'2019-08-08'`)
}

View file

@ -1,6 +1,11 @@
package jet package jet
import "errors" // Operators
const (
StringConcatOperator = "||"
StringRegexpLikeOperator = "REGEXP"
StringNotRegexpLikeOperator = "NOT REGEXP"
)
//----------- Logical operators ---------------// //----------- Logical operators ---------------//
@ -11,13 +16,16 @@ func NOT(exp BoolExpression) BoolExpression {
// BIT_NOT inverts every bit in integer expression result // BIT_NOT inverts every bit in integer expression result
func BIT_NOT(expr IntegerExpression) IntegerExpression { func BIT_NOT(expr IntegerExpression) IntegerExpression {
if literalExp, ok := expr.(LiteralExpression); ok {
literalExp.SetConstant(true)
}
return newPrefixIntegerOperator(expr, "~") return newPrefixIntegerOperator(expr, "~")
} }
//----------- Comparison operators ---------------// //----------- Comparison operators ---------------//
// EXISTS checks for existence of the rows in subQuery // EXISTS checks for existence of the rows in subQuery
func EXISTS(subQuery SelectStatement) BoolExpression { func EXISTS(subQuery Expression) BoolExpression {
return newPrefixBoolOperator(subQuery, "EXISTS") return newPrefixBoolOperator(subQuery, "EXISTS")
} }
@ -87,7 +95,7 @@ func CASE(expression ...Expression) CaseOperator {
caseExp.expression = expression[0] caseExp.expression = expression[0]
} }
caseExp.expressionInterfaceImpl.parent = caseExp caseExp.expressionInterfaceImpl.Parent = caseExp
return caseExp return caseExp
} }
@ -108,55 +116,33 @@ func (c *caseOperatorImpl) ELSE(els Expression) CaseOperator {
return c return c
} }
func (c *caseOperatorImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (c *caseOperatorImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if c == nil { out.WriteString("(CASE")
return errors.New("jet: Case Expression is nil. ")
}
out.writeString("(CASE")
if c.expression != nil { if c.expression != nil {
err := c.expression.serialize(statement, out) c.expression.serialize(statement, out)
if err != nil {
return err
}
} }
if len(c.when) == 0 || len(c.then) == 0 { if len(c.when) == 0 || len(c.then) == 0 {
return errors.New("jet: Invalid case Statement. There should be at least one when/then Expression pair. ") panic("jet: invalid case Statement. There should be at least one WHEN/THEN pair. ")
} }
if len(c.when) != len(c.then) { if len(c.when) != len(c.then) {
return errors.New("jet: When and then Expression count mismatch. ") panic("jet: WHEN and THEN expression count mismatch. ")
} }
for i, when := range c.when { for i, when := range c.when {
out.writeString("WHEN") out.WriteString("WHEN")
err := when.serialize(statement, out, noWrap) when.serialize(statement, out, noWrap)
if err != nil { out.WriteString("THEN")
return err c.then[i].serialize(statement, out, noWrap)
}
out.writeString("THEN")
err = c.then[i].serialize(statement, out, noWrap)
if err != nil {
return err
}
} }
if c.els != nil { if c.els != nil {
out.writeString("ELSE") out.WriteString("ELSE")
err := c.els.serialize(statement, out, noWrap) c.els.serialize(statement, out, noWrap)
if err != nil {
return err
}
} }
out.writeString("END)") out.WriteString("END)")
return nil
} }

View file

@ -5,10 +5,10 @@ import "testing"
func TestOperatorNOT(t *testing.T) { func TestOperatorNOT(t *testing.T) {
notExpression := NOT(Int(2).EQ(Int(1))) notExpression := NOT(Int(2).EQ(Int(1)))
assertClauseSerialize(t, NOT(table1ColBool), "NOT table1.col_bool") assertClauseSerialize(t, NOT(table1ColBool), "(NOT table1.col_bool)")
assertClauseSerialize(t, notExpression, "NOT ($1 = $2)", int64(2), int64(1)) assertClauseSerialize(t, notExpression, "(NOT ($1 = $2))", int64(2), int64(1))
assertProjectionSerialize(t, notExpression.AS("alias_not_expression"), `NOT ($1 = $2) AS "alias_not_expression"`, int64(2), int64(1)) assertProjectionSerialize(t, notExpression.AS("alias_not_expression"), `(NOT ($1 = $2)) AS "alias_not_expression"`, int64(2), int64(1))
assertClauseSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `(NOT ($1 = $2) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5)) assertClauseSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `((NOT ($1 = $2)) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5))
} }
func TestCase1(t *testing.T) { func TestCase1(t *testing.T) {

View file

@ -0,0 +1,29 @@
package jet
// OrderByClause interface
type OrderByClause interface {
serializeForOrderBy(statement StatementType, out *SQLBuilder)
}
type orderByClauseImpl struct {
expression Expression
ascent bool
}
func (o *orderByClauseImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) {
if o.expression == nil {
panic("jet: nil expression in ORDER BY clause")
}
o.expression.serializeForOrderBy(statement, out)
if o.ascent {
out.WriteString("ASC")
} else {
out.WriteString("DESC")
}
}
func newOrderByClause(expression Expression, ascent bool) OrderByClause {
return &orderByClauseImpl{expression: expression, ascent: ascent}
}

View file

@ -0,0 +1,29 @@
package jet
// Projection is interface for all projection types. Types that can be part of, for instance SELECT clause.
type Projection interface {
serializeForProjection(statement StatementType, out *SQLBuilder)
fromImpl(subQuery SelectTable) Projection
}
// SerializeForProjection is helper function for serializing projection outside of jet package
func SerializeForProjection(projection Projection, statementType StatementType, out *SQLBuilder) {
projection.serializeForProjection(statementType, out)
}
// ProjectionList is a redefined type, so that ProjectionList can be used as a Projection.
type ProjectionList []Projection
func (cl ProjectionList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, projection := range cl {
newProjectionList = append(newProjectionList, projection.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ProjectionList) serializeForProjection(statement StatementType, out *SQLBuilder) {
SerializeProjectionList(statement, cl, out)
}

View file

@ -0,0 +1,47 @@
package jet
// RowLock is interface for SELECT statement row lock types
type RowLock interface {
Serializer
NOWAIT() RowLock
SKIP_LOCKED() RowLock
}
type selectLockImpl struct {
lockStrength string
noWait, skipLocked bool
}
// NewRowLock creates new RowLock
func NewRowLock(name string) func() RowLock {
return func() RowLock {
return newSelectLock(name)
}
}
func newSelectLock(lockStrength string) RowLock {
return &selectLockImpl{lockStrength: lockStrength}
}
func (s *selectLockImpl) NOWAIT() RowLock {
s.noWait = true
return s
}
func (s *selectLockImpl) SKIP_LOCKED() RowLock {
s.skipLocked = true
return s
}
func (s *selectLockImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(s.lockStrength)
if s.noWait {
out.WriteString("NOWAIT")
}
if s.skipLocked {
out.WriteString("SKIP LOCKED")
}
}

View file

@ -0,0 +1,44 @@
package jet
// SelectTable is interface for SELECT sub-queries
type SelectTable interface {
Serializer
Alias() string
AllColumns() ProjectionList
}
type selectTableImpl struct {
selectStmt StatementWithProjections
alias string
projections ProjectionList
}
// NewSelectTable func
func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTable {
selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias}
projectionList := selectStmt.projections().fromImpl(&selectTable)
selectTable.projections = projectionList.(ProjectionList)
return &selectTable
}
func (s *selectTableImpl) Alias() string {
return s.alias
}
func (s *selectTableImpl) AllColumns() ProjectionList {
return s.projections
}
func (s *selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if s == nil {
panic("jet: expression table is nil. ")
}
s.selectStmt.serialize(statement, out)
out.WriteString("AS")
out.WriteIdentifier(s.alias)
}

View file

@ -0,0 +1,43 @@
package jet
// SerializeOption type
type SerializeOption int
// Serialize options
const (
noWrap SerializeOption = iota
)
// StatementType is type of the SQL statement
type StatementType string
// Statement types
const (
SelectStatementType StatementType = "SELECT"
InsertStatementType StatementType = "INSERT"
UpdateStatementType StatementType = "UPDATE"
DeleteStatementType StatementType = "DELETE"
SetStatementType StatementType = "SET"
LockStatementType StatementType = "LOCK"
UnLockStatementType StatementType = "UNLOCK"
)
// Serializer interface
type Serializer interface {
serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption)
}
// Serialize func
func Serialize(exp Serializer, statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
exp.serialize(statementType, out, options...)
}
func contains(options []SerializeOption, option SerializeOption) bool {
for _, opt := range options {
if opt == option {
return true
}
}
return false
}

View file

@ -21,8 +21,10 @@ func TestArgToString(t *testing.T) {
assert.Equal(t, argToString(uint(32)), "32") assert.Equal(t, argToString(uint(32)), "32")
assert.Equal(t, argToString(uint32(32)), "32") assert.Equal(t, argToString(uint32(32)), "32")
assert.Equal(t, argToString(uint64(64)), "64") assert.Equal(t, argToString(uint64(64)), "64")
assert.Equal(t, argToString(float64(1.11)), "1.11")
assert.Equal(t, argToString("john"), "'john'") assert.Equal(t, argToString("john"), "'john'")
assert.Equal(t, argToString("It's text"), "'It''s text'")
assert.Equal(t, argToString([]byte("john")), "'john'") assert.Equal(t, argToString([]byte("john")), "'john'")
assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'") assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")

196
internal/jet/sql_builder.go Normal file
View file

@ -0,0 +1,196 @@
package jet
import (
"bytes"
"github.com/go-jet/jet/internal/utils"
"github.com/google/uuid"
"strconv"
"strings"
"time"
"unicode"
)
// SQLBuilder generates output SQL
type SQLBuilder struct {
Dialect Dialect
Buff bytes.Buffer
Args []interface{}
lastChar byte
ident int
debug bool
}
const defaultIdent = 5
// IncreaseIdent adds ident or defaultIdent number of spaces to each new line
func (s *SQLBuilder) IncreaseIdent(ident ...int) {
if len(ident) > 0 {
s.ident += ident[0]
} else {
s.ident += defaultIdent
}
}
// DecreaseIdent removes ident or defaultIdent number of spaces for each new line
func (s *SQLBuilder) DecreaseIdent(ident ...int) {
toDecrease := defaultIdent
if len(ident) > 0 {
toDecrease = ident[0]
}
if s.ident < toDecrease {
s.ident = 0
}
s.ident -= toDecrease
}
// WriteProjections func
func (s *SQLBuilder) WriteProjections(statement StatementType, projections []Projection) {
s.IncreaseIdent()
SerializeProjectionList(statement, projections, s)
s.DecreaseIdent()
}
// NewLine adds new line to output SQL
func (s *SQLBuilder) NewLine() {
s.write([]byte{'\n'})
s.write(bytes.Repeat([]byte{' '}, s.ident))
}
func (s *SQLBuilder) write(data []byte) {
if len(data) == 0 {
return
}
if !isPreSeparator(s.lastChar) && !isPostSeparator(data[0]) && s.Buff.Len() > 0 {
s.Buff.WriteByte(' ')
}
s.Buff.Write(data)
s.lastChar = data[len(data)-1]
}
func isPreSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':'
}
func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
}
// WriteAlias is used to add alias to output SQL
func (s *SQLBuilder) WriteAlias(str string) {
aliasQuoteChar := string(s.Dialect.AliasQuoteChar())
s.WriteString(aliasQuoteChar + str + aliasQuoteChar)
}
// WriteString writes sting to output SQL
func (s *SQLBuilder) WriteString(str string) {
s.write([]byte(str))
}
// WriteIdentifier adds identifier to output SQL
func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
if shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 {
identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
s.WriteString(identQuoteChar + name + identQuoteChar)
} else {
s.WriteString(name)
}
}
// WriteByte writes byte to output SQL
func (s *SQLBuilder) WriteByte(b byte) {
s.write([]byte{b})
}
func (s *SQLBuilder) finalize() (string, []interface{}) {
return s.Buff.String() + ";\n", s.Args
}
func (s *SQLBuilder) insertConstantArgument(arg interface{}) {
s.WriteString(argToString(arg))
}
func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) {
if s.debug {
s.insertConstantArgument(arg)
return
}
s.Args = append(s.Args, arg)
argPlaceholder := s.Dialect.ArgumentPlaceholder()(len(s.Args))
s.WriteString(argPlaceholder)
}
func argToString(value interface{}) string {
if utils.IsNil(value) {
return "NULL"
}
switch bindVal := value.(type) {
case bool:
if bindVal {
return "TRUE"
}
return "FALSE"
case int8:
return strconv.FormatInt(int64(bindVal), 10)
case int:
return strconv.FormatInt(int64(bindVal), 10)
case int16:
return strconv.FormatInt(int64(bindVal), 10)
case int32:
return strconv.FormatInt(int64(bindVal), 10)
case int64:
return strconv.FormatInt(bindVal, 10)
case uint8:
return strconv.FormatUint(uint64(bindVal), 10)
case uint:
return strconv.FormatUint(uint64(bindVal), 10)
case uint16:
return strconv.FormatUint(uint64(bindVal), 10)
case uint32:
return strconv.FormatUint(uint64(bindVal), 10)
case uint64:
return strconv.FormatUint(uint64(bindVal), 10)
case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case float64:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case string:
return stringQuote(bindVal)
case []byte:
return stringQuote(string(bindVal))
case uuid.UUID:
return stringQuote(bindVal.String())
case time.Time:
return stringQuote(string(utils.FormatTimestamp(bindVal)))
default:
return "[Unsupported type]"
}
}
func shouldQuoteIdentifier(identifier string) bool {
for _, c := range identifier {
if unicode.IsNumber(c) || c == '_' {
continue
}
if c > unicode.MaxASCII || !unicode.IsLetter(c) || unicode.IsUpper(c) {
return true
}
}
return false
}
func stringQuote(value string) string {
return `'` + strings.Replace(value, "'", "''", -1) + `'`
}

172
internal/jet/statement.go Normal file
View file

@ -0,0 +1,172 @@
package jet
import (
"context"
"database/sql"
"github.com/go-jet/jet/execution"
)
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
type Statement interface {
// Sql returns parametrized sql query with list of arguments.
Sql() (query string, args []interface{})
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument.
// Do not use it in production. Use it only for debug purposes.
DebugSql() (query string)
// Query executes statement over database connection db and stores row result in destination.
// Destination can be arbitrary structure
Query(db execution.DB, destination interface{}) error
// QueryContext executes statement with a context over database connection db and stores row result in destination.
// Destination can be of arbitrary structure
QueryContext(context context.Context, db execution.DB, destination interface{}) error
//Exec executes statement over db connection without returning any rows.
Exec(db execution.DB) (sql.Result, error)
//Exec executes statement with context over db connection without returning any rows.
ExecContext(context context.Context, db execution.DB) (sql.Result, error)
}
// SerializerStatement interface
type SerializerStatement interface {
Serializer
Statement
}
// StatementWithProjections interface
type StatementWithProjections interface {
Statement
HasProjections
Serializer
}
// HasProjections interface
type HasProjections interface {
projections() ProjectionList
}
// serializerStatementInterfaceImpl struct
type serializerStatementInterfaceImpl struct {
dialect Dialect
statementType StatementType
parent SerializerStatement
}
func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface{}) {
queryData := &SQLBuilder{Dialect: s.dialect}
s.parent.serialize(s.statementType, queryData, noWrap)
query, args = queryData.finalize()
return
}
func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
sqlBuilder := &SQLBuilder{Dialect: s.dialect, debug: true}
s.parent.serialize(s.statementType, sqlBuilder, noWrap)
query, _ = sqlBuilder.finalize()
return
}
func (s *serializerStatementInterfaceImpl) Query(db execution.DB, destination interface{}) error {
query, args := s.Sql()
return execution.Query(context.Background(), db, query, args, destination)
}
func (s *serializerStatementInterfaceImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
query, args := s.Sql()
return execution.Query(context, db, query, args, destination)
}
func (s *serializerStatementInterfaceImpl) Exec(db execution.DB) (res sql.Result, err error) {
query, args := s.Sql()
return db.Exec(query, args...)
}
func (s *serializerStatementInterfaceImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
query, args := s.Sql()
return db.ExecContext(context, query, args...)
}
// ExpressionStatement interfacess
type ExpressionStatement interface {
Expression
Statement
HasProjections
}
// NewExpressionStatementImpl creates new expression statement
func NewExpressionStatementImpl(Dialect Dialect, statementType StatementType, parent ExpressionStatement, clauses ...Clause) ExpressionStatement {
return &expressionStatementImpl{
expressionInterfaceImpl{Parent: parent},
statementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
parent: parent,
dialect: Dialect,
statementType: statementType,
},
Clauses: clauses,
},
}
}
type expressionStatementImpl struct {
expressionInterfaceImpl
statementImpl
}
func (s *expressionStatementImpl) serializeForProjection(statement StatementType, out *SQLBuilder) {
s.serialize(statement, out)
}
// NewStatementImpl creates new statementImpl
func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) SerializerStatement {
return &statementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
parent: parent,
dialect: Dialect,
statementType: statementType,
},
Clauses: clauses,
}
}
type statementImpl struct {
serializerStatementInterfaceImpl
Clauses []Clause
}
func (s *statementImpl) projections() ProjectionList {
for _, clause := range s.Clauses {
if selectClause, ok := clause.(ClauseWithProjections); ok {
return selectClause.projections()
}
}
return nil
}
func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, noWrap) {
out.WriteString("(")
out.IncreaseIdent()
}
for _, clause := range s.Clauses {
clause.Serialize(statement, out)
}
if !contains(options, noWrap) {
out.DecreaseIdent()
out.NewLine()
out.WriteString(")")
}
}

View file

@ -18,8 +18,9 @@ type StringExpression interface {
LIKE(pattern StringExpression) BoolExpression LIKE(pattern StringExpression) BoolExpression
NOT_LIKE(pattern StringExpression) BoolExpression NOT_LIKE(pattern StringExpression) BoolExpression
SIMILAR_TO(pattern StringExpression) BoolExpression
NOT_SIMILAR_TO(pattern StringExpression) BoolExpression REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression
NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression
} }
type stringInterfaceImpl struct { type stringInterfaceImpl struct {
@ -59,7 +60,7 @@ func (s *stringInterfaceImpl) LT_EQ(rhs StringExpression) BoolExpression {
} }
func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression { func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression {
return newBinaryStringExpression(s.parent, rhs, "||") return newBinaryStringExpression(s.parent, rhs, StringConcatOperator)
} }
func (s *stringInterfaceImpl) LIKE(pattern StringExpression) BoolExpression { func (s *stringInterfaceImpl) LIKE(pattern StringExpression) BoolExpression {
@ -70,15 +71,16 @@ func (s *stringInterfaceImpl) NOT_LIKE(pattern StringExpression) BoolExpression
return newBinaryBoolOperator(s.parent, pattern, "NOT LIKE") return newBinaryBoolOperator(s.parent, pattern, "NOT LIKE")
} }
func (s *stringInterfaceImpl) SIMILAR_TO(pattern StringExpression) BoolExpression { func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, "SIMILAR TO") return newBinaryBoolOperator(s.parent, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0]))
} }
func (s *stringInterfaceImpl) NOT_SIMILAR_TO(pattern StringExpression) BoolExpression { func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, "NOT SIMILAR TO") return newBinaryBoolOperator(s.parent, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0]))
} }
//---------------------------------------------------// //---------------------------------------------------//
type binaryStringExpression struct { type binaryStringExpression struct {
expressionInterfaceImpl expressionInterfaceImpl
stringInterfaceImpl stringInterfaceImpl
@ -90,7 +92,7 @@ func newBinaryStringExpression(lhs, rhs Expression, operator string) StringExpre
boolExpression := binaryStringExpression{} boolExpression := binaryStringExpression{}
boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
boolExpression.expressionInterfaceImpl.parent = &boolExpression boolExpression.expressionInterfaceImpl.Parent = &boolExpression
boolExpression.stringInterfaceImpl.parent = &boolExpression boolExpression.stringInterfaceImpl.parent = &boolExpression
return &boolExpression return &boolExpression

View file

@ -66,14 +66,14 @@ func TestStringNOT_LIKE(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_LIKE(String("JOHN")), "(table3.col2 NOT LIKE $1)", "JOHN") assertClauseSerialize(t, table3StrCol.NOT_LIKE(String("JOHN")), "(table3.col2 NOT LIKE $1)", "JOHN")
} }
func TestStringSIMILAR_TO(t *testing.T) { func TestStringREGEXP_LIKE(t *testing.T) {
assertClauseSerialize(t, table3StrCol.SIMILAR_TO(table2ColStr), "(table3.col2 SIMILAR TO table2.col_str)") assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)")
assertClauseSerialize(t, table3StrCol.SIMILAR_TO(String("JOHN")), "(table3.col2 SIMILAR TO $1)", "JOHN") assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 REGEXP $1)", "JOHN")
} }
func TestStringNOT_SIMILAR_TO(t *testing.T) { func TestStringNOT_REGEXP_LIKE(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_SIMILAR_TO(table2ColStr), "(table3.col2 NOT SIMILAR TO table2.col_str)") assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)")
assertClauseSerialize(t, table3StrCol.NOT_SIMILAR_TO(String("JOHN")), "(table3.col2 NOT SIMILAR TO $1)", "JOHN") assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP $1)", "JOHN")
} }
func TestStringExp(t *testing.T) { func TestStringExp(t *testing.T) {

188
internal/jet/table.go Normal file
View file

@ -0,0 +1,188 @@
package jet
import (
"github.com/go-jet/jet/internal/utils"
)
// SerializerTable interface
type SerializerTable interface {
Serializer
Table
}
// Table interface
type Table interface {
columns() []Column
SchemaName() string
TableName() string
AS(alias string)
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, columns ...ColumnExpression) SerializerTable {
t := tableImpl{
schemaName: schemaName,
name: name,
columnList: columns,
}
for _, c := range columns {
c.setTableName(name)
}
return &t
}
type tableImpl struct {
schemaName string
name string
alias string
columnList []ColumnExpression
}
func (t *tableImpl) AS(alias string) {
t.alias = alias
for _, c := range t.columnList {
c.setTableName(alias)
}
}
func (t *tableImpl) SchemaName() string {
return t.schemaName
}
func (t *tableImpl) TableName() string {
return t.name
}
func (t *tableImpl) columns() []Column {
ret := []Column{}
for _, col := range t.columnList {
ret = append(ret, col)
}
return ret
}
func (t *tableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if t == nil {
panic("jet: tableImpl is nil")
}
out.WriteIdentifier(t.schemaName)
out.WriteString(".")
out.WriteIdentifier(t.name)
if len(t.alias) > 0 {
out.WriteString("AS")
out.WriteIdentifier(t.alias)
}
}
// JoinType is type of table join
type JoinType int
// Table join types
const (
InnerJoin JoinType = iota
LeftJoin
RightJoin
FullJoin
CrossJoin
)
// Join expressions are pseudo readable tables.
type joinTableImpl struct {
lhs Serializer
rhs Serializer
joinType JoinType
onCondition BoolExpression
}
// JoinTable interface
type JoinTable SerializerTable
// NewJoinTable creates new join table
func NewJoinTable(lhs Serializer, rhs Serializer, joinType JoinType, onCondition BoolExpression) JoinTable {
joinTable := joinTableImpl{
lhs: lhs,
rhs: rhs,
joinType: joinType,
onCondition: onCondition,
}
return &joinTable
}
func (t *joinTableImpl) SchemaName() string {
if table, ok := t.lhs.(Table); ok {
return table.SchemaName()
}
return ""
}
func (t *joinTableImpl) TableName() string {
return ""
}
func (t *joinTableImpl) AS(alias string) {
}
func (t *joinTableImpl) columns() []Column {
var ret []Column
if lhsTable, ok := t.lhs.(Table); ok {
ret = append(ret, lhsTable.columns()...)
}
if rhsTable, ok := t.rhs.(Table); ok {
ret = append(ret, rhsTable.columns()...)
}
return ret
}
func (t *joinTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if t == nil {
panic("jet: Join table is nil. ")
}
if utils.IsNil(t.lhs) {
panic("jet: left hand side of join operation is nil table")
}
t.lhs.serialize(statement, out)
out.NewLine()
switch t.joinType {
case InnerJoin:
out.WriteString("INNER JOIN")
case LeftJoin:
out.WriteString("LEFT JOIN")
case RightJoin:
out.WriteString("RIGHT JOIN")
case FullJoin:
out.WriteString("FULL JOIN")
case CrossJoin:
out.WriteString("CROSS JOIN")
}
if utils.IsNil(t.rhs) {
panic("jet: right hand side of join operation is nil table")
}
t.rhs.serialize(statement, out)
if t.onCondition == nil && t.joinType != CrossJoin {
panic("jet: join condition is nil")
}
if t.onCondition != nil {
out.WriteString("ON")
t.onCondition.serialize(statement, out)
}
}

85
internal/jet/testutils.go Normal file
View file

@ -0,0 +1,85 @@
package jet
import (
"gotest.tools/assert"
"strconv"
"testing"
)
var defaultDialect = NewDialect(DialectParams{ // just for tests
AliasQuoteChar: '"',
IdentifierQuoteChar: '"',
ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord)
},
})
var table1Col1 = IntegerColumn("col1")
var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float")
var table1Col3 = IntegerColumn("col3")
var table1ColTime = TimeColumn("col_time")
var table1ColTimez = TimezColumn("col_timez")
var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColTimestampz = TimestampzColumn("col_timestampz")
var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date")
var table1 = NewTable("db", "table1", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz)
var table2Col3 = IntegerColumn("col3")
var table2Col4 = IntegerColumn("col4")
var table2ColInt = IntegerColumn("col_int")
var table2ColFloat = FloatColumn("col_float")
var table2ColStr = StringColumn("col_str")
var table2ColBool = BoolColumn("col_bool")
var table2ColTime = TimeColumn("col_time")
var table2ColTimez = TimezColumn("col_timez")
var table2ColTimestamp = TimestampColumn("col_timestamp")
var table2ColTimestampz = TimestampzColumn("col_timestampz")
var table2ColDate = DateColumn("col_date")
var table2 = NewTable("db", "table2", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz)
var table3Col1 = IntegerColumn("col1")
var table3ColInt = IntegerColumn("col_int")
var table3StrCol = StringColumn("col2")
var table3 = NewTable("db", "table3", table3Col1, table3ColInt, table3StrCol)
func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) {
out := SQLBuilder{Dialect: defaultDialect}
clause.serialize(SelectStatementType, &out)
//fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}
func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) {
defer func() {
r := recover()
assert.Equal(t, r, errString)
}()
out := SQLBuilder{Dialect: defaultDialect}
clause.serialize(SelectStatementType, &out)
}
func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) {
out := SQLBuilder{Dialect: defaultDialect, debug: true}
clause.serialize(SelectStatementType, &out)
//fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}
func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) {
out := SQLBuilder{Dialect: defaultDialect}
projection.serializeForProjection(SelectStatementType, &out)
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}

View file

@ -2,52 +2,53 @@ package jet
import ( import (
"testing" "testing"
"time"
) )
var timeVar = Time(10, 20, 0, 0) var timeVar = Time(10, 20, 0, 0)
func TestTimeExpressionEQ(t *testing.T) { func TestTimeExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.EQ(table2ColTime), "(table1.col_time = table2.col_time)") assertClauseSerialize(t, table1ColTime.EQ(table2ColTime), "(table1.col_time = table2.col_time)")
assertClauseSerialize(t, table1ColTime.EQ(timeVar), "(table1.col_time = $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.EQ(timeVar), "(table1.col_time = $1)", "10:20:00")
} }
func TestTimeExpressionNOT_EQ(t *testing.T) { func TestTimeExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.NOT_EQ(table2ColTime), "(table1.col_time != table2.col_time)") assertClauseSerialize(t, table1ColTime.NOT_EQ(table2ColTime), "(table1.col_time != table2.col_time)")
assertClauseSerialize(t, table1ColTime.NOT_EQ(timeVar), "(table1.col_time != $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.NOT_EQ(timeVar), "(table1.col_time != $1)", "10:20:00")
} }
func TestTimeExpressionIS_DISTINCT_FROM(t *testing.T) { func TestTimeExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(table2ColTime), "(table1.col_time IS DISTINCT FROM table2.col_time)") assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(table2ColTime), "(table1.col_time IS DISTINCT FROM table2.col_time)")
assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(timeVar), "(table1.col_time IS DISTINCT FROM $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(timeVar), "(table1.col_time IS DISTINCT FROM $1)", "10:20:00")
} }
func TestTimeExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { func TestTimeExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(table2ColTime), "(table1.col_time IS NOT DISTINCT FROM table2.col_time)") assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(table2ColTime), "(table1.col_time IS NOT DISTINCT FROM table2.col_time)")
assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(timeVar), "(table1.col_time IS NOT DISTINCT FROM $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(timeVar), "(table1.col_time IS NOT DISTINCT FROM $1)", "10:20:00")
} }
func TestTimeExpressionLT(t *testing.T) { func TestTimeExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTime.LT(table2ColTime), "(table1.col_time < table2.col_time)") assertClauseSerialize(t, table1ColTime.LT(table2ColTime), "(table1.col_time < table2.col_time)")
assertClauseSerialize(t, table1ColTime.LT(timeVar), "(table1.col_time < $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.LT(timeVar), "(table1.col_time < $1)", "10:20:00")
} }
func TestTimeExpressionLT_EQ(t *testing.T) { func TestTimeExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.LT_EQ(table2ColTime), "(table1.col_time <= table2.col_time)") assertClauseSerialize(t, table1ColTime.LT_EQ(table2ColTime), "(table1.col_time <= table2.col_time)")
assertClauseSerialize(t, table1ColTime.LT_EQ(timeVar), "(table1.col_time <= $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.LT_EQ(timeVar), "(table1.col_time <= $1)", "10:20:00")
} }
func TestTimeExpressionGT(t *testing.T) { func TestTimeExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTime.GT(table2ColTime), "(table1.col_time > table2.col_time)") assertClauseSerialize(t, table1ColTime.GT(table2ColTime), "(table1.col_time > table2.col_time)")
assertClauseSerialize(t, table1ColTime.GT(timeVar), "(table1.col_time > $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.GT(timeVar), "(table1.col_time > $1)", "10:20:00")
} }
func TestTimeExpressionGT_EQ(t *testing.T) { func TestTimeExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.GT_EQ(table2ColTime), "(table1.col_time >= table2.col_time)") assertClauseSerialize(t, table1ColTime.GT_EQ(table2ColTime), "(table1.col_time >= table2.col_time)")
assertClauseSerialize(t, table1ColTime.GT_EQ(timeVar), "(table1.col_time >= $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.GT_EQ(timeVar), "(table1.col_time >= $1)", "10:20:00")
} }
func TestTimeExp(t *testing.T) { func TestTimeExp(t *testing.T) {
assertClauseSerialize(t, TimeExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimeExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1)), assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1*time.Millisecond)),
"(table1.col_float < $1::time without time zone)", string("01:01:01.001")) "(table1.col_float < $1)", string("01:01:01.001"))
} }

View file

@ -1,52 +1,55 @@
package jet package jet
import "testing" import (
"testing"
"time"
)
var timestamp = Timestamp(2000, 1, 31, 10, 20, 0, 0) var timestamp = Timestamp(2000, 1, 31, 10, 20, 0, 3*time.Millisecond)
func TestTimestampExpressionEQ(t *testing.T) { func TestTimestampExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.EQ(table2ColTimestamp), "(table1.col_timestamp = table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.EQ(table2ColTimestamp), "(table1.col_timestamp = table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.EQ(timestamp), assertClauseSerialize(t, table1ColTimestamp.EQ(timestamp),
"(table1.col_timestamp = $1::timestamp without time zone)", "2000-01-31 10:20:00.000") "(table1.col_timestamp = $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionNOT_EQ(t *testing.T) { func TestTimestampExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(table2ColTimestamp), "(table1.col_timestamp != table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(table2ColTimestamp), "(table1.col_timestamp != table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(timestamp), "(table1.col_timestamp != $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(timestamp), "(table1.col_timestamp != $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionIS_DISTINCT_FROM(t *testing.T) { func TestTimestampExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS DISTINCT FROM table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS DISTINCT FROM table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS DISTINCT FROM $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS DISTINCT FROM $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { func TestTimestampExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS NOT DISTINCT FROM table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS NOT DISTINCT FROM table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS NOT DISTINCT FROM $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS NOT DISTINCT FROM $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionLT(t *testing.T) { func TestTimestampExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.LT(table2ColTimestamp), "(table1.col_timestamp < table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.LT(table2ColTimestamp), "(table1.col_timestamp < table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.LT(timestamp), "(table1.col_timestamp < $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.LT(timestamp), "(table1.col_timestamp < $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionLT_EQ(t *testing.T) { func TestTimestampExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.LT_EQ(table2ColTimestamp), "(table1.col_timestamp <= table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.LT_EQ(table2ColTimestamp), "(table1.col_timestamp <= table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.LT_EQ(timestamp), "(table1.col_timestamp <= $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.LT_EQ(timestamp), "(table1.col_timestamp <= $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionGT(t *testing.T) { func TestTimestampExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.GT(table2ColTimestamp), "(table1.col_timestamp > table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.GT(table2ColTimestamp), "(table1.col_timestamp > table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.GT(timestamp), "(table1.col_timestamp > $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.GT(timestamp), "(table1.col_timestamp > $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionGT_EQ(t *testing.T) { func TestTimestampExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.GT_EQ(table2ColTimestamp), "(table1.col_timestamp >= table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.GT_EQ(table2ColTimestamp), "(table1.col_timestamp >= table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.GT_EQ(timestamp), "(table1.col_timestamp >= $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.GT_EQ(timestamp), "(table1.col_timestamp >= $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExp(t *testing.T) { func TestTimestampExp(t *testing.T) {
assertClauseSerialize(t, TimestampExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimestampExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp), assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp),
"(table1.col_float < $1::timestamp without time zone)", "2000-01-31 10:20:00.000") "(table1.col_float < $1)", "2000-01-31 10:20:00.003")
} }

View file

@ -51,6 +51,15 @@ func (t *timestampzInterfaceImpl) GT_EQ(rhs TimestampzExpression) BoolExpression
return gtEq(t.parent, rhs) return gtEq(t.parent, rhs)
} }
//---------------------------------------------------//
type prefixTimestampzOperator struct {
expressionInterfaceImpl
timestampzInterfaceImpl
prefixOpExpression
}
//------------------------------------------------- //-------------------------------------------------
type timestampzExpressionWrapper struct { type timestampzExpressionWrapper struct {

View file

@ -1,52 +1,55 @@
package jet package jet
import "testing" import (
"testing"
"time"
)
var timestampz = Timestampz(2000, 1, 31, 10, 20, 0, 0, 2) var timestampz = Timestampz(2000, 1, 31, 10, 20, 5, 23*time.Microsecond, "+200")
func TestTimestampzExpressionEQ(t *testing.T) { func TestTimestampzExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.EQ(table2ColTimestampz), "(table1.col_timestampz = table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.EQ(table2ColTimestampz), "(table1.col_timestampz = table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.EQ(timestampz), assertClauseSerialize(t, table1ColTimestampz.EQ(timestampz),
"(table1.col_timestampz = $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") "(table1.col_timestampz = $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionNOT_EQ(t *testing.T) { func TestTimestampzExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(table2ColTimestampz), "(table1.col_timestampz != table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(table2ColTimestampz), "(table1.col_timestampz != table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionIS_DISTINCT_FROM(t *testing.T) { func TestTimestampzExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS DISTINCT FROM table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { func TestTimestampzExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS NOT DISTINCT FROM table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS NOT DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionLT(t *testing.T) { func TestTimestampzExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT(table2ColTimestampz), "(table1.col_timestampz < table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.LT(table2ColTimestampz), "(table1.col_timestampz < table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionLT_EQ(t *testing.T) { func TestTimestampzExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(table2ColTimestampz), "(table1.col_timestampz <= table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.LT_EQ(table2ColTimestampz), "(table1.col_timestampz <= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionGT(t *testing.T) { func TestTimestampzExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT(table2ColTimestampz), "(table1.col_timestampz > table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.GT(table2ColTimestampz), "(table1.col_timestampz > table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionGT_EQ(t *testing.T) { func TestTimestampzExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(table2ColTimestampz), "(table1.col_timestampz >= table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.GT_EQ(table2ColTimestampz), "(table1.col_timestampz >= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExp(t *testing.T) { func TestTimestampzExp(t *testing.T) {
assertClauseSerialize(t, TimestampzExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimestampzExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz), assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz),
"(table1.col_float < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") "(table1.col_float < $1)", "2000-01-31 10:20:05.000023 +200")
} }

View file

@ -1,6 +1,6 @@
package jet package jet
// TimezExpression interface 'time with time zone' // TimezExpression interface for 'time with time zone' types
type TimezExpression interface { type TimezExpression interface {
Expression Expression

View file

@ -2,50 +2,50 @@ package jet
import "testing" import "testing"
var timezVar = Timez(10, 20, 0, 0, 4) var timezVar = Timez(10, 20, 0, 0, "+4:00")
func TestTimezExpressionEQ(t *testing.T) { func TestTimezExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.EQ(table2ColTimez), "(table1.col_timez = table2.col_timez)") assertClauseSerialize(t, table1ColTimez.EQ(table2ColTimez), "(table1.col_timez = table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionNOT_EQ(t *testing.T) { func TestTimezExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.NOT_EQ(table2ColTimez), "(table1.col_timez != table2.col_timez)") assertClauseSerialize(t, table1ColTimez.NOT_EQ(table2ColTimez), "(table1.col_timez != table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionIS_DISTINCT_FROM(t *testing.T) { func TestTimezExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS DISTINCT FROM table2.col_timez)") assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { func TestTimezExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS NOT DISTINCT FROM table2.col_timez)") assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS NOT DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionLT(t *testing.T) { func TestTimezExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT(table2ColTimez), "(table1.col_timez < table2.col_timez)") assertClauseSerialize(t, table1ColTimez.LT(table2ColTimez), "(table1.col_timez < table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionLT_EQ(t *testing.T) { func TestTimezExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT_EQ(table2ColTimez), "(table1.col_timez <= table2.col_timez)") assertClauseSerialize(t, table1ColTimez.LT_EQ(table2ColTimez), "(table1.col_timez <= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionGT(t *testing.T) { func TestTimezExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT(table2ColTimez), "(table1.col_timez > table2.col_timez)") assertClauseSerialize(t, table1ColTimez.GT(table2ColTimez), "(table1.col_timez > table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionGT_EQ(t *testing.T) { func TestTimezExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT_EQ(table2ColTimez), "(table1.col_timez >= table2.col_timez)") assertClauseSerialize(t, table1ColTimez.GT_EQ(table2ColTimez), "(table1.col_timez >= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1)", "10:20:00 +4:00")
} }
func TestTimezExp(t *testing.T) { func TestTimezExp(t *testing.T) {
assertClauseSerialize(t, TimezExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimezExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, 4)), assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, "+4:00")),
"(table1.col_float < $1::time with time zone)", string("01:01:01.001 +04")) "(table1.col_float < $1)", string("01:01:01.000000001 +4:00"))
} }

178
internal/jet/utils.go Normal file
View file

@ -0,0 +1,178 @@
package jet
import (
"github.com/go-jet/jet/internal/utils"
"reflect"
)
// SerializeClauseList func
func SerializeClauseList(statement StatementType, clauses []Serializer, out *SQLBuilder) {
for i, c := range clauses {
if i > 0 {
out.WriteString(", ")
}
if c == nil {
panic("jet: nil clause")
}
c.serialize(statement, out)
}
}
func serializeExpressionList(statement StatementType, expressions []Expression, separator string, out *SQLBuilder) {
for i, value := range expressions {
if i > 0 {
out.WriteString(separator)
}
value.serialize(statement, out)
}
}
// SerializeProjectionList func
func SerializeProjectionList(statement StatementType, projections []Projection, out *SQLBuilder) {
for i, col := range projections {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
if col == nil {
panic("jet: Projection is nil")
}
col.serializeForProjection(statement, out)
}
}
// SerializeColumnNames func
func SerializeColumnNames(columns []Column, out *SQLBuilder) {
for i, col := range columns {
if i > 0 {
out.WriteString(", ")
}
if col == nil {
panic("jet: nil column in columns list")
}
out.WriteString(col.Name())
}
}
// ColumnListToProjectionList func
func ColumnListToProjectionList(columns []ColumnExpression) []Projection {
var ret []Projection
for _, column := range columns {
ret = append(ret, column)
}
return ret
}
func valueToClause(value interface{}) Serializer {
if clause, ok := value.(Serializer); ok {
return clause
}
return literal(value)
}
// UnwindRowFromModel func
func UnwindRowFromModel(columns []Column, data interface{}) []Serializer {
structValue := reflect.Indirect(reflect.ValueOf(data))
row := []Serializer{}
utils.ValueMustBe(structValue, reflect.Struct, "jet: data has to be a struct")
for _, column := range columns {
columnName := column.Name()
structFieldName := utils.ToGoIdentifier(columnName)
structField := structValue.FieldByName(structFieldName)
if !structField.IsValid() {
panic("missing struct field for column : " + columnName)
}
var field interface{}
if structField.Kind() == reflect.Ptr && structField.IsNil() {
field = nil
} else {
field = reflect.Indirect(structField).Interface()
}
row = append(row, literal(field))
}
return row
}
// UnwindRowsFromModels func
func UnwindRowsFromModels(columns []Column, data interface{}) [][]Serializer {
sliceValue := reflect.Indirect(reflect.ValueOf(data))
utils.ValueMustBe(sliceValue, reflect.Slice, "jet: data has to be a slice.")
rows := [][]Serializer{}
for i := 0; i < sliceValue.Len(); i++ {
structValue := sliceValue.Index(i)
rows = append(rows, UnwindRowFromModel(columns, structValue.Interface()))
}
return rows
}
// UnwindRowFromValues func
func UnwindRowFromValues(value interface{}, values []interface{}) []Serializer {
row := []Serializer{}
allValues := append([]interface{}{value}, values...)
for _, val := range allValues {
row = append(row, valueToClause(val))
}
return row
}
// UnwindColumns func
func UnwindColumns(column1 Column, columns ...Column) []Column {
columnList := []Column{}
if val, ok := column1.(IColumnList); ok {
for _, col := range val.columns() {
columnList = append(columnList, col)
}
columnList = append(columnList, columns...)
} else {
columnList = append(columnList, column1)
columnList = append(columnList, columns...)
}
return columnList
}
// UnwidColumnList func
func UnwidColumnList(columns []Column) []Column {
ret := []Column{}
for _, col := range columns {
if columnList, ok := col.(IColumnList); ok {
for _, c := range columnList.columns() {
ret = append(ret, c)
}
} else {
ret = append(ret, col)
}
}
return ret
}

Some files were not shown because too many files have changed in this diff Show more