Merge pull request #7 from go-jet/develop

Merge develop to master
This commit is contained in:
go-jet 2019-07-23 13:42:47 +02:00 committed by GitHub
commit 05311b8129
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
84 changed files with 9819 additions and 1140 deletions

View file

@ -34,10 +34,7 @@ jobs:
go get github.com/davecgh/go-spew/spew go get github.com/davecgh/go-spew/spew
go get github.com/jstemmer/go-junit-report go get github.com/jstemmer/go-junit-report
- run: mkdir -p $TEST_RESULTS/unit-tests go install github.com/go-jet/jet/cmd/jet
- run: mkdir -p $TEST_RESULTS/integration-tests
- run: go test -v 2>&1 | go-junit-report > $TEST_RESULTS/unit-tests/results.xml
- run: - run:
name: Waiting for Postgres to be ready name: Waiting for Postgres to be ready
@ -51,13 +48,19 @@ jobs:
echo Failed waiting for Postgres && exit 1 echo Failed waiting for Postgres && exit 1
- run: - run:
name: Run integration tests name: Init Postgres database
command: | command: |
cd tests cd tests
go run ./init/init.go go run ./init/init.go
go test -v 2>&1 | go-junit-report > $TEST_RESULTS/integration-tests/results.xml
cd .. cd ..
- run: mkdir -p $TEST_RESULTS
- run: go test -v . ./tests -coverpkg=github.com/go-jet/jet,github.com/go-jet/jet/execution/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
- run:
name: Upload code coverage
command: bash <(curl -s https://codecov.io/bash)
- store_artifacts: # Upload test summary for display in Artifacts: https://circleci.com/docs/2.0/artifacts/ - store_artifacts: # Upload test summary for display in Artifacts: https://circleci.com/docs/2.0/artifacts/
path: /tmp/test-results path: /tmp/test-results
destination: raw-test-output destination: raw-test-output

View file

@ -1,10 +1,13 @@
# Jet # Jet
[![Go Report Card](https://goreportcard.com/badge/github.com/go-jet/jet)](https://goreportcard.com/report/github.com/go-jet/jet)
[![Documentation](https://godoc.org/github.com/go-jet/jet?status.svg)](http://godoc.org/github.com/go-jet/jet)
[![codecov](https://codecov.io/gh/go-jet/jet/branch/develop/graph/badge.svg)](https://codecov.io/gh/go-jet/jet)
[![CircleCI](https://circleci.com/gh/go-jet/jet/tree/develop.svg?style=svg&circle-token=97f255c6a4a3ab6590ea2e9195eb3ebf9f97b4a7)](https://circleci.com/gh/go-jet/jet/tree/develop) [![CircleCI](https://circleci.com/gh/go-jet/jet/tree/develop.svg?style=svg&circle-token=97f255c6a4a3ab6590ea2e9195eb3ebf9f97b4a7)](https://circleci.com/gh/go-jet/jet/tree/develop)
Jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability to easily Jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability to easily
convert database query result to desired arbitrary structure. convert database query result to desired arbitrary structure.
_Support for additional databases will be added in future jet releases._ _*Support for additional databases will be added in future jet releases._
## Contents ## Contents
@ -22,7 +25,7 @@ _Support for additional databases will be added in future jet releases._
- [License](#license) - [License](#license)
## Features ## Features
1) Type-safe SQL Builder 1) Auto-generated type-safe SQL Builder
- Types - boolean, integers(smallint, integer, bigint), floats(real, numeric, decimal, double precision), - Types - boolean, integers(smallint, integer, bigint), floats(real, numeric, decimal, double precision),
strings(text, character, character varying), date, time(z), timestamp(z) and enums. strings(text, character, character varying), date, time(z), timestamp(z) and enums.
- Statements: - Statements:
@ -31,10 +34,9 @@ _Support for additional databases will be added in future jet releases._
* UPDATE (SET, WHERE, RETURNING), * UPDATE (SET, WHERE, RETURNING),
* DELETE (WHERE, RETURNING), * DELETE (WHERE, RETURNING),
* LOCK (IN, NOWAIT) * LOCK (IN, NOWAIT)
2) Auto-generated Data Model types - Go struct mapped to database type (table or enum), used to store 2) Auto-generated Data Model types - Go types mapped to database type (table or enum), used to store
result of database queries. result of database queries. Can be combined to create desired query result destination.
3) Query execution with mapping to arbitrary destination structure - destination structure can be 3) Query execution with result mapping to arbitrary destination structure.
created by combining auto-generated data model types.
## Getting Started ## Getting Started
@ -63,15 +65,16 @@ Make sure GOPATH bin folder is added to the PATH environment variable.
For this quick start example we will use sample _dvd rental_ database. Full database dump can be found in [./tests/init/data/dvds.sql](./tests/init/data/dvds.sql). For this quick start example we will use sample _dvd rental_ database. Full database dump can be found in [./tests/init/data/dvds.sql](./tests/init/data/dvds.sql).
Schema diagram of interest for example can be found [here](./examples/quick-start/diagram.png). Schema diagram of interest for example can be found [here](./examples/quick-start/diagram.png).
#### Generate sql builder and model files #### Generate SQL Builder and Model files
To generate jet SQL Builder and Data Model files from postgres database we need to call `jet` generator with postgres To generate jet SQL Builder and Data Model files from postgres database we need to call `jet` generator with postgres
connection parameters and root destination folder path for generated files.\ connection parameters and root destination folder path for generated files.\
Assuming we are running local postgres database, with user `jet`, database `jetdb` and schema `dvds` we will use this command: Assuming we are running local postgres database, with user `jetuser`, user password `jetpass`, database `jetdb` and
schema `dvds` we will use this command:
```sh ```sh
jet -host=localhost -port=5432 -user=jet -password=jet -dbname=jetdb -schema=dvds -path=./gen jet -host=localhost -port=5432 -user=jetuser -password=jetpass -dbname=jetdb -schema=dvds -path=./gen
``` ```
```sh ```sh
Connecting to postgres database: host=localhost port=5432 user=jet password=jet dbname=jetdb sslmode=disable Connecting to postgres database: host=localhost port=5432 user=jetuser password=jetpass dbname=jetdb sslmode=disable
Retrieving schema information... Retrieving schema information...
FOUND 15 table(s), 1 enum(s) FOUND 15 table(s), 1 enum(s)
Destination directory: ./gen/jetdb/dvds Destination directory: ./gen/jetdb/dvds
@ -82,11 +85,12 @@ Generating enum sql builder files...
Generating enum model files... Generating enum model files...
Done Done
``` ```
_*User has to have a permission to read information schema tables_
As command output suggest, Jet will: As command output suggest, Jet will:
- connect to postgres database and retrieve information about the _tables_ and _enums_ of `dvds` schema - connect to postgres database and retrieve information about the _tables_ and _enums_ of `dvds` schema
- delete everything in schema destination folder - `./gen/jetdb/dvds`, - delete everything in schema destination folder - `./gen/jetdb/dvds`,
- and finally generate sql builder and model files for each schema tables and enums. - and finally generate SQL Builder and Model files for each schema table and enum.
Generated files folder structure will look like this: Generated files folder structure will look like this:
@ -101,13 +105,13 @@ Generated files folder structure will look like this:
| |-- address.go | |-- address.go
| |-- category.go | |-- category.go
| ... | ...
| |-- model # Plain Old Data for every table and enum | |-- model # model files for each table and enum
| | |-- actor.go | | |-- actor.go
| | |-- address.go | | |-- address.go
| | |-- mpaa_rating.go | | |-- mpaa_rating.go
| | ... | | ...
``` ```
Types from `table` and `enum` are used to write type safe SQL in Go, and `model` types are combined to store Types from `table` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store
results of the SQL queries. results of the SQL queries.
#### Lets write some SQL queries in Go #### Lets write some SQL queries in Go
@ -147,12 +151,11 @@ stmt := SELECT(
Film.FilmID.ASC(), Film.FilmID.ASC(),
) )
``` ```
With package(dot) import above statements looks almost the same as native SQL. With package(dot) import above statements looks almost the same as native SQL. Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with
Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with
string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns
and can be compared only with integer columns and expressions. and can be compared only with integer columns and expressions.
__How to get parametrized SQL query?__ __How to get parametrized SQL query from statement?__
```go ```go
query, args, err := stmt.Sql() query, args, err := stmt.Sql()
``` ```
@ -160,7 +163,7 @@ query - parametrized query\
args - parameters for the query args - parameters for the query
<details> <details>
<summary>Click to see `query` and `arg`</summary> <summary>Click to see `query` and `args`</summary>
```sql ```sql
SELECT actor.actor_id AS "actor.actor_id", SELECT actor.actor_id AS "actor.actor_id",
@ -202,11 +205,12 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
</details> </details>
__How to get debug SQL that can be copy pasted to sql editor and executed?__ __How to get debug SQL from statement?__
```go ```go
debugSql, err := stmt.DebugSql() debugSql, err := stmt.DebugSql()
``` ```
debugSql - parametrized query where every parameter is replaced with appropriate string argument representation debugSql - query string that can be copy pasted to sql editor and executed. It's not intended to be used in production.
<details> <details>
<summary>Click to see debug sql</summary> <summary>Click to see debug sql</summary>
@ -252,20 +256,24 @@ Well formed SQL is just a first half the job. Lets see how can we make some sens
above statement. Usually this is the most complex and tedious work, but with Jet it is the easiest. above statement. Usually this is the most complex and tedious work, but with Jet it is the easiest.
First we have to create desired structure to store query result set. First we have to create desired structure to store query result set.
This is done be combining autogenerated model types or it can be done manually(see wiki for more information). This is done be combining autogenerated model types or it can be done manually(see [wiki](https://github.com/go-jet/jet/wiki/Scan-to-arbitrary-destination) for more information).
Let's say this is our desired structure: Let's say this is our desired structure, created by combining auto-generated model types:
```go ```go
var dest []struct { var dest []struct {
model.Actor model.Actor
Films []struct { Films []struct {
model.Film model.Film
Language model.Language Language model.Language
Categories []model.Category Categories []model.Category
} }
} }
``` ```
_There is no limitation for how big or nested destination structure can be._ Because one actor can act in multiple films, `Films` field is a slice, and because each film belongs to one language
`Langauge` field is just a single model struct.
_*There is no limitation of how big or nested destination structure can be._
Now lets execute a above statement on open database connection db and store result into `dest`. Now lets execute a above statement on open database connection db and store result into `dest`.
@ -494,12 +502,12 @@ ORM sometimes can access the database once for every object needed. Now lets say
different objects required from the database. This handler will last 3 seconds !!!. different objects required from the database. This handler will last 3 seconds !!!.
With Jet, handler time lost on latency between server and database is constant. Because we can write complex query and With Jet, handler time lost on latency between server and database is constant. Because we can write complex query and
return result in one database call. Handler execution will be proportional to the number of rows returned from database. return result in one database call. Handler execution will be only proportional to the number of rows returned from database.
ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate). ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate).
With Jet you can even join the whole database and store the whole structured result in in one query call. With Jet you can even join the whole database and store the whole structured result in in one query call.
This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/chinook_db_test.go#L40). This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/chinook_db_test.go#L40).
The whole test database is joined and query result is stored in a structured variable in less than 1s. The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s.
##### How quickly bugs are found ##### How quickly bugs are found

View file

@ -12,7 +12,7 @@ func newAlias(expression Expression, aliasName string) projection {
} }
} }
func (a *alias) from(subQuery ExpressionTable) projection { func (a *alias) from(subQuery SelectTable) projection {
column := newColumn(a.alias, "", nil) column := newColumn(a.alias, "", nil)
column.parent = &column column.parent = &column
column.subQuery = subQuery column.subQuery = subQuery

View file

@ -1,5 +1,6 @@
package jet package jet
//BoolExpression interface
type BoolExpression interface { type BoolExpression interface {
Expression Expression
@ -150,6 +151,9 @@ func newBoolExpressionWrap(expression Expression) BoolExpression {
return &boolExpressionWrap return &boolExpressionWrap
} }
// BoolExp is bool expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as bool expression.
// Does not add sql cast to generated sql builder output.
func BoolExp(expression Expression) BoolExpression { func BoolExp(expression Expression) BoolExpression {
return newBoolExpressionWrap(expression) return newBoolExpressionWrap(expression)
} }

View file

@ -15,6 +15,16 @@ func TestBoolExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColBool.NOT_EQ(Bool(true)), "(table1.col_bool != $1)", true) assertClauseSerialize(t, table1ColBool.NOT_EQ(Bool(true)), "(table1.col_bool != $1)", true)
} }
func TestBoolExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS DISTINCT FROM table2.col_bool)")
assertClauseSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS DISTINCT FROM $1)", false)
}
func TestBoolExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS NOT DISTINCT FROM table2.col_bool)")
assertClauseSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS NOT DISTINCT FROM $1)", false)
}
func TestBoolExpressionIS_TRUE(t *testing.T) { func TestBoolExpressionIS_TRUE(t *testing.T) {
assertClauseSerialize(t, table1ColBool.IS_TRUE(), "table1.col_bool IS TRUE") assertClauseSerialize(t, table1ColBool.IS_TRUE(), "table1.col_bool IS TRUE")
assertClauseSerialize(t, (Int(2).EQ(table1ColInt)).IS_TRUE(), assertClauseSerialize(t, (Int(2).EQ(table1ColInt)).IS_TRUE(),

View file

@ -36,6 +36,8 @@ type castImpl struct {
castType string castType string
} }
// CAST wraps expression for casting.
// For instance: CAST(table.column).AS_BOOL()
func CAST(expression Expression) cast { func CAST(expression Expression) cast {
return &castImpl{ return &castImpl{
Expression: expression, Expression: expression,

View file

@ -40,12 +40,12 @@ type sqlBuilder struct {
type statementType string type statementType string
const ( const (
select_statement statementType = "SELECT" selectStatement statementType = "SELECT"
insert_statement statementType = "INSERT" insertStatement statementType = "INSERT"
update_statement statementType = "UPDATE" updateStatement statementType = "UPDATE"
delete_statement statementType = "DELETE" deleteStatement statementType = "DELETE"
set_statement statementType = "SET" setStatement statementType = "SET"
lock_statement statementType = "LOCK" lockStatement statementType = "LOCK"
) )
const defaultIdent = 5 const defaultIdent = 5
@ -102,7 +102,7 @@ func (q *sqlBuilder) writeGroupBy(statement statementType, groupBy []groupByClau
return err return err
} }
func (q *sqlBuilder) writeOrderBy(statement statementType, orderBy []OrderByClause) error { func (q *sqlBuilder) writeOrderBy(statement statementType, orderBy []orderByClause) error {
q.newLine() q.newLine()
q.writeString("ORDER BY") q.writeString("ORDER BY")
@ -189,23 +189,18 @@ func (q *sqlBuilder) finalize() (string, []interface{}) {
} }
func (q *sqlBuilder) insertConstantArgument(arg interface{}) { func (q *sqlBuilder) insertConstantArgument(arg interface{}) {
q.writeString(ArgToString(arg)) q.writeString(argToString(arg))
} }
func (q *sqlBuilder) insertPreparedArgument(arg interface{}) { func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) {
q.args = append(q.args, arg) q.args = append(q.args, arg)
argPlaceholder := "$" + strconv.Itoa(len(q.args)) argPlaceholder := "$" + strconv.Itoa(len(q.args))
q.writeString(argPlaceholder) q.writeString(argPlaceholder)
} }
func (q *sqlBuilder) reset() { func argToString(value interface{}) string {
q.buff.Reset() if utils.IsNil(value) {
q.args = []interface{}{}
}
func ArgToString(value interface{}) string {
if isNil(value) {
return "NULL" return "NULL"
} }
@ -213,9 +208,8 @@ func ArgToString(value interface{}) string {
case bool: case bool:
if bindVal { if bindVal {
return "TRUE" return "TRUE"
} else {
return "FALSE"
} }
return "FALSE"
case int8: case int8:
return strconv.FormatInt(int64(bindVal), 10) return strconv.FormatInt(int64(bindVal), 10)
case int: case int:
@ -252,7 +246,7 @@ func ArgToString(value interface{}) string {
case time.Time: case time.Time:
return stringQuote(string(utils.FormatTimestamp(bindVal))) return stringQuote(string(utils.FormatTimestamp(bindVal)))
default: default:
return "[Unknown type]" return "[Unsupported type]"
} }
} }

33
clause_test.go Normal file
View file

@ -0,0 +1,33 @@
package jet
import (
"github.com/google/uuid"
"gotest.tools/assert"
"testing"
"time"
)
func TestArgToString(t *testing.T) {
assert.Equal(t, argToString(true), "TRUE")
assert.Equal(t, argToString(false), "FALSE")
assert.Equal(t, argToString(int8(-8)), "-8")
assert.Equal(t, argToString(int16(-16)), "-16")
assert.Equal(t, argToString(int(-32)), "-32")
assert.Equal(t, argToString(int32(-32)), "-32")
assert.Equal(t, argToString(int64(-64)), "-64")
assert.Equal(t, argToString(uint8(8)), "8")
assert.Equal(t, argToString(uint16(16)), "16")
assert.Equal(t, argToString(uint(32)), "32")
assert.Equal(t, argToString(uint32(32)), "32")
assert.Equal(t, argToString(uint64(64)), "64")
assert.Equal(t, argToString("john"), "'john'")
assert.Equal(t, argToString([]byte("john")), "'john'")
assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")
time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006")
assert.NilError(t, err)
assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'")
assert.Equal(t, argToString(map[string]bool{}), "[Unsupported type]")
}

View file

@ -4,8 +4,8 @@ import (
"flag" "flag"
"fmt" "fmt"
"github.com/go-jet/jet/generator/postgres" "github.com/go-jet/jet/generator/postgres"
_ "github.com/lib/pq"
"os" "os"
"strconv"
) )
var ( var (
@ -70,7 +70,7 @@ Usage of jet:
genData := postgres.DBConnection{ genData := postgres.DBConnection{
Host: host, Host: host,
Port: strconv.Itoa(port), Port: port,
User: user, User: user,
Password: password, Password: password,
SslMode: sslmode, SslMode: sslmode,

View file

@ -7,10 +7,11 @@ type column interface {
TableName() string TableName() string
setTableName(table string) setTableName(table string)
setSubQuery(subQuery ExpressionTable) setSubQuery(subQuery SelectTable)
defaultAlias() string defaultAlias() string
} }
// Column is common column interface for all types of columns.
type Column interface { type Column interface {
Expression Expression
column column
@ -23,7 +24,7 @@ type columnImpl struct {
name string name string
tableName string tableName string
subQuery ExpressionTable subQuery SelectTable
} }
func newColumn(name string, tableName string, parent Column) columnImpl { func newColumn(name string, tableName string, parent Column) columnImpl {
@ -49,7 +50,7 @@ func (c *columnImpl) setTableName(table string) {
c.tableName = table c.tableName = table
} }
func (c *columnImpl) setSubQuery(subQuery ExpressionTable) { func (c *columnImpl) setSubQuery(subQuery SelectTable) {
c.subQuery = subQuery c.subQuery = subQuery
} }
@ -62,7 +63,7 @@ func (c *columnImpl) defaultAlias() string {
} }
func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error { func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
if statement == set_statement { if statement == setStatement {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause // set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.writeString(`"` + c.defaultAlias() + `"`) //always quote out.writeString(`"` + c.defaultAlias() + `"`) //always quote
@ -104,13 +105,13 @@ func (c columnImpl) serialize(statement statementType, out *sqlBuilder, options
//------------------------------------------------------// //------------------------------------------------------//
// Redefined type to support list of columns as projection // ColumnList is redefined type to support list of columns as single projection
type ColumnList []Column type ColumnList []Column
// projection interface implementation // projection interface implementation
func (cl ColumnList) isProjectionType() {} func (cl ColumnList) isProjectionType() {}
func (cl ColumnList) from(subQuery ExpressionTable) projection { func (cl ColumnList) from(subQuery SelectTable) projection {
newProjectionList := ProjectionList{} newProjectionList := ProjectionList{}
for _, column := range cl { for _, column := range cl {
@ -134,8 +135,11 @@ func (cl ColumnList) serializeForProjection(statement statementType, out *sqlBui
// dummy column interface implementation // dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" } func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" } func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {} func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery ExpressionTable) {} func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" } func (cl ColumnList) defaultAlias() string { return "" }

View file

@ -1,10 +1,11 @@
package jet package jet
// ColumnBool is interface for SQL boolean columns.
type ColumnBool interface { type ColumnBool interface {
BoolExpression BoolExpression
column column
From(subQuery ExpressionTable) ColumnBool From(subQuery SelectTable) ColumnBool
} }
type boolColumnImpl struct { type boolColumnImpl struct {
@ -13,7 +14,7 @@ type boolColumnImpl struct {
columnImpl columnImpl
} }
func (i *boolColumnImpl) from(subQuery ExpressionTable) projection { func (i *boolColumnImpl) from(subQuery SelectTable) projection {
newBoolColumn := BoolColumn(i.name) newBoolColumn := BoolColumn(i.name)
newBoolColumn.setTableName(i.tableName) newBoolColumn.setTableName(i.tableName)
newBoolColumn.setSubQuery(subQuery) newBoolColumn.setSubQuery(subQuery)
@ -21,12 +22,13 @@ func (i *boolColumnImpl) from(subQuery ExpressionTable) projection {
return newBoolColumn return newBoolColumn
} }
func (i *boolColumnImpl) From(subQuery ExpressionTable) ColumnBool { func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
newBoolColumn := i.from(subQuery).(ColumnBool) newBoolColumn := i.from(subQuery).(ColumnBool)
return newBoolColumn return newBoolColumn
} }
// BoolColumn creates named bool column.
func BoolColumn(name string) ColumnBool { func BoolColumn(name string) ColumnBool {
boolColumn := &boolColumnImpl{} boolColumn := &boolColumnImpl{}
boolColumn.columnImpl = newColumn(name, "", boolColumn) boolColumn.columnImpl = newColumn(name, "", boolColumn)
@ -37,11 +39,12 @@ func BoolColumn(name string) ColumnBool {
//------------------------------------------------------// //------------------------------------------------------//
// ColumnFloat is interface for SQL real, numeric, decimal or double precision column.
type ColumnFloat interface { type ColumnFloat interface {
FloatExpression FloatExpression
column column
From(subQuery ExpressionTable) ColumnFloat From(subQuery SelectTable) ColumnFloat
} }
type floatColumnImpl struct { type floatColumnImpl struct {
@ -49,7 +52,7 @@ type floatColumnImpl struct {
columnImpl columnImpl
} }
func (i *floatColumnImpl) from(subQuery ExpressionTable) projection { func (i *floatColumnImpl) from(subQuery SelectTable) projection {
newFloatColumn := FloatColumn(i.name) newFloatColumn := FloatColumn(i.name)
newFloatColumn.setTableName(i.tableName) newFloatColumn.setTableName(i.tableName)
newFloatColumn.setSubQuery(subQuery) newFloatColumn.setSubQuery(subQuery)
@ -57,12 +60,13 @@ func (i *floatColumnImpl) from(subQuery ExpressionTable) projection {
return newFloatColumn return newFloatColumn
} }
func (i *floatColumnImpl) From(subQuery ExpressionTable) ColumnFloat { func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
newFloatColumn := i.from(subQuery).(ColumnFloat) newFloatColumn := i.from(subQuery).(ColumnFloat)
return newFloatColumn return newFloatColumn
} }
// FloatColumn creates named float column.
func FloatColumn(name string) ColumnFloat { func FloatColumn(name string) ColumnFloat {
floatColumn := &floatColumnImpl{} floatColumn := &floatColumnImpl{}
floatColumn.floatInterfaceImpl.parent = floatColumn floatColumn.floatInterfaceImpl.parent = floatColumn
@ -73,11 +77,12 @@ func FloatColumn(name string) ColumnFloat {
//------------------------------------------------------// //------------------------------------------------------//
// ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger interface { type ColumnInteger interface {
IntegerExpression IntegerExpression
column column
From(subQuery ExpressionTable) ColumnInteger From(subQuery SelectTable) ColumnInteger
} }
type integerColumnImpl struct { type integerColumnImpl struct {
@ -86,7 +91,7 @@ type integerColumnImpl struct {
columnImpl columnImpl
} }
func (i *integerColumnImpl) from(subQuery ExpressionTable) projection { func (i *integerColumnImpl) from(subQuery SelectTable) projection {
newIntColumn := IntegerColumn(i.name) newIntColumn := IntegerColumn(i.name)
newIntColumn.setTableName(i.tableName) newIntColumn.setTableName(i.tableName)
newIntColumn.setSubQuery(subQuery) newIntColumn.setSubQuery(subQuery)
@ -94,10 +99,11 @@ func (i *integerColumnImpl) from(subQuery ExpressionTable) projection {
return newIntColumn return newIntColumn
} }
func (i *integerColumnImpl) From(subQuery ExpressionTable) ColumnInteger { func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
return i.from(subQuery).(ColumnInteger) return i.from(subQuery).(ColumnInteger)
} }
// IntegerColumn creates named integer column.
func IntegerColumn(name string) ColumnInteger { func IntegerColumn(name string) ColumnInteger {
integerColumn := &integerColumnImpl{} integerColumn := &integerColumnImpl{}
integerColumn.integerInterfaceImpl.parent = integerColumn integerColumn.integerInterfaceImpl.parent = integerColumn
@ -108,11 +114,13 @@ func IntegerColumn(name string) ColumnInteger {
//------------------------------------------------------// //------------------------------------------------------//
// ColumnString is interface for SQL text, character, character varying
// bytea, uuid columns and enums types.
type ColumnString interface { type ColumnString interface {
StringExpression StringExpression
column column
From(subQuery ExpressionTable) ColumnString From(subQuery SelectTable) ColumnString
} }
type stringColumnImpl struct { type stringColumnImpl struct {
@ -121,7 +129,7 @@ type stringColumnImpl struct {
columnImpl columnImpl
} }
func (i *stringColumnImpl) from(subQuery ExpressionTable) projection { func (i *stringColumnImpl) from(subQuery SelectTable) projection {
newStrColumn := StringColumn(i.name) newStrColumn := StringColumn(i.name)
newStrColumn.setTableName(i.tableName) newStrColumn.setTableName(i.tableName)
newStrColumn.setSubQuery(subQuery) newStrColumn.setSubQuery(subQuery)
@ -129,10 +137,11 @@ func (i *stringColumnImpl) from(subQuery ExpressionTable) projection {
return newStrColumn return newStrColumn
} }
func (i *stringColumnImpl) From(subQuery ExpressionTable) ColumnString { func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
return i.from(subQuery).(ColumnString) return i.from(subQuery).(ColumnString)
} }
// StringColumn creates named string column.
func StringColumn(name string) ColumnString { func StringColumn(name string) ColumnString {
stringColumn := &stringColumnImpl{} stringColumn := &stringColumnImpl{}
stringColumn.stringInterfaceImpl.parent = stringColumn stringColumn.stringInterfaceImpl.parent = stringColumn
@ -143,11 +152,12 @@ func StringColumn(name string) ColumnString {
//------------------------------------------------------// //------------------------------------------------------//
// ColumnTime is interface for SQL time column.
type ColumnTime interface { type ColumnTime interface {
TimeExpression TimeExpression
column column
From(subQuery ExpressionTable) ColumnTime From(subQuery SelectTable) ColumnTime
} }
type timeColumnImpl struct { type timeColumnImpl struct {
@ -155,7 +165,7 @@ type timeColumnImpl struct {
columnImpl columnImpl
} }
func (i *timeColumnImpl) from(subQuery ExpressionTable) projection { func (i *timeColumnImpl) from(subQuery SelectTable) projection {
newTimeColumn := TimeColumn(i.name) newTimeColumn := TimeColumn(i.name)
newTimeColumn.setTableName(i.tableName) newTimeColumn.setTableName(i.tableName)
newTimeColumn.setSubQuery(subQuery) newTimeColumn.setSubQuery(subQuery)
@ -163,10 +173,11 @@ func (i *timeColumnImpl) from(subQuery ExpressionTable) projection {
return newTimeColumn return newTimeColumn
} }
func (i *timeColumnImpl) From(subQuery ExpressionTable) ColumnTime { func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
return i.from(subQuery).(ColumnTime) return i.from(subQuery).(ColumnTime)
} }
// TimeColumn creates named time column
func TimeColumn(name string) ColumnTime { func TimeColumn(name string) ColumnTime {
timeColumn := &timeColumnImpl{} timeColumn := &timeColumnImpl{}
timeColumn.timeInterfaceImpl.parent = timeColumn timeColumn.timeInterfaceImpl.parent = timeColumn
@ -176,11 +187,12 @@ func TimeColumn(name string) ColumnTime {
//------------------------------------------------------// //------------------------------------------------------//
// ColumnTimez is interface of SQL time with time zone columns.
type ColumnTimez interface { type ColumnTimez interface {
TimezExpression TimezExpression
column column
From(subQuery ExpressionTable) ColumnTimez From(subQuery SelectTable) ColumnTimez
} }
type timezColumnImpl struct { type timezColumnImpl struct {
@ -189,7 +201,7 @@ type timezColumnImpl struct {
columnImpl columnImpl
} }
func (i *timezColumnImpl) from(subQuery ExpressionTable) projection { func (i *timezColumnImpl) from(subQuery SelectTable) projection {
newTimezColumn := TimezColumn(i.name) newTimezColumn := TimezColumn(i.name)
newTimezColumn.setTableName(i.tableName) newTimezColumn.setTableName(i.tableName)
newTimezColumn.setSubQuery(subQuery) newTimezColumn.setSubQuery(subQuery)
@ -197,10 +209,11 @@ func (i *timezColumnImpl) from(subQuery ExpressionTable) projection {
return newTimezColumn return newTimezColumn
} }
func (i *timezColumnImpl) From(subQuery ExpressionTable) ColumnTimez { func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
return i.from(subQuery).(ColumnTimez) return i.from(subQuery).(ColumnTimez)
} }
// TimezColumn creates named time with time zone column.
func TimezColumn(name string) ColumnTimez { func TimezColumn(name string) ColumnTimez {
timezColumn := &timezColumnImpl{} timezColumn := &timezColumnImpl{}
timezColumn.timezInterfaceImpl.parent = timezColumn timezColumn.timezInterfaceImpl.parent = timezColumn
@ -211,11 +224,12 @@ func TimezColumn(name string) ColumnTimez {
//------------------------------------------------------// //------------------------------------------------------//
// ColumnTimestamp is interface of SQL timestamp columns.
type ColumnTimestamp interface { type ColumnTimestamp interface {
TimestampExpression TimestampExpression
column column
From(subQuery ExpressionTable) ColumnTimestamp From(subQuery SelectTable) ColumnTimestamp
} }
type timestampColumnImpl struct { type timestampColumnImpl struct {
@ -224,7 +238,7 @@ type timestampColumnImpl struct {
columnImpl columnImpl
} }
func (i *timestampColumnImpl) from(subQuery ExpressionTable) projection { func (i *timestampColumnImpl) from(subQuery SelectTable) projection {
newTimestampColumn := TimestampColumn(i.name) newTimestampColumn := TimestampColumn(i.name)
newTimestampColumn.setTableName(i.tableName) newTimestampColumn.setTableName(i.tableName)
newTimestampColumn.setSubQuery(subQuery) newTimestampColumn.setSubQuery(subQuery)
@ -232,10 +246,11 @@ func (i *timestampColumnImpl) from(subQuery ExpressionTable) projection {
return newTimestampColumn return newTimestampColumn
} }
func (i *timestampColumnImpl) From(subQuery ExpressionTable) ColumnTimestamp { func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
return i.from(subQuery).(ColumnTimestamp) return i.from(subQuery).(ColumnTimestamp)
} }
// TimestampColumn creates named timestamp column
func TimestampColumn(name string) ColumnTimestamp { func TimestampColumn(name string) ColumnTimestamp {
timestampColumn := &timestampColumnImpl{} timestampColumn := &timestampColumnImpl{}
timestampColumn.timestampInterfaceImpl.parent = timestampColumn timestampColumn.timestampInterfaceImpl.parent = timestampColumn
@ -246,11 +261,12 @@ func TimestampColumn(name string) ColumnTimestamp {
//------------------------------------------------------// //------------------------------------------------------//
// ColumnTimestampz is interface of SQL timestamp with timezone columns.
type ColumnTimestampz interface { type ColumnTimestampz interface {
TimestampzExpression TimestampzExpression
column column
From(subQuery ExpressionTable) ColumnTimestampz From(subQuery SelectTable) ColumnTimestampz
} }
type timestampzColumnImpl struct { type timestampzColumnImpl struct {
@ -259,7 +275,7 @@ type timestampzColumnImpl struct {
columnImpl columnImpl
} }
func (i *timestampzColumnImpl) from(subQuery ExpressionTable) projection { func (i *timestampzColumnImpl) from(subQuery SelectTable) projection {
newTimestampzColumn := TimestampzColumn(i.name) newTimestampzColumn := TimestampzColumn(i.name)
newTimestampzColumn.setTableName(i.tableName) newTimestampzColumn.setTableName(i.tableName)
newTimestampzColumn.setSubQuery(subQuery) newTimestampzColumn.setSubQuery(subQuery)
@ -267,10 +283,11 @@ func (i *timestampzColumnImpl) from(subQuery ExpressionTable) projection {
return newTimestampzColumn return newTimestampzColumn
} }
func (i *timestampzColumnImpl) From(subQuery ExpressionTable) ColumnTimestampz { func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
return i.from(subQuery).(ColumnTimestampz) return i.from(subQuery).(ColumnTimestampz)
} }
// TimestampzColumn creates named timestamp with time zone column.
func TimestampzColumn(name string) ColumnTimestampz { func TimestampzColumn(name string) ColumnTimestampz {
timestampzColumn := &timestampzColumnImpl{} timestampzColumn := &timestampzColumnImpl{}
timestampzColumn.timestampzInterfaceImpl.parent = timestampzColumn timestampzColumn.timestampzInterfaceImpl.parent = timestampzColumn
@ -281,11 +298,12 @@ func TimestampzColumn(name string) ColumnTimestampz {
//------------------------------------------------------// //------------------------------------------------------//
// ColumnDate is interface of SQL date columns.
type ColumnDate interface { type ColumnDate interface {
DateExpression DateExpression
column column
From(subQuery ExpressionTable) ColumnDate From(subQuery SelectTable) ColumnDate
} }
type dateColumnImpl struct { type dateColumnImpl struct {
@ -294,7 +312,7 @@ type dateColumnImpl struct {
columnImpl columnImpl
} }
func (i *dateColumnImpl) from(subQuery ExpressionTable) projection { func (i *dateColumnImpl) from(subQuery SelectTable) projection {
newDateColumn := DateColumn(i.name) newDateColumn := DateColumn(i.name)
newDateColumn.setTableName(i.tableName) newDateColumn.setTableName(i.tableName)
newDateColumn.setSubQuery(subQuery) newDateColumn.setSubQuery(subQuery)
@ -302,10 +320,11 @@ func (i *dateColumnImpl) from(subQuery ExpressionTable) projection {
return newDateColumn return newDateColumn
} }
func (i *dateColumnImpl) From(subQuery ExpressionTable) ColumnDate { func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
return i.from(subQuery).(ColumnDate) return i.from(subQuery).(ColumnDate)
} }
// DateColumn creates named date column.
func DateColumn(name string) ColumnDate { func DateColumn(name string) ColumnDate {
dateColumn := &dateColumnImpl{} dateColumn := &dateColumnImpl{}
dateColumn.dateInterfaceImpl.parent = dateColumn dateColumn.dateInterfaceImpl.parent = dateColumn

View file

@ -1,5 +1,6 @@
package jet package jet
// DateExpression is interface for all SQL date expressions.
type DateExpression interface { type DateExpression interface {
Expression Expression
@ -63,6 +64,9 @@ func newDateExpressionWrap(expression Expression) DateExpression {
return &dateExpressionWrap return &dateExpressionWrap
} }
// DateExp is date expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as date expression.
// Does not add sql cast to generated sql builder output.
func DateExp(expression Expression) DateExpression { func DateExp(expression Expression) DateExpression {
return newDateExpressionWrap(expression) return newDateExpressionWrap(expression)
} }

45
date_expression_test.go Normal file
View file

@ -0,0 +1,45 @@
package jet
import "testing"
var dateVar = Date(2000, 12, 30)
func TestDateExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.EQ(table2ColDate), "(table1.col_date = table2.col_date)")
assertClauseSerialize(t, table1ColDate.EQ(dateVar), "(table1.col_date = $1::date)", "2000-12-30")
}
func TestDateExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.NOT_EQ(table2ColDate), "(table1.col_date != table2.col_date)")
assertClauseSerialize(t, table1ColDate.NOT_EQ(dateVar), "(table1.col_date != $1::date)", "2000-12-30")
}
func TestDateExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(table2ColDate), "(table1.col_date IS DISTINCT FROM table2.col_date)")
assertClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(dateVar), "(table1.col_date IS DISTINCT FROM $1::date)", "2000-12-30")
}
func TestDateExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(table2ColDate), "(table1.col_date IS NOT DISTINCT FROM table2.col_date)")
assertClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(dateVar), "(table1.col_date IS NOT DISTINCT FROM $1::date)", "2000-12-30")
}
func TestDateExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColDate.GT(table2ColDate), "(table1.col_date > table2.col_date)")
assertClauseSerialize(t, table1ColDate.GT(dateVar), "(table1.col_date > $1::date)", "2000-12-30")
}
func TestDateExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.GT_EQ(table2ColDate), "(table1.col_date >= table2.col_date)")
assertClauseSerialize(t, table1ColDate.GT_EQ(dateVar), "(table1.col_date >= $1::date)", "2000-12-30")
}
func TestDateExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColDate.LT(table2ColDate), "(table1.col_date < table2.col_date)")
assertClauseSerialize(t, table1ColDate.LT(dateVar), "(table1.col_date < $1::date)", "2000-12-30")
}
func TestDateExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.LT_EQ(table2ColDate), "(table1.col_date <= table2.col_date)")
assertClauseSerialize(t, table1ColDate.LT_EQ(dateVar), "(table1.col_date <= $1::date)", "2000-12-30")
}

View file

@ -7,6 +7,7 @@ import (
"github.com/go-jet/jet/execution" "github.com/go-jet/jet/execution"
) )
// DeleteStatement is interface for SQL DELETE statement
type DeleteStatement interface { type DeleteStatement interface {
Statement Statement
@ -48,7 +49,7 @@ func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
return errors.New("jet: nil tableName") return errors.New("jet: nil tableName")
} }
if err := d.table.serialize(delete_statement, out); err != nil { if err := d.table.serialize(deleteStatement, out); err != nil {
return err return err
} }
@ -56,11 +57,11 @@ func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
return errors.New("jet: deleting without a WHERE clause") return errors.New("jet: deleting without a WHERE clause")
} }
if err := out.writeWhere(delete_statement, d.where); err != nil { if err := out.writeWhere(deleteStatement, d.where); err != nil {
return err return err
} }
if err := out.writeReturning(delete_statement, d.returning); err != nil { if err := out.writeReturning(deleteStatement, d.returning); err != nil {
return err return err
} }
@ -88,14 +89,14 @@ func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) er
return query(d, db, destination) return query(d, db, destination)
} }
func (d *deleteStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { func (d *deleteStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(d, db, context, destination) return queryContext(context, d, db, destination)
} }
func (d *deleteStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { func (d *deleteStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(d, db) return exec(d, db)
} }
func (d *deleteStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { func (d *deleteStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(d, db, context) return execContext(context, d, db)
} }

4
doc.go
View file

@ -1,5 +1,5 @@
/* /*
Package Jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability Package jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability
to easily convert database query result to desired arbitrary structure. to easily convert database query result to desired arbitrary structure.
*/ */
package jet package jet

View file

@ -6,6 +6,7 @@ type enumValue struct {
name string name string
} }
// NewEnumValue creates new named enum value
func NewEnumValue(name string) StringExpression { func NewEnumValue(name string) StringExpression {
enumValue := &enumValue{name: name} enumValue := &enumValue{name: name}

View file

@ -0,0 +1,12 @@
# Quick start example
This package contains sample usage for Jet framework.
Jet generated files of interest are in ./gen folder.
quick-start.go contains code explained at [README.md](../../README.md#quick-start),
with difference of redirecting json output to files(dest.json and dest2.json) rather then to a
standard output.
./gen, dest.json and dest2.json - added to git for presentation purposes.

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"io/ioutil"
// dot import so go code would resemble as much as native SQL // dot import so go code would resemble as much as native SQL
// dot import is not mandatory // dot import is not mandatory
@ -12,15 +13,26 @@ import (
. "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table" . "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table"
"github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/model" "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/model"
"github.com/go-jet/jet/tests/dbconfig" )
const (
Host = "localhost"
Port = 5432
User = "jet"
Password = "jet"
DBName = "jetdb"
) )
func main() { func main() {
db, err := sql.Open("postgres", dbconfig.ConnectString) // Connect to database
var connectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName)
db, err := sql.Open("postgres", connectString)
panicOnError(err) panicOnError(err)
defer db.Close() defer db.Close()
// Write query
stmt := SELECT( stmt := SELECT(
Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate, Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate,
Film.AllColumns, Film.AllColumns,
@ -42,24 +54,13 @@ func main() {
Film.FilmID.ASC(), Film.FilmID.ASC(),
) )
query, args, err := stmt.Sql() // Execute query and store result
panicOnError(err)
fmt.Println("Parameterized query: ")
fmt.Println(query)
fmt.Println("Arguments: ")
fmt.Println(args)
debugSql, err := stmt.DebugSql()
panicOnError(err)
fmt.Println("Debug sql: ")
fmt.Println(debugSql)
var dest []struct { var dest []struct {
model.Actor model.Actor
Films []struct { Films []struct {
model.Film model.Film
Language model.Language Language model.Language
Categories []model.Category Categories []model.Category
} }
@ -68,9 +69,8 @@ func main() {
err = stmt.Query(db, &dest) err = stmt.Query(db, &dest)
panicOnError(err) panicOnError(err)
fmt.Println("dest to json: ") printStatementInfo(stmt)
jsonText, _ := json.MarshalIndent(dest, "", "\t") jsonSave("./dest.json", dest)
fmt.Println(string(jsonText))
// New Destination // New Destination
@ -84,9 +84,35 @@ func main() {
err = stmt.Query(db, &dest2) err = stmt.Query(db, &dest2)
panicOnError(err) panicOnError(err)
fmt.Println("dest2 to json: ") jsonSave("./dest2.json", dest2)
jsonText, _ = json.MarshalIndent(dest2, "", "\t") }
fmt.Println(string(jsonText))
func jsonSave(path string, v interface{}) {
jsonText, _ := json.MarshalIndent(v, "", "\t")
err := ioutil.WriteFile(path, jsonText, 0644)
if err != nil {
panic(err)
}
}
func printStatementInfo(stmt Statement) {
query, args, err := stmt.Sql()
panicOnError(err)
fmt.Println("Parameterized query: ")
fmt.Println(query)
fmt.Println("Arguments: ")
fmt.Println(args)
debugSQL, err := stmt.DebugSql()
panicOnError(err)
fmt.Println("\n\n==============================")
fmt.Println("\n\nDebug sql: ")
fmt.Println(debugSQL)
} }
func panicOnError(err error) { func panicOnError(err error) {

View file

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
) )
// DB is common database interface used by jet execution
type DB interface { type DB interface {
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)

View file

@ -7,16 +7,19 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/go-jet/jet/execution/internal" "github.com/go-jet/jet/execution/internal"
"github.com/go-jet/jet/internal/utils"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
) )
func Query(db DB, context context.Context, query string, args []interface{}, destinationPtr interface{}) error { // Query executes query with arguments over database connection with context and stores result into destination.
// Destination can be either pointer to struct or pointer to slice of structs.
func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error {
if destinationPtr == nil { if utils.IsNil(destinationPtr) {
return errors.New("jet: Destination is nil.") return errors.New("jet: Destination is nil")
} }
destinationPtrType := reflect.TypeOf(destinationPtr) destinationPtrType := reflect.TypeOf(destinationPtr)
@ -25,12 +28,12 @@ func Query(db DB, context context.Context, query string, args []interface{}, des
} }
if destinationPtrType.Elem().Kind() == reflect.Slice { if destinationPtrType.Elem().Kind() == reflect.Slice {
return queryToSlice(db, context, query, args, destinationPtr) return queryToSlice(context, db, query, args, destinationPtr)
} else if destinationPtrType.Elem().Kind() == reflect.Struct { } else if destinationPtrType.Elem().Kind() == reflect.Struct {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem() tempSliceValue := tempSlicePtrValue.Elem()
err := queryToSlice(db, context, query, args, tempSlicePtrValue.Interface()) err := queryToSlice(context, db, query, args, tempSlicePtrValue.Interface())
if err != nil { if err != nil {
return err return err
@ -52,7 +55,7 @@ func Query(db DB, context context.Context, query string, args []interface{}, des
} }
} }
func queryToSlice(db DB, ctx context.Context, query string, args []interface{}, slicePtr interface{}) error { func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) error {
if db == nil { if db == nil {
return errors.New("jet: db is nil") return errors.New("jet: db is nil")
} }
@ -142,7 +145,8 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
structPtrValue := getSliceElemPtrAt(slicePtrValue, index) structPtrValue := getSliceElemPtrAt(slicePtrValue, index)
return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true) return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true)
} else { }
destinationStructPtr := newElemPtrValueForSlice(slicePtrValue) destinationStructPtr := newElemPtrValueForSlice(slicePtrValue)
updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field) updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field)
@ -159,7 +163,6 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
return return
} }
} }
}
return return
} }
@ -481,9 +484,8 @@ func valueToString(value reflect.Value) string {
if value.Kind() == reflect.Ptr { if value.Kind() == reflect.Ptr {
if value.IsNil() { if value.IsNil() {
return "nil" return "nil"
} else {
valueInterface = value.Elem().Interface()
} }
valueInterface = value.Elem().Interface()
} else { } else {
valueInterface = value.Interface() valueInterface = value.Interface()
} }
@ -655,13 +657,13 @@ func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.
if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok { if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok {
return s.constructGroupKey(groupKeyInfo) return s.constructGroupKey(groupKeyInfo)
} else { }
groupKeyInfo := s.getGroupKeyInfo(structType, structField) groupKeyInfo := s.getGroupKeyInfo(structType, structField)
s.groupKeyInfoCache[mapKey] = groupKeyInfo s.groupKeyInfoCache[mapKey] = groupKeyInfo
return s.constructGroupKey(groupKeyInfo) return s.constructGroupKey(groupKeyInfo)
}
} }
func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
@ -742,16 +744,6 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int {
return index return index
} }
func (s *scanContext) getCellValue(typeName, fieldName string) interface{} {
index := s.typeToColumnIndex(typeName, fieldName)
if index < 0 {
return nil
}
return s.rowElem(index)
}
func (s *scanContext) rowElem(index int) interface{} { func (s *scanContext) rowElem(index int) interface{} {
valuer, ok := s.row[index].(driver.Valuer) valuer, ok := s.row[index].(driver.Valuer)

View file

@ -5,7 +5,7 @@ import (
"time" "time"
) )
// NullByteArray // NullByteArray struct
type NullByteArray struct { type NullByteArray struct {
ByteArray []byte ByteArray []byte
Valid bool Valid bool
@ -31,7 +31,7 @@ func (nb NullByteArray) Value() (driver.Value, error) {
return nb.ByteArray, nil return nb.ByteArray, nil
} }
//NullTime // NullTime struct
type NullTime struct { type NullTime struct {
Time time.Time Time time.Time
Valid bool // Valid is true if Time is not NULL Valid bool // Valid is true if Time is not NULL
@ -51,6 +51,7 @@ func (nt NullTime) Value() (driver.Value, error) {
return nt.Time, nil return nt.Time, nil
} }
// NullInt32 struct
type NullInt32 struct { type NullInt32 struct {
Int32 int32 Int32 int32
Valid bool // Valid is true if Int64 is not NULL Valid bool // Valid is true if Int64 is not NULL
@ -83,6 +84,7 @@ func (n NullInt32) Value() (driver.Value, error) {
return n.Int32, nil return n.Int32, nil
} }
// NullInt16 struct
type NullInt16 struct { type NullInt16 struct {
Int16 int16 Int16 int16
Valid bool // Valid is true if Int64 is not NULL Valid bool // Valid is true if Int64 is not NULL
@ -115,6 +117,7 @@ func (n NullInt16) Value() (driver.Value, error) {
return n.Int16, nil return n.Int16, nil
} }
// NullFloat32 struct
type NullFloat32 struct { type NullFloat32 struct {
Float32 float32 Float32 float32
Valid bool // Valid is true if Int64 is not NULL Valid bool // Valid is true if Int64 is not NULL

View file

@ -4,12 +4,13 @@ import (
"errors" "errors"
) )
// Common expression interface // Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface { type Expression interface {
clause clause
projection projection
groupByClause groupByClause
OrderByClause orderByClause
// Test expression whether it is a NULL value. // Test expression whether it is a NULL value.
IS_NULL() BoolExpression IS_NULL() BoolExpression
@ -25,16 +26,16 @@ type Expression interface {
AS(alias string) projection AS(alias string) projection
// Expression will be used to sort query result in ascending order // Expression will be used to sort query result in ascending order
ASC() OrderByClause ASC() orderByClause
// Expression will be used to sort query result in ascending order // Expression will be used to sort query result in ascending order
DESC() OrderByClause DESC() orderByClause
} }
type expressionInterfaceImpl struct { type expressionInterfaceImpl struct {
parent Expression parent Expression
} }
func (e *expressionInterfaceImpl) from(subQuery ExpressionTable) projection { func (e *expressionInterfaceImpl) from(subQuery SelectTable) projection {
return e.parent return e.parent
} }
@ -58,11 +59,11 @@ func (e *expressionInterfaceImpl) AS(alias string) projection {
return newAlias(e.parent, alias) return newAlias(e.parent, alias)
} }
func (e *expressionInterfaceImpl) ASC() OrderByClause { func (e *expressionInterfaceImpl) ASC() orderByClause {
return newOrderByClause(e.parent, true) return newOrderByClause(e.parent, true)
} }
func (e *expressionInterfaceImpl) DESC() OrderByClause { func (e *expressionInterfaceImpl) DESC() orderByClause {
return newOrderByClause(e.parent, false) return newOrderByClause(e.parent, false)
} }
@ -145,13 +146,13 @@ func newPrefixExpression(expression Expression, operator string) prefixOpExpress
func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil { if p == nil {
return errors.New("jet: Prefix Expression is nil.") return errors.New("jet: Prefix Expression is nil")
} }
out.writeString(p.operator + " ") out.writeString(p.operator + " ")
if p.expression == nil { if p.expression == nil {
return errors.New("jet: nil prefix Expression.") return errors.New("jet: nil prefix Expression")
} }
if err := p.expression.serialize(statement, out); err != nil { if err := p.expression.serialize(statement, out); err != nil {
return err return err
@ -177,11 +178,11 @@ func newPostfixOpExpression(expression Expression, operator string) postfixOpExp
func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil { if p == nil {
return errors.New("jet: Postifx operator Expression is nil.") return errors.New("jet: Postifx operator Expression is nil")
} }
if p.expression == nil { if p.expression == nil {
return errors.New("jet: nil prefix Expression.") return errors.New("jet: nil prefix Expression")
} }
if err := p.expression.serialize(statement, out); err != nil { if err := p.expression.serialize(statement, out); err != nil {
return err return err

View file

@ -1,62 +0,0 @@
package jet
import "errors"
type ExpressionTable interface {
ReadableTable
Alias() string
AllColumns() ProjectionList
}
type expressionTableImpl struct {
readableTableInterfaceImpl
expression Expression
alias string
projections []projection
}
func newExpressionTable(expression Expression, alias string, projections []projection) ExpressionTable {
expTable := &expressionTableImpl{expression: expression, alias: alias}
expTable.readableTableInterfaceImpl.parent = expTable
for _, projection := range projections {
newProjection := projection.from(expTable)
expTable.projections = append(expTable.projections, newProjection)
}
return expTable
}
func (e *expressionTableImpl) Alias() string {
return e.alias
}
func (e *expressionTableImpl) columns() []column {
return nil
}
func (e *expressionTableImpl) AllColumns() ProjectionList {
return e.projections
}
func (e *expressionTableImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if e == nil {
return errors.New("jet: Expression table is nil. ")
}
err := e.expression.serialize(statement, out)
if err != nil {
return err
}
out.writeString("AS")
out.writeIdentifier(e.alias)
return nil
}

View file

@ -1,5 +1,6 @@
package jet package jet
//FloatExpression is interface for SQL float columns
type FloatExpression interface { type FloatExpression interface {
Expression Expression
numericExpression numericExpression
@ -115,6 +116,9 @@ func newFloatExpressionWrap(expression Expression) FloatExpression {
return &floatExpressionWrap return &floatExpressionWrap
} }
// FloatExp is date expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as float expression.
// Does not add sql cast to generated sql builder output.
func FloatExp(expression Expression) FloatExpression { func FloatExp(expression Expression) FloatExpression {
return newFloatExpressionWrap(expression) return newFloatExpressionWrap(expression)
} }

View file

@ -14,6 +14,16 @@ func TestFloatExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColFloat.NOT_EQ(Float(2.11)), "(table1.col_float != $1)", float64(2.11)) assertClauseSerialize(t, table1ColFloat.NOT_EQ(Float(2.11)), "(table1.col_float != $1)", float64(2.11))
} }
func TestFloatExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColFloat.IS_DISTINCT_FROM(table2ColFloat), "(table1.col_float IS DISTINCT FROM table2.col_float)")
assertClauseSerialize(t, table1ColFloat.IS_DISTINCT_FROM(Float(2.11)), "(table1.col_float IS DISTINCT FROM $1)", float64(2.11))
}
func TestFloatExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColFloat.IS_NOT_DISTINCT_FROM(table2ColFloat), "(table1.col_float IS NOT DISTINCT FROM table2.col_float)")
assertClauseSerialize(t, table1ColFloat.IS_NOT_DISTINCT_FROM(Float(2.11)), "(table1.col_float IS NOT DISTINCT FROM $1)", float64(2.11))
}
func TestFloatExpressionGT(t *testing.T) { func TestFloatExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColFloat.GT(table2ColFloat), "(table1.col_float > table2.col_float)") assertClauseSerialize(t, table1ColFloat.GT(table2ColFloat), "(table1.col_float > table2.col_float)")
assertClauseSerialize(t, table1ColFloat.GT(Float(2.11)), "(table1.col_float > $1)", float64(2.11)) assertClauseSerialize(t, table1ColFloat.GT(Float(2.11)), "(table1.col_float > $1)", float64(2.11))

View file

@ -2,6 +2,466 @@ package jet
import "errors" import "errors"
// ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression {
return newFunc("ROW", expressions, nil)
}
// ------------------ Mathematical functions ---------------//
// ABSf calculates absolute value from float expression
func ABSf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("ABS", floatExpression)
}
// ABSi calculates absolute value from int expression
func ABSi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("ABS", integerExpression)
}
// SQRT calculates square root of numeric expression
func SQRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("SQRT", numericExpression)
}
// CBRT calculates cube root of numeric expression
func CBRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("CBRT", numericExpression)
}
// CEIL calculates ceil of float expression
func CEIL(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("CEIL", floatExpression)
}
// FLOOR calculates floor of float expression
func FLOOR(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("FLOOR", floatExpression)
}
// ROUND calculates round of a float expressions with optional precision
func ROUND(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression {
if len(precision) > 0 {
return newFloatFunc("ROUND", floatExpression, precision[0])
}
return newFloatFunc("ROUND", floatExpression)
}
// SIGN returns sign of float expression
func SIGN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SIGN", floatExpression)
}
// TRUNC calculates trunc of float expression with optional precision
func TRUNC(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression {
if len(precision) > 0 {
return newFloatFunc("TRUNC", floatExpression, precision[0])
}
return newFloatFunc("TRUNC", floatExpression)
}
// LN calculates natural algorithm of float expression
func LN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LN", floatExpression)
}
// LOG calculates logarithm of float expression
func LOG(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LOG", floatExpression)
}
// ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression
func AVG(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("AVG", numericExpression)
}
// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none.
func BIT_AND(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_AND", integerExpression)
}
// BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none.
func BIT_OR(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_OR", integerExpression)
}
// BOOL_AND is aggregate function. Returns true if all input values are true, otherwise false
func BOOL_AND(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_AND", boolExpression)
}
// BOOL_OR is aggregate function. Returns true if at least one input value is true, otherwise false
func BOOL_OR(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_OR", boolExpression)
}
// COUNT is aggregate function. Returns number of input rows for which the value of expression is not null.
func COUNT(expression Expression) IntegerExpression {
return newIntegerFunc("COUNT", expression)
}
// EVERY is aggregate function. Returns true if all input values are true, otherwise false
func EVERY(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("EVERY", boolExpression)
}
// MAXf is aggregate function. Returns maximum value of float expression across all input values
func MAXf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MAX", floatExpression)
}
// MAXi is aggregate function. Returns maximum value of int expression across all input values
func MAXi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MAX", integerExpression)
}
// MINf is aggregate function. Returns minimum value of float expression across all input values
func MINf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MIN", floatExpression)
}
// MINi is aggregate function. Returns minimum value of int expression across all input values
func MINi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MIN", integerExpression)
}
// SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SUM", floatExpression)
}
// SUMi is aggregate function. Returns sum of expression across all integer expression.
func SUMi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("SUM", integerExpression)
}
//------------ String functions ------------------//
// BIT_LENGTH returns number of bits in string expression
func BIT_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("BIT_LENGTH", stringExpression)
}
// CHAR_LENGTH returns number of characters in string expression
func CHAR_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("CHAR_LENGTH", stringExpression)
}
// OCTET_LENGTH returns number of bytes in string expression
func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("OCTET_LENGTH", stringExpression)
}
// LOWER returns string expression in lower case
func LOWER(stringExpression StringExpression) StringExpression {
return newStringFunc("LOWER", stringExpression)
}
// UPPER returns string expression in upper case
func UPPER(stringExpression StringExpression) StringExpression {
return newStringFunc("UPPER", stringExpression)
}
// BTRIM removes the longest string consisting only of characters
// in characters (a space by default) from the start and end of string
func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 {
return newStringFunc("BTRIM", stringExpression, trimChars[0])
}
return newStringFunc("BTRIM", stringExpression)
}
// LTRIM removes the longest string containing only characters
// from characters (a space by default) from the start of string
func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 {
return newStringFunc("LTRIM", str, trimChars[0])
}
return newStringFunc("LTRIM", str)
}
// RTRIM removes the longest string containing only characters
// from characters (a space by default) from the end of string
func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 {
return newStringFunc("RTRIM", str, trimChars[0])
}
return newStringFunc("RTRIM", str)
}
// CHR returns character with the given code.
func CHR(integerExpression IntegerExpression) StringExpression {
return newStringFunc("CHR", integerExpression)
}
//
//func CONCAT(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT", expressions...)
//}
//
//func CONCAT_WS(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT_WS", expressions...)
//}
// CONVERT converts string to dest_encoding. The original encoding is
// specified by src_encoding. The string must be valid in this encoding.
func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT", str, srcEncoding, destEncoding)
}
// CONVERT_FROM converts string to the database encoding. The original
// encoding is specified by src_encoding. The string must be valid in this encoding.
func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_FROM", str, srcEncoding)
}
// CONVERT_TO converts string to dest_encoding.
func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_TO", str, toEncoding)
}
// ENCODE encodes binary data into a textual representation.
// Supported formats are: base64, hex, escape. escape converts zero bytes and
// high-bit-set bytes to octal sequences (\nnn) and doubles backslashes.
func ENCODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("ENCODE", data, format)
}
// DECODE decodes binary data from textual representation in string.
// Options for format are same as in encode.
func DECODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("DECODE", data, format)
}
//func FORMAT(formatStr StringExpression, formatArgs ...expressions) StringExpression {
// args := []expressions{formatStr}
// args = append(args, formatArgs...)
// return newStringFunc("FORMAT", args...)
//}
// INITCAP converts the first letter of each word to upper case
// and the rest to lower case. Words are sequences of alphanumeric
// characters separated by non-alphanumeric characters.
func INITCAP(str StringExpression) StringExpression {
return newStringFunc("INITCAP", str)
}
// LEFT returns first n characters in the string.
// When n is negative, return all but last |n| characters.
func LEFT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("LEFT", str, n)
}
// RIGHT returns last n characters in the string.
// When n is negative, return all but first |n| characters.
func RIGHT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("RIGHT", str, n)
}
// LENGTH returns number of characters in string with a given encoding
func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 {
return newStringFunc("LENGTH", str, encoding[0])
}
return newStringFunc("LENGTH", str)
}
// LPAD fills up the string to length length by prepending the characters
// fill (a space by default). If the string is already longer than length
// then it is truncated (on the right).
func LPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 {
return newStringFunc("LPAD", str, length, text[0])
}
return newStringFunc("LPAD", str, length)
}
// RPAD fills up the string to length length by appending the characters
// fill (a space by default). If the string is already longer than length then it is truncated.
func RPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 {
return newStringFunc("RPAD", str, length, text[0])
}
return newStringFunc("RPAD", str, length)
}
// MD5 calculates the MD5 hash of string, returning the result in hexadecimal
func MD5(stringExpression StringExpression) StringExpression {
return newStringFunc("MD5", stringExpression)
}
// REPEAT repeats string the specified number of times
func REPEAT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("REPEAT", str, n)
}
// REPLACE replaces all occurrences in string of substring from with substring to
func REPLACE(text, from, to StringExpression) StringExpression {
return newStringFunc("REPLACE", text, from, to)
}
// REVERSE returns reversed string.
func REVERSE(stringExpression StringExpression) StringExpression {
return newStringFunc("REVERSE", stringExpression)
}
// STRPOS returns location of specified substring (same as position(substring in string),
// but note the reversed argument order)
func STRPOS(str, substring StringExpression) IntegerExpression {
return newIntegerFunc("STRPOS", str, substring)
}
// SUBSTR extracts substring
func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression {
if len(count) > 0 {
return newStringFunc("SUBSTR", str, from, count[0])
}
return newStringFunc("SUBSTR", str, from)
}
// TO_ASCII convert string to ASCII from another encoding
func TO_ASCII(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 {
return newStringFunc("TO_ASCII", str, encoding[0])
}
return newStringFunc("TO_ASCII", str)
}
// TO_HEX converts number to its equivalent hexadecimal representation
func TO_HEX(number IntegerExpression) StringExpression {
return newStringFunc("TO_HEX", number)
}
//----------Data Type Formatting Functions ----------------------//
// TO_CHAR converts expression to string with format
func TO_CHAR(expression Expression, format StringExpression) StringExpression {
return newStringFunc("TO_CHAR", expression, format)
}
// TO_DATE converts string to date using format
func TO_DATE(dateStr, format StringExpression) DateExpression {
return newDateFunc("TO_DATE", dateStr, format)
}
// TO_NUMBER converts string to numeric using format
func TO_NUMBER(floatStr, format StringExpression) FloatExpression {
return newFloatFunc("TO_NUMBER", floatStr, format)
}
// TO_TIMESTAMP converts string to time stamp with time zone using format
func TO_TIMESTAMP(timestampzStr, format StringExpression) TimestampzExpression {
return newTimestampzFunc("TO_TIMESTAMP", timestampzStr, format)
}
//----------------- Date/Time Functions and Operators ---------------//
// CURRENT_DATE returns current date
func CURRENT_DATE() DateExpression {
dateFunc := newDateFunc("CURRENT_DATE")
dateFunc.noBrackets = true
return dateFunc
}
// CURRENT_TIME returns current time with time zone
func CURRENT_TIME(precision ...int) TimezExpression {
var timezFunc *timezFunc
if len(precision) > 0 {
timezFunc = newTimezFunc("CURRENT_TIME", constLiteral(precision[0]))
} else {
timezFunc = newTimezFunc("CURRENT_TIME")
}
timezFunc.noBrackets = true
return timezFunc
}
// CURRENT_TIMESTAMP returns current timestamp with time zone
func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression {
var timestampzFunc *timestampzFunc
if len(precision) > 0 {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", constLiteral(precision[0]))
} else {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP")
}
timestampzFunc.noBrackets = true
return timestampzFunc
}
// LOCALTIME returns local time of day using optional precision
func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc
if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", constLiteral(precision[0]))
} else {
timeFunc = newTimeFunc("LOCALTIME")
}
timeFunc.noBrackets = true
return timeFunc
}
// LOCALTIMESTAMP returns current date and time using optional precision
func LOCALTIMESTAMP(precision ...int) TimestampExpression {
var timestampFunc *timestampFunc
if len(precision) > 0 {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP", constLiteral(precision[0]))
} else {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP")
}
timestampFunc.noBrackets = true
return timestampFunc
}
// NOW returns current date and time
func NOW() TimestampzExpression {
return newTimestampzFunc("NOW")
}
// --------------- Conditional Expressions Functions -------------//
// COALESCE function returns the first of its arguments that is not null.
func COALESCE(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("COALESCE", allValues, nil)
}
// NULLIF function returns a null value if value1 equals value2; otherwise it returns value1.
func NULLIF(value1, value2 Expression) Expression {
return newFunc("NULLIF", []Expression{value1, value2}, nil)
}
// GREATEST selects the largest value from a list of expressions
func GREATEST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("GREATEST", allValues, nil)
}
// LEAST selects the smallest value from a list of expressions
func LEAST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("LEAST", allValues, nil)
}
//--------------------------------------------------------------------//
type funcExpressionImpl struct { type funcExpressionImpl struct {
expressionInterfaceImpl expressionInterfaceImpl
@ -175,375 +635,3 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc {
return timestampzFunc return timestampzFunc
} }
func ROW(expressions ...Expression) Expression {
return newFunc("ROW", expressions, nil)
}
// ------------------ Mathematical functions ---------------//
func ABSf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("ABS", floatExpression)
}
func ABSi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("ABS", integerExpression)
}
func SQRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("SQRT", numericExpression)
}
func CBRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("CBRT", numericExpression)
}
func CEIL(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("CEIL", floatExpression)
}
func FLOOR(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("FLOOR", floatExpression)
}
func ROUND(floatExpression FloatExpression, intExpression ...IntegerExpression) FloatExpression {
if len(intExpression) > 0 {
return newFloatFunc("ROUND", floatExpression, intExpression[0])
}
return newFloatFunc("ROUND", floatExpression)
}
func SIGN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SIGN", floatExpression)
}
func TRUNC(floatExpression FloatExpression, intExpression ...IntegerExpression) FloatExpression {
if len(intExpression) > 0 {
return newFloatFunc("TRUNC", floatExpression, intExpression[0])
}
return newFloatFunc("TRUNC", floatExpression)
}
func LN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LN", floatExpression)
}
func LOG(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LOG", floatExpression)
}
// ----------------- Aggregate functions -------------------//
func AVG(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("AVG", numericExpression)
}
func BIT_AND(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_AND", integerExpression)
}
func BIT_OR(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_OR", integerExpression)
}
func BOOL_AND(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_AND", boolExpression)
}
func BOOL_OR(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_OR", boolExpression)
}
func COUNT(expression Expression) IntegerExpression {
return newIntegerFunc("COUNT", expression)
}
func EVERY(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("EVERY", boolExpression)
}
func MAXf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MAX", floatExpression)
}
func MAXi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MAX", integerExpression)
}
func MINf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MIN", floatExpression)
}
func MINi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MIN", integerExpression)
}
func SUMf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SUM", floatExpression)
}
func SUMi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("SUM", integerExpression)
}
//------------ String functions ------------------//
func BIT_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("BIT_LENGTH", stringExpression)
}
func CHAR_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("CHAR_LENGTH", stringExpression)
}
func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("OCTET_LENGTH", stringExpression)
}
func LOWER(stringExpression StringExpression) StringExpression {
return newStringFunc("LOWER", stringExpression)
}
func UPPER(stringExpression StringExpression) StringExpression {
return newStringFunc("UPPER", stringExpression)
}
func BTRIM(stringExpression StringExpression) StringExpression {
return newStringFunc("BTRIM", stringExpression)
}
func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 {
return newStringFunc("LTRIM", str, trimChars[0])
}
return newStringFunc("LTRIM", str)
}
func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 {
return newStringFunc("RTRIM", str, trimChars[0])
}
return newStringFunc("RTRIM", str)
}
func CHR(integerExpression IntegerExpression) StringExpression {
return newStringFunc("CHR", integerExpression)
}
//
//func CONCAT(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT", expressions...)
//}
//
//func CONCAT_WS(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT_WS", expressions...)
//}
func CONVERT(str StringExpression, fromEncoding StringExpression, toEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT", str, fromEncoding, toEncoding)
}
func CONVERT_FROM(str StringExpression, fromEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_FROM", str, fromEncoding)
}
func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_TO", str, toEncoding)
}
func ENCODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("ENCODE", data, format)
}
func DECODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("DECODE", data, format)
}
//func FORMAT(formatStr StringExpression, formatArgs ...expressions) StringExpression {
// args := []expressions{formatStr}
// args = append(args, formatArgs...)
// return newStringFunc("FORMAT", args...)
//}
func INITCAP(str StringExpression) StringExpression {
return newStringFunc("INITCAP", str)
}
func LEFT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("LEFT", str, n)
}
func RIGHT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("RIGHT", str, n)
}
func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 {
return newStringFunc("LENGTH", str, encoding[0])
}
return newStringFunc("LENGTH", str)
}
func LPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 {
return newStringFunc("LPAD", str, length, text[0])
}
return newStringFunc("LPAD", str, length)
}
func RPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 {
return newStringFunc("RPAD", str, length, text[0])
}
return newStringFunc("RPAD", str, length)
}
func MD5(stringExpression StringExpression) StringExpression {
return newStringFunc("MD5", stringExpression)
}
func REPEAT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("REPEAT", str, n)
}
func REPLACE(text, from, to StringExpression) StringExpression {
return newStringFunc("REPLACE", text, from, to)
}
func REVERSE(stringExpression StringExpression) StringExpression {
return newStringFunc("REVERSE", stringExpression)
}
func STRPOS(str, substring StringExpression) IntegerExpression {
return newIntegerFunc("STRPOS", str, substring)
}
func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression {
if len(count) > 0 {
return newStringFunc("SUBSTR", str, from, count[0])
}
return newStringFunc("SUBSTR", str, from)
}
func TO_ASCII(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 {
return newStringFunc("TO_ASCII", str, encoding[0])
}
return newStringFunc("TO_ASCII", str)
}
func TO_HEX(number IntegerExpression) StringExpression {
return newStringFunc("TO_HEX", number)
}
//----------Data Type Formatting Functions ----------------------//
func TO_CHAR(expression Expression, text StringExpression) StringExpression {
return newStringFunc("TO_CHAR", expression, text)
}
func TO_DATE(dateStr, format StringExpression) DateExpression {
return newDateFunc("TO_DATE", dateStr, format)
}
func TO_NUMBER(floatStr, format StringExpression) FloatExpression {
return newFloatFunc("TO_NUMBER", floatStr, format)
}
func TO_TIMESTAMP(timestampzStr, format StringExpression) TimestampzExpression {
return newTimestampzFunc("TO_TIMESTAMP", timestampzStr, format)
}
//----------------- Date/Time Functions and Operators ---------------//
func CURRENT_DATE() DateExpression {
dateFunc := newDateFunc("CURRENT_DATE")
dateFunc.noBrackets = true
return dateFunc
}
func CURRENT_TIME(precision ...int) TimezExpression {
var timezFunc *timezFunc
if len(precision) > 0 {
timezFunc = newTimezFunc("CURRENT_TIME", constLiteral(precision[0]))
} else {
timezFunc = newTimezFunc("CURRENT_TIME")
}
timezFunc.noBrackets = true
return timezFunc
}
func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression {
var timestampzFunc *timestampzFunc
if len(precision) > 0 {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", constLiteral(precision[0]))
} else {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP")
}
timestampzFunc.noBrackets = true
return timestampzFunc
}
func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc
if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", constLiteral(precision[0]))
} else {
timeFunc = newTimeFunc("LOCALTIME")
}
timeFunc.noBrackets = true
return timeFunc
}
func LOCALTIMESTAMP(precision ...int) TimestampExpression {
var timestampFunc *timestampFunc
if len(precision) > 0 {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP", constLiteral(precision[0]))
} else {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP")
}
timestampFunc.noBrackets = true
return timestampFunc
}
func NOW() TimestampzExpression {
return newTimestampzFunc("NOW")
}
// --------------- Conditional Expressions Functions -------------//
func COALESCE(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("COALESCE", allValues, nil)
}
func NULLIF(value1, value2 Expression) Expression {
return newFunc("NULLIF", []Expression{value1, value2}, nil)
}
func GREATEST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("GREATEST", allValues, nil)
}
func LEAST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("LEAST", allValues, nil)
}

View file

@ -1,7 +1,6 @@
package jet package jet
import ( import (
"gotest.tools/assert"
"testing" "testing"
) )
@ -157,13 +156,7 @@ func TestFuncLEAST(t *testing.T) {
assertClauseSerialize(t, LEAST(Float(11.2222), NULL, String("str")), "LEAST($1, NULL, $2)", float64(11.2222), "str") assertClauseSerialize(t, LEAST(Float(11.2222), NULL, String("str")), "LEAST($1, NULL, $2)", float64(11.2222), "str")
} }
func TestInterval(t *testing.T) { func TestTO_ASCII(t *testing.T) {
query := INTERVAL(`6 years 5 months 4 days 3 hours 2 minutes 1 second`) assertClauseSerialize(t, TO_ASCII(String("Karel")), `TO_ASCII($1)`, "Karel")
assertClauseSerialize(t, TO_ASCII(String("Karel")), `TO_ASCII($1)`, "Karel")
queryData := &sqlBuilder{}
err := query.serialize(select_statement, queryData)
assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `INTERVAL $1`)
} }

View file

@ -1,5 +1,6 @@
package metadata package metadata
// MetaData interface
type MetaData interface { type MetaData interface {
Name() string Name() string
} }

View file

@ -1,4 +1,4 @@
package postgres_metadata package postgresmeta
import ( import (
"database/sql" "database/sql"
@ -7,6 +7,7 @@ import (
"strings" "strings"
) )
// ColumnInfo metadata struct
type ColumnInfo struct { type ColumnInfo struct {
Name string Name string
IsNullable bool IsNullable bool
@ -14,6 +15,7 @@ type ColumnInfo struct {
EnumName string EnumName string
} }
// SqlBuilderColumnType returns type of jet sql builder column
func (c ColumnInfo) SqlBuilderColumnType() string { func (c ColumnInfo) SqlBuilderColumnType() string {
switch c.DataType { switch c.DataType {
case "boolean": case "boolean":
@ -41,6 +43,7 @@ func (c ColumnInfo) SqlBuilderColumnType() string {
} }
} }
// GoBaseType returns model type for column info.
func (c ColumnInfo) GoBaseType() string { func (c ColumnInfo) GoBaseType() string {
switch c.DataType { switch c.DataType {
case "USER-DEFINED": case "USER-DEFINED":
@ -72,6 +75,8 @@ func (c ColumnInfo) GoBaseType() string {
} }
} }
// GoModelType returns model type for column info with optional pointer if
// column can be NULL.
func (c ColumnInfo) GoModelType() string { func (c ColumnInfo) GoModelType() string {
typeStr := c.GoBaseType() typeStr := c.GoBaseType()
if c.IsNullable { if c.IsNullable {
@ -81,6 +86,7 @@ func (c ColumnInfo) GoModelType() string {
return typeStr return typeStr
} }
// GoModelTag returns model field tag for column
func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string { func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string {
tags := []string{} tags := []string{}

View file

@ -1,15 +1,17 @@
package postgres_metadata package postgresmeta
import ( import (
"database/sql" "database/sql"
"github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/generator/internal/metadata"
) )
// EnumInfo struct
type EnumInfo struct { type EnumInfo struct {
name string name string
Values []string Values []string
} }
// Name returns enum name
func (e EnumInfo) Name() string { func (e EnumInfo) Name() string {
return e.name return e.name
} }

View file

@ -1,10 +1,11 @@
package postgres_metadata package postgresmeta
import ( import (
"database/sql" "database/sql"
"github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/generator/internal/metadata"
) )
// SchemaInfo metadata struct
type SchemaInfo struct { type SchemaInfo struct {
DatabaseName string DatabaseName string
Name string Name string
@ -12,6 +13,7 @@ type SchemaInfo struct {
EnumInfos []metadata.MetaData EnumInfos []metadata.MetaData
} }
// GetSchemaInfo returns schema information from db connection.
func GetSchemaInfo(db *sql.DB, databaseName, schemaName string) (schemaInfo SchemaInfo, err error) { func GetSchemaInfo(db *sql.DB, databaseName, schemaName string) (schemaInfo SchemaInfo, err error) {
schemaInfo.DatabaseName = databaseName schemaInfo.DatabaseName = databaseName

View file

@ -1,10 +1,11 @@
package postgres_metadata package postgresmeta
import ( import (
"database/sql" "database/sql"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
) )
// TableInfo metadata struct
type TableInfo struct { type TableInfo struct {
SchemaName string SchemaName string
name string name string
@ -12,14 +13,17 @@ type TableInfo struct {
Columns []ColumnInfo Columns []ColumnInfo
} }
// Name returns table info name
func (t TableInfo) Name() string { func (t TableInfo) Name() string {
return t.name return t.name
} }
func (t TableInfo) IsPrimaryKey(columnName string) bool { // IsPrimaryKey returns if column is a part of primary key
return t.PrimaryKeys[columnName] func (t TableInfo) IsPrimaryKey(column string) bool {
return t.PrimaryKeys[column]
} }
// MutableColumns returns list of mutable columns for table
func (t TableInfo) MutableColumns() []ColumnInfo { func (t TableInfo) MutableColumns() []ColumnInfo {
ret := []ColumnInfo{} ret := []ColumnInfo{}
@ -34,6 +38,7 @@ func (t TableInfo) MutableColumns() []ColumnInfo {
return ret return ret
} }
// GetImports returns model imports for table.
func (t TableInfo) GetImports() []string { func (t TableInfo) GetImports() []string {
imports := map[string]string{} imports := map[string]string{}
@ -57,10 +62,12 @@ func (t TableInfo) GetImports() []string {
return ret return ret
} }
// GoStructName returns go struct name for sql builder
func (t TableInfo) GoStructName() string { func (t TableInfo) GoStructName() string {
return utils.ToGoIdentifier(t.name) + "Table" return utils.ToGoIdentifier(t.name) + "Table"
} }
// GetTableInfo returns table info metadata
func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo TableInfo, err error) { func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo TableInfo, err error) {
tableInfo.SchemaName = schemaName tableInfo.SchemaName = schemaName

View file

@ -4,16 +4,17 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/metadata/postgres-metadata" "github.com/go-jet/jet/generator/internal/metadata/postgresmeta"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
_ "github.com/lib/pq"
"path" "path"
"path/filepath" "path/filepath"
"strconv"
) )
// DBConnection contains postgres connection details
type DBConnection struct { type DBConnection struct {
Host string Host string
Port string Port int
User string User string
Password string Password string
SslMode string SslMode string
@ -23,10 +24,11 @@ type DBConnection struct {
SchemaName string SchemaName string
} }
func Generate(destDir string, genData DBConnection) error { // Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s", connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
genData.Host, genData.Port, genData.User, genData.Password, genData.DBName, genData.SslMode, genData.Params) dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params)
fmt.Println("Connecting to postgres database: " + connectionString) fmt.Println("Connecting to postgres database: " + connectionString)
@ -43,7 +45,7 @@ func Generate(destDir string, genData DBConnection) error {
} }
fmt.Println("Retrieving schema information...") fmt.Println("Retrieving schema information...")
schemaInfo, err := postgres_metadata.GetSchemaInfo(db, genData.DBName, genData.SchemaName) schemaInfo, err := postgresmeta.GetSchemaInfo(db, dbConn.DBName, dbConn.SchemaName)
if err != nil { if err != nil {
return err return err
@ -55,7 +57,7 @@ func Generate(destDir string, genData DBConnection) error {
return nil return nil
} }
schemaGenPath := path.Join(destDir, genData.DBName, genData.SchemaName) schemaGenPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
fmt.Println("Destination directory:", schemaGenPath) fmt.Println("Destination directory:", schemaGenPath)
fmt.Println("Cleaning up destination directory...") fmt.Println("Cleaning up destination directory...")
err = utils.CleanUpGeneratedFiles(schemaGenPath) err = utils.CleanUpGeneratedFiles(schemaGenPath)
@ -99,7 +101,7 @@ func Generate(destDir string, genData DBConnection) error {
return nil return nil
} }
func generate(schemaInfo postgres_metadata.SchemaInfo, dirPath, packageName string, template string, metaDataList []metadata.MetaData) error { func generate(schemaInfo postgresmeta.SchemaInfo, dirPath, packageName string, template string, metaDataList []metadata.MetaData) error {
modelDirPath := filepath.Join(dirPath, schemaInfo.DatabaseName, schemaInfo.Name, packageName) modelDirPath := filepath.Join(dirPath, schemaInfo.DatabaseName, schemaInfo.Name, packageName)
err := utils.EnsureDirPath(modelDirPath) err := utils.EnsureDirPath(modelDirPath)

View file

@ -5,8 +5,10 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"github.com/go-jet/jet/execution" "github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/utils"
) )
// InsertStatement is interface for SQL INSERT statements
type InsertStatement interface { type InsertStatement interface {
Statement Statement
@ -81,11 +83,11 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
queryData.newLine() queryData.newLine()
queryData.writeString("INSERT INTO") queryData.writeString("INSERT INTO")
if isNil(i.table) { if utils.IsNil(i.table) {
return "", nil, errors.New("jet: table is nil") return "", nil, errors.New("jet: table is nil")
} }
err = i.table.serialize(insert_statement, queryData) err = i.table.serialize(insertStatement, queryData)
if err != nil { if err != nil {
return return
@ -114,8 +116,8 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
if len(i.rows) > 0 { if len(i.rows) > 0 {
queryData.writeString("VALUES") queryData.writeString("VALUES")
for row_i, row := range i.rows { for rowIndex, row := range i.rows {
if row_i > 0 { if rowIndex > 0 {
queryData.writeString(",") queryData.writeString(",")
} }
@ -123,7 +125,7 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
queryData.newLine() queryData.newLine()
queryData.writeString("(") queryData.writeString("(")
err = serializeClauseList(insert_statement, row, queryData) err = serializeClauseList(insertStatement, row, queryData)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@ -135,14 +137,14 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
} }
if i.query != nil { if i.query != nil {
err = i.query.serialize(insert_statement, queryData) err = i.query.serialize(insertStatement, queryData)
if err != nil { if err != nil {
return return
} }
} }
if err = queryData.writeReturning(insert_statement, i.returning); err != nil { if err = queryData.writeReturning(insertStatement, i.returning); err != nil {
return return
} }
@ -155,14 +157,14 @@ func (i *insertStatementImpl) Query(db execution.DB, destination interface{}) er
return query(i, db, destination) return query(i, db, destination)
} }
func (i *insertStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { func (i *insertStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(i, db, context, destination) return queryContext(context, i, db, destination)
} }
func (i *insertStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { func (i *insertStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(i, db) return exec(i, db)
} }
func (i *insertStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { func (i *insertStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(i, db, context) return execContext(context, i, db)
} }

View file

@ -81,13 +81,13 @@ func TestInsertValuesFromModel(t *testing.T) {
MODEL(toInsert). MODEL(toInsert).
MODEL(&toInsert) MODEL(&toInsert)
expectedSql := ` expectedSQL := `
INSERT INTO db.table1 (col1, col_float) VALUES INSERT INTO db.table1 (col1, col_float) VALUES
($1, $2), ($1, $2),
($3, $4); ($3, $4);
` `
assertStatement(t, stmt, expectedSql, int(1), float64(1.11), int(1), float64(1.11)) assertStatement(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11))
} }
func TestInsertValuesFromModelColumnMismatch(t *testing.T) { func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
@ -125,23 +125,23 @@ func TestInsertQuery(t *testing.T) {
stmt := table1.INSERT(table1Col1). stmt := table1.INSERT(table1Col1).
QUERY(table1.SELECT(table1Col1)) QUERY(table1.SELECT(table1Col1))
var expectedSql = ` var expectedSQL = `
INSERT INTO db.table1 (col1) ( INSERT INTO db.table1 (col1) (
SELECT table1.col1 AS "table1.col1" SELECT table1.col1 AS "table1.col1"
FROM db.table1 FROM db.table1
); );
` `
assertStatement(t, stmt, expectedSql) assertStatement(t, stmt, expectedSQL)
} }
func TestInsertDefaultValue(t *testing.T) { func TestInsertDefaultValue(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColFloat). stmt := table1.INSERT(table1Col1, table1ColFloat).
VALUES(DEFAULT, "two") VALUES(DEFAULT, "two")
var expectedSql = ` var expectedSQL = `
INSERT INTO db.table1 (col1, col_float) VALUES INSERT INTO db.table1 (col1, col_float) VALUES
(DEFAULT, $1); (DEFAULT, $1);
` `
assertStatement(t, stmt, expectedSql, "two") assertStatement(t, stmt, expectedSQL, "two")
} }

View file

@ -1,5 +1,6 @@
package jet package jet
// IntegerExpression interface
type IntegerExpression interface { type IntegerExpression interface {
Expression Expression
numericExpression numericExpression
@ -180,6 +181,9 @@ func newIntExpressionWrap(expression Expression) IntegerExpression {
return &intExpressionWrap return &intExpressionWrap
} }
// IntExp is int expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as int expression.
// Does not add sql cast to generated sql builder output.
func IntExp(expression Expression) IntegerExpression { func IntExp(expression Expression) IntegerExpression {
return newIntExpressionWrap(expression) return newIntExpressionWrap(expression)
} }

View file

@ -8,42 +8,9 @@ import (
"unicode" "unicode"
) )
// CamelToSnake converts a given string to snake case // SnakeToCamel returns a string converted from snake case to uppercase
func CamelToSnake(s string) string { func SnakeToCamel(s string) string {
var result string return snakeToCamel(s, true)
var words []string
var lastPos int
rs := []rune(s)
for i := 0; i < len(rs); i++ {
if i > 0 && unicode.IsUpper(rs[i]) {
if initialism := startsWithInitialism(s[lastPos:]); initialism != "" {
words = append(words, initialism)
i += len(initialism) - 1
lastPos = i
continue
}
words = append(words, s[lastPos:i])
lastPos = i
}
}
// append the last word
if s[lastPos:] != "" {
words = append(words, s[lastPos:])
}
for k, word := range words {
if k > 0 {
result += "_"
}
result += strings.ToLower(word)
}
return result
} }
func snakeToCamel(s string, upperCase bool) string { func snakeToCamel(s string, upperCase bool) string {
@ -54,28 +21,6 @@ func snakeToCamel(s string, upperCase bool) string {
words := strings.Split(s, "_") words := strings.Split(s, "_")
//// if there is no underscore, first try commons and then just return
//if len(words) == 1 {
// if exception := snakeToCamelExceptions[words[0]]; len(exception) > 0 {
// return exception
// }
//
// if upperCase {
// if upper := strings.ToUpper(words[0]); commonInitialisms[upper] {
// return upper
// }
// }
//
// w := []rune(s)
// if upperCase {
// w[0] = unicode.ToUpper(w[0])
// } else {
// w[0] = unicode.ToLower(w[0])
// }
//
// return string(w)
//}
for i, word := range words { for i, word := range words {
if exception := snakeToCamelExceptions[word]; len(exception) > 0 { if exception := snakeToCamelExceptions[word]; len(exception) > 0 {
result += exception result += exception
@ -117,28 +62,6 @@ func camelizeWord(word string, force bool) string {
return string(runes) return string(runes)
} }
// SnakeToCamel returns a string converted from snake case to uppercase
func SnakeToCamel(s string) string {
return snakeToCamel(s, true)
}
// SnakeToCamelLower returns a string converted from snake case to lowercase
func SnakeToCamelLower(s string) string {
return snakeToCamel(s, false)
}
// startsWithInitialism returns the initialism if the given string begins with it
func startsWithInitialism(s string) string {
var initialism string
// the longest initialism is 5 char, the shortest 2
for i := 1; i <= 5; i++ {
if len(s) > i-1 && commonInitialisms[s[:i]] {
initialism = s[:i]
}
}
return initialism
}
// commonInitialisms, taken from // commonInitialisms, taken from
// https://github.com/golang/lint/blob/206c0f020eba0f7fbcfbc467a5eb808037df2ed6/lint.go#L731 // https://github.com/golang/lint/blob/206c0f020eba0f7fbcfbc467a5eb808037df2ed6/lint.go#L731
var commonInitialisms = map[string]bool{ var commonInitialisms = map[string]bool{

View file

@ -6,56 +6,6 @@ import (
) )
var _ = Describe("Snaker", func() { var _ = Describe("Snaker", func() {
Describe("CamelToSnake test", func() {
It("should return an empty string on an empty input", func() {
Expect(CamelToSnake("")).To(Equal(""))
})
It("should work with one word", func() {
Expect(CamelToSnake("One")).To(Equal("one"))
})
It("should return an uppercase string as seperate words", func() {
Expect(CamelToSnake("ONE")).To(Equal("o_n_e"))
})
It("should return ID as lowercase", func() {
Expect(CamelToSnake("ID")).To(Equal("id"))
})
It("should work with a single lowercase character", func() {
Expect(CamelToSnake("i")).To(Equal("i"))
})
It("should work with a single uppcase character", func() {
Expect(CamelToSnake("I")).To(Equal("i"))
})
It("should return a long text as expected", func() {
Expect(CamelToSnake("ThisHasToBeConvertedCorrectlyID")).To(
Equal("this_has_to_be_converted_correctly_id"))
})
It("should return the text as expected if the initialism is in the middle", func() {
Expect(CamelToSnake("ThisIDIsFine")).To(Equal("this_id_is_fine"))
})
It("should work with long initialism", func() {
Expect(CamelToSnake("ThisHTTPSConnection")).To(Equal("this_https_connection"))
})
It("should work with multi initialisms", func() {
Expect(CamelToSnake("HelloHTTPSConnectionID")).To(Equal("hello_https_connection_id"))
})
It("sould work with concat initialisms", func() {
Expect(CamelToSnake("HTTPSID")).To(Equal("https_id"))
})
It("sould work with initialism where only certain characters are uppercase", func() {
Expect(CamelToSnake("OAuthClient")).To(Equal("oauth_client"))
})
})
Describe("SnakeToCamel test", func() { Describe("SnakeToCamel test", func() {
It("should return an empty string on an empty input", func() { It("should return an empty string on an empty input", func() {
@ -83,39 +33,8 @@ var _ = Describe("Snaker", func() {
Expect(SnakeToCamel("id")).To(Equal("ID")) Expect(SnakeToCamel("id")).To(Equal("ID"))
}) })
It("sould work with initialism where only certain characters are uppercase", func() { It("should work with initialism where only certain characters are uppercase", func() {
Expect(SnakeToCamel("oauth_client")).To(Equal("OAuthClient")) Expect(SnakeToCamel("oauth_client")).To(Equal("OAuthClient"))
}) })
}) })
Describe("SnakeToCamelLower test", func() {
It("should return an empty string on an empty input", func() {
Ω(SnakeToCamelLower("")).To(Equal(""))
})
It("should not blow up on trailing _", func() {
Ω(SnakeToCamelLower("potato_")).To(Equal("potato"))
})
It("should return a snaked text as camel case", func() {
Ω(SnakeToCamelLower("this_has_to_be_uppercased")).To(
Equal("thisHasToBeUppercased"))
})
It("should return a snaked text as camel case, except the word ID", func() {
Ω(SnakeToCamelLower("this_is_an_id")).To(Equal("thisIsAnID"))
})
It("should return 'id' not as uppercase", func() {
Ω(SnakeToCamelLower("this_is_an_identifier")).To(Equal("thisIsAnIdentifier"))
})
It("should simply work with id", func() {
Ω(SnakeToCamelLower("id")).To(Equal("id"))
})
It("should simply work with leading id", func() {
Ω(SnakeToCamelLower("id_me_please")).To(Equal("idMePlease"))
})
})
}) })

View file

@ -6,24 +6,24 @@ import (
"go/format" "go/format"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strconv" "strconv"
"strings" "strings"
"text/template" "text/template"
"time" "time"
) )
// ToGoIdentifier converts database to Go identifier.
func ToGoIdentifier(databaseIdentifier string) string { func ToGoIdentifier(databaseIdentifier string) string {
if len(databaseIdentifier) == 0 {
return databaseIdentifier
}
return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier)) return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier))
} }
// ToGoFileName converts database identifier to Go file name.
func ToGoFileName(databaseIdentifier string) string { func ToGoFileName(databaseIdentifier string) string {
return strings.ToLower(replaceInvalidChars(databaseIdentifier)) return strings.ToLower(replaceInvalidChars(databaseIdentifier))
} }
// SaveGoFile saves go file at folder dir, with name fileName and contents text.
func SaveGoFile(dirPath, fileName string, text []byte) error { func SaveGoFile(dirPath, fileName string, text []byte) error {
newGoFilePath := filepath.Join(dirPath, fileName) + ".go" newGoFilePath := filepath.Join(dirPath, fileName) + ".go"
@ -49,6 +49,7 @@ func SaveGoFile(dirPath, fileName string, text []byte) error {
return nil return nil
} }
// EnsureDirPath ensures dir path exists. If path does not exist, creates new path.
func EnsureDirPath(dirPath string) error { func EnsureDirPath(dirPath string) error {
if _, err := os.Stat(dirPath); os.IsNotExist(err) { if _, err := os.Stat(dirPath); os.IsNotExist(err) {
err := os.MkdirAll(dirPath, os.ModePerm) err := os.MkdirAll(dirPath, os.ModePerm)
@ -61,6 +62,7 @@ func EnsureDirPath(dirPath string) error {
return nil return nil
} }
// GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}) ([]byte, error) { func GenerateTemplate(templateText string, templateData interface{}) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
@ -82,6 +84,7 @@ func GenerateTemplate(templateText string, templateData interface{}) ([]byte, er
return buf.Bytes(), nil return buf.Bytes(), nil
} }
// CleanUpGeneratedFiles deletes everything at folder dir.
func CleanUpGeneratedFiles(dir string) error { func CleanUpGeneratedFiles(dir string) error {
exist, err := DirExists(dir) exist, err := DirExists(dir)
@ -100,6 +103,7 @@ func CleanUpGeneratedFiles(dir string) error {
return nil return nil
} }
// DirExists checks if folder at path exist.
func DirExists(path string) (bool, error) { func DirExists(path string) (bool, error) {
_, err := os.Stat(path) _, err := os.Stat(path)
if err == nil { if err == nil {
@ -118,8 +122,7 @@ func replaceInvalidChars(str string) string {
return str return str
} }
// github.com/lib/pq // FormatTimestamp formats t into Postgres' text format for timestamps. From: github.com/lib/pq
// FormatTimestamp formats t into Postgres' text format for timestamps.
func FormatTimestamp(t time.Time) []byte { func FormatTimestamp(t time.Time) []byte {
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the // Need to send dates before 0001 A.D. with " BC" suffix, instead of the
// minus sign preferred by Go. // minus sign preferred by Go.
@ -152,3 +155,7 @@ func FormatTimestamp(t time.Time) []byte {
} }
return b return b
} }
func IsNil(v interface{}) bool {
return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil())
}

View file

@ -6,6 +6,7 @@ import (
) )
func TestToGoIdentifier(t *testing.T) { func TestToGoIdentifier(t *testing.T) {
assert.Equal(t, ToGoIdentifier(""), "")
assert.Equal(t, ToGoIdentifier("uuid"), "UUID") assert.Equal(t, ToGoIdentifier("uuid"), "UUID")
assert.Equal(t, ToGoIdentifier("col1"), "Col1") assert.Equal(t, ToGoIdentifier("col1"), "Col1")
assert.Equal(t, ToGoIdentifier("PG-13"), "Pg13") assert.Equal(t, ToGoIdentifier("PG-13"), "Pg13")

View file

@ -1,11 +1,14 @@
package jet package jet
const ( const (
// DEFAULT is jet equivalent of SQL DEFAULT
DEFAULT keywordClause = "DEFAULT" DEFAULT keywordClause = "DEFAULT"
) )
var ( var (
// NULL is jet equivalent of SQL NULL
NULL = newNullLiteral() NULL = newNullLiteral()
// STAR is jet equivalent of SQL *
STAR = newStarLiteral() STAR = newStarLiteral()
) )

View file

@ -27,7 +27,7 @@ func (l literalExpression) serialize(statement statementType, out *sqlBuilder, o
if l.constant { if l.constant {
out.insertConstantArgument(l.value) out.insertConstantArgument(l.value)
} else { } else {
out.insertPreparedArgument(l.value) out.insertParametrizedArgument(l.value)
} }
return nil return nil
@ -38,6 +38,7 @@ type integerLiteralExpression struct {
integerInterfaceImpl integerInterfaceImpl
} }
// Int is constructor for integer expressions literals.
func Int(value int64) IntegerExpression { func Int(value int64) IntegerExpression {
numLiteral := &integerLiteralExpression{} numLiteral := &integerLiteralExpression{}
@ -55,6 +56,7 @@ type boolLiteralExpression struct {
literalExpression literalExpression
} }
// Bool creates new bool literal expression
func Bool(value bool) BoolExpression { func Bool(value bool) BoolExpression {
boolLiteralExpression := boolLiteralExpression{} boolLiteralExpression := boolLiteralExpression{}
@ -70,6 +72,7 @@ type floatLiteral struct {
literalExpression literalExpression
} }
// Float creates new float literal expression
func Float(value float64) FloatExpression { func Float(value float64) FloatExpression {
floatLiteral := floatLiteral{} floatLiteral := floatLiteral{}
floatLiteral.literalExpression = *literal(value) floatLiteral.literalExpression = *literal(value)
@ -85,6 +88,7 @@ type stringLiteral struct {
literalExpression literalExpression
} }
// String creates new string literal expression
func String(value string) StringExpression { func String(value string) StringExpression {
stringLiteral := stringLiteral{} stringLiteral := stringLiteral{}
stringLiteral.literalExpression = *literal(value) stringLiteral.literalExpression = *literal(value)
@ -100,6 +104,7 @@ type timeLiteral struct {
literalExpression literalExpression
} }
// Time creates new time literal expression
func Time(hour, minute, second, milliseconds int) TimeExpression { func Time(hour, minute, second, milliseconds int) TimeExpression {
timeLiteral := &timeLiteral{} timeLiteral := &timeLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d.%03d", hour, minute, second, milliseconds) timeStr := fmt.Sprintf("%02d:%02d:%02d.%03d", hour, minute, second, milliseconds)
@ -116,6 +121,7 @@ type timezLiteral struct {
literalExpression literalExpression
} }
// Timez creates new time with time zone literal expression
func Timez(hour, minute, second, milliseconds, timezone int) TimezExpression { func Timez(hour, minute, second, milliseconds, timezone int) TimezExpression {
timezLiteral := &timezLiteral{} timezLiteral := &timezLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d.%03d %+03d", hour, minute, second, milliseconds, timezone) timeStr := fmt.Sprintf("%02d:%02d:%02d.%03d %+03d", hour, minute, second, milliseconds, timezone)
@ -132,6 +138,7 @@ type timestampLiteral struct {
literalExpression literalExpression
} }
// Timestamp creates new timestamp literal expression
func Timestamp(year, month, day, hour, minute, second, milliseconds int) TimestampExpression { func Timestamp(year, month, day, hour, minute, second, milliseconds int) TimestampExpression {
timestampLiteral := &timestampLiteral{} timestampLiteral := &timestampLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%03d", year, month, day, hour, minute, second, milliseconds) timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%03d", year, month, day, hour, minute, second, milliseconds)
@ -148,6 +155,7 @@ type timestampzLiteral struct {
literalExpression literalExpression
} }
// Timestampz creates new timestamp with time zone literal expression
func Timestampz(year, month, day, hour, minute, second, milliseconds, timezone int) TimestampzExpression { func Timestampz(year, month, day, hour, minute, second, milliseconds, timezone int) TimestampzExpression {
timestampzLiteral := &timestampzLiteral{} timestampzLiteral := &timestampzLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%03d %+04d", timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%03d %+04d",
@ -166,6 +174,7 @@ type dateLiteral struct {
literalExpression literalExpression
} }
//Date creates new date expression
func Date(year, month, day int) DateExpression { func Date(year, month, day int) DateExpression {
dateLiteral := &dateLiteral{} dateLiteral := &dateLiteral{}
@ -226,6 +235,7 @@ func (n *wrap) serialize(statement statementType, out *sqlBuilder, options ...se
return err return err
} }
// WRAP wraps list of expressions with brackets '(' and ')'
func WRAP(expression ...Expression) Expression { func WRAP(expression ...Expression) Expression {
wrap := &wrap{expressions: expression} wrap := &wrap{expressions: expression}
wrap.expressionInterfaceImpl.parent = wrap wrap.expressionInterfaceImpl.parent = wrap
@ -245,6 +255,8 @@ func (n *rawExpression) serialize(statement statementType, out *sqlBuilder, opti
return nil return nil
} }
// RAW can be used for any unsupported functions, operators or expressions.
// For example: RAW("current_database()")
func RAW(raw string) Expression { func RAW(raw string) Expression {
rawExp := &rawExpression{raw: raw} rawExp := &rawExpression{raw: raw}
rawExp.expressionInterfaceImpl.parent = rawExp rawExp.expressionInterfaceImpl.parent = rawExp

View file

@ -7,8 +7,10 @@ import (
"github.com/go-jet/jet/execution" "github.com/go-jet/jet/execution"
) )
// TableLockMode is a type of possible SQL table lock
type TableLockMode string type TableLockMode string
// Lock types for LockStatement.
const ( const (
LOCK_ACCESS_SHARE = "ACCESS SHARE" LOCK_ACCESS_SHARE = "ACCESS SHARE"
LOCK_ROW_SHARE = "ROW SHARE" LOCK_ROW_SHARE = "ROW SHARE"
@ -20,6 +22,7 @@ const (
LOCK_ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE" LOCK_ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE"
) )
// LockStatement interface for SQL LOCK statement
type LockStatement interface { type LockStatement interface {
Statement Statement
@ -33,6 +36,7 @@ type lockStatementImpl struct {
nowait bool nowait bool
} }
// LOCK creates lock statement for list of tables.
func LOCK(tables ...WritableTable) LockStatement { func LOCK(tables ...WritableTable) LockStatement {
return &lockStatementImpl{ return &lockStatementImpl{
tables: tables, tables: tables,
@ -55,11 +59,11 @@ func (l *lockStatementImpl) DebugSql() (query string, err error) {
func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) { func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) {
if l == nil { if l == nil {
return "", nil, errors.New("jet: nil Statement.") return "", nil, errors.New("jet: nil Statement")
} }
if len(l.tables) == 0 { if len(l.tables) == 0 {
return "", nil, errors.New("jet: There is no table selected to be locked. ") return "", nil, errors.New("jet: There is no table selected to be locked")
} }
out := &sqlBuilder{} out := &sqlBuilder{}
@ -72,7 +76,7 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
out.writeString(", ") out.writeString(", ")
} }
err := table.serialize(lock_statement, out) err := table.serialize(lockStatement, out)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@ -97,14 +101,14 @@ func (l *lockStatementImpl) Query(db execution.DB, destination interface{}) erro
return query(l, db, destination) return query(l, db, destination)
} }
func (l *lockStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { func (l *lockStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(l, db, context, destination) return queryContext(context, l, db, destination)
} }
func (l *lockStatementImpl) Exec(db execution.DB) (sql.Result, error) { func (l *lockStatementImpl) Exec(db execution.DB) (sql.Result, error) {
return exec(l, db) return exec(l, db)
} }
func (l *lockStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { func (l *lockStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(l, db, context) return execContext(context, l, db)
} }

View file

@ -1,5 +1,6 @@
package jet package jet
// NumericExpression is common interface for all integer and float expressions
type NumericExpression interface { type NumericExpression interface {
Expression Expression
numericExpression numericExpression

View file

@ -4,17 +4,19 @@ import "errors"
//----------- Logical operators ---------------// //----------- Logical operators ---------------//
// Returns negation of bool expression expr // NOT returns negation of bool expression result
func NOT(exp BoolExpression) BoolExpression { func NOT(exp BoolExpression) BoolExpression {
return newPrefixBoolOperator(exp, "NOT") return newPrefixBoolOperator(exp, "NOT")
} }
// BIT_NOT inverts every bit in integer expression result
func BIT_NOT(expr IntegerExpression) IntegerExpression { func BIT_NOT(expr IntegerExpression) IntegerExpression {
return newPrefixIntegerOperator(expr, "~") return newPrefixIntegerOperator(expr, "~")
} }
//----------- Comparison operators ---------------// //----------- Comparison operators ---------------//
// EXISTS checks for existence of the rows in subQuery
func EXISTS(subQuery SelectStatement) BoolExpression { func EXISTS(subQuery SelectStatement) BoolExpression {
return newPrefixBoolOperator(subQuery, "EXISTS") return newPrefixBoolOperator(subQuery, "EXISTS")
} }
@ -59,12 +61,13 @@ func gtEq(lhs, rhs Expression) BoolExpression {
// --------------- CASE operator -------------------// // --------------- CASE operator -------------------//
type CaseOperatorExpression interface { // CaseOperator is interface for SQL case operator
type CaseOperator interface {
Expression Expression
WHEN(condition Expression) CaseOperatorExpression WHEN(condition Expression) CaseOperator
THEN(then Expression) CaseOperatorExpression THEN(then Expression) CaseOperator
ELSE(els Expression) CaseOperatorExpression ELSE(els Expression) CaseOperator
} }
type caseOperatorImpl struct { type caseOperatorImpl struct {
@ -76,7 +79,8 @@ type caseOperatorImpl struct {
els Expression els Expression
} }
func CASE(expression ...Expression) CaseOperatorExpression { // CASE create CASE operator with optional list of expressions
func CASE(expression ...Expression) CaseOperator {
caseExp := &caseOperatorImpl{} caseExp := &caseOperatorImpl{}
if len(expression) > 0 { if len(expression) > 0 {
@ -88,17 +92,17 @@ func CASE(expression ...Expression) CaseOperatorExpression {
return caseExp return caseExp
} }
func (c *caseOperatorImpl) WHEN(when Expression) CaseOperatorExpression { func (c *caseOperatorImpl) WHEN(when Expression) CaseOperator {
c.when = append(c.when, when) c.when = append(c.when, when)
return c return c
} }
func (c *caseOperatorImpl) THEN(then Expression) CaseOperatorExpression { func (c *caseOperatorImpl) THEN(then Expression) CaseOperator {
c.then = append(c.then, then) c.then = append(c.then, then)
return c return c
} }
func (c *caseOperatorImpl) ELSE(els Expression) CaseOperatorExpression { func (c *caseOperatorImpl) ELSE(els Expression) CaseOperator {
c.els = els c.els = els
return c return c

View file

@ -5,6 +5,7 @@ import "testing"
func TestOperatorNOT(t *testing.T) { func TestOperatorNOT(t *testing.T) {
notExpression := NOT(Int(2).EQ(Int(1))) notExpression := NOT(Int(2).EQ(Int(1)))
assertClauseSerialize(t, NOT(table1ColBool), "NOT table1.col_bool")
assertClauseSerialize(t, notExpression, "NOT ($1 = $2)", int64(2), int64(1)) assertClauseSerialize(t, notExpression, "NOT ($1 = $2)", int64(2), int64(1))
assertProjectionSerialize(t, notExpression.AS("alias_not_expression"), `NOT ($1 = $2) AS "alias_not_expression"`, int64(2), int64(1)) assertProjectionSerialize(t, notExpression.AS("alias_not_expression"), `NOT ($1 = $2) AS "alias_not_expression"`, int64(2), int64(1))
assertClauseSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `(NOT ($1 = $2) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5)) assertClauseSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `(NOT ($1 = $2) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5))

View file

@ -2,7 +2,8 @@ package jet
import "errors" import "errors"
type OrderByClause interface { // OrderByClause
type orderByClause interface {
serializeForOrderBy(statement statementType, out *sqlBuilder) error serializeForOrderBy(statement statementType, out *sqlBuilder) error
} }
@ -13,7 +14,7 @@ type orderByClauseImpl struct {
func (o *orderByClauseImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error { func (o *orderByClauseImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
if o.expression == nil { if o.expression == nil {
return errors.New("jet: nil orderBy by clause.") return errors.New("jet: nil orderBy by clause")
} }
if err := o.expression.serializeForOrderBy(statement, out); err != nil { if err := o.expression.serializeForOrderBy(statement, out); err != nil {
@ -29,6 +30,6 @@ func (o *orderByClauseImpl) serializeForOrderBy(statement statementType, out *sq
return nil return nil
} }
func newOrderByClause(expression Expression, ascent bool) OrderByClause { func newOrderByClause(expression Expression, ascent bool) orderByClause {
return &orderByClauseImpl{expression: expression, ascent: ascent} return &orderByClauseImpl{expression: expression, ascent: ascent}
} }

View file

@ -2,12 +2,13 @@ package jet
type projection interface { type projection interface {
serializeForProjection(statement statementType, out *sqlBuilder) error serializeForProjection(statement statementType, out *sqlBuilder) error
from(subQuery ExpressionTable) projection from(subQuery SelectTable) projection
} }
// ProjectionList is a redefined type, so that ProjectionList can be used as a projection.
type ProjectionList []projection type ProjectionList []projection
func (cl ProjectionList) from(subQuery ExpressionTable) projection { func (cl ProjectionList) from(subQuery SelectTable) projection {
newProjectionList := ProjectionList{} newProjectionList := ProjectionList{}
for _, projection := range cl { for _, projection := range cl {

View file

@ -7,6 +7,7 @@ import (
"github.com/go-jet/jet/execution" "github.com/go-jet/jet/execution"
) )
// Select statements lock types
var ( var (
UPDATE = newLock("UPDATE") UPDATE = newLock("UPDATE")
NO_KEY_UPDATE = newLock("NO KEY UPDATE") NO_KEY_UPDATE = newLock("NO KEY UPDATE")
@ -14,6 +15,7 @@ var (
KEY_SHARE = newLock("KEY SHARE") KEY_SHARE = newLock("KEY SHARE")
) )
// SelectStatement is interface for SQL SELECT statements
type SelectStatement interface { type SelectStatement interface {
Statement Statement
Expression Expression
@ -23,7 +25,7 @@ type SelectStatement interface {
WHERE(expression BoolExpression) SelectStatement WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...groupByClause) SelectStatement GROUP_BY(groupByClauses ...groupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement
ORDER_BY(orderByClauses ...OrderByClause) SelectStatement ORDER_BY(orderByClauses ...orderByClause) SelectStatement
LIMIT(limit int64) SelectStatement LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement OFFSET(offset int64) SelectStatement
FOR(lock SelectLock) SelectStatement FOR(lock SelectLock) SelectStatement
@ -35,11 +37,12 @@ type SelectStatement interface {
EXCEPT(rhs SelectStatement) SelectStatement EXCEPT(rhs SelectStatement) SelectStatement
EXCEPT_ALL(rhs SelectStatement) SelectStatement EXCEPT_ALL(rhs SelectStatement) SelectStatement
AsTable(alias string) ExpressionTable AsTable(alias string) SelectTable
projections() []projection projections() []projection
} }
//SELECT creates new SelectStatement with list of projections
func SELECT(projection1 projection, projections ...projection) SelectStatement { func SELECT(projection1 projection, projections ...projection) SelectStatement {
return newSelectStatement(nil, append([]projection{projection1}, projections...)) return newSelectStatement(nil, append([]projection{projection1}, projections...))
} }
@ -55,7 +58,7 @@ type selectStatementImpl struct {
groupBy []groupByClause groupBy []groupByClause
having BoolExpression having BoolExpression
orderBy []OrderByClause orderBy []orderByClause
limit, offset int64 limit, offset int64
lockFor SelectLock lockFor SelectLock
@ -81,8 +84,8 @@ func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement {
return s.parent return s.parent
} }
func (s *selectStatementImpl) AsTable(alias string) ExpressionTable { func (s *selectStatementImpl) AsTable(alias string) SelectTable {
return newExpressionTable(s.parent, alias, s.parent.projections()) return newSelectTable(s.parent, alias)
} }
func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement {
@ -100,7 +103,7 @@ func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement
return s.parent return s.parent
} }
func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement { func (s *selectStatementImpl) ORDER_BY(clauses ...orderByClause) SelectStatement {
s.orderBy = clauses s.orderBy = clauses
return s.parent return s.parent
} }
@ -189,20 +192,20 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
return errors.New("jet: no column selected for projection") return errors.New("jet: no column selected for projection")
} }
err := out.writeProjections(select_statement, s.projectionList) err := out.writeProjections(selectStatement, s.projectionList)
if err != nil { if err != nil {
return err return err
} }
if s.table != nil { if s.table != nil {
if err := out.writeFrom(select_statement, s.table); err != nil { if err := out.writeFrom(selectStatement, s.table); err != nil {
return err return err
} }
} }
if s.where != nil { if s.where != nil {
err := out.writeWhere(select_statement, s.where) err := out.writeWhere(selectStatement, s.where)
if err != nil { if err != nil {
return nil return nil
@ -210,7 +213,7 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
} }
if s.groupBy != nil && len(s.groupBy) > 0 { if s.groupBy != nil && len(s.groupBy) > 0 {
err := out.writeGroupBy(select_statement, s.groupBy) err := out.writeGroupBy(selectStatement, s.groupBy)
if err != nil { if err != nil {
return err return err
@ -218,7 +221,7 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
} }
if s.having != nil { if s.having != nil {
err := out.writeHaving(select_statement, s.having) err := out.writeHaving(selectStatement, s.having)
if err != nil { if err != nil {
return err return err
@ -226,7 +229,7 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
} }
if s.orderBy != nil { if s.orderBy != nil {
err := out.writeOrderBy(select_statement, s.orderBy) err := out.writeOrderBy(selectStatement, s.orderBy)
if err != nil { if err != nil {
return err return err
@ -236,19 +239,19 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
if s.limit >= 0 { if s.limit >= 0 {
out.newLine() out.newLine()
out.writeString("LIMIT") out.writeString("LIMIT")
out.insertPreparedArgument(s.limit) out.insertParametrizedArgument(s.limit)
} }
if s.offset >= 0 { if s.offset >= 0 {
out.newLine() out.newLine()
out.writeString("OFFSET") out.writeString("OFFSET")
out.insertPreparedArgument(s.offset) out.insertParametrizedArgument(s.offset)
} }
if s.lockFor != nil { if s.lockFor != nil {
out.newLine() out.newLine()
out.writeString("FOR") out.writeString("FOR")
err := s.lockFor.serialize(select_statement, out) err := s.lockFor.serialize(selectStatement, out)
if err != nil { if err != nil {
return err return err
@ -280,20 +283,19 @@ func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) er
return query(s.parent, db, destination) return query(s.parent, db, destination)
} }
func (s *selectStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { func (s *selectStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(s.parent, db, context, destination) return queryContext(context, s.parent, db, destination)
} }
func (s *selectStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { func (s *selectStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(s.parent, db) return exec(s.parent, db)
} }
func (s *selectStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { func (s *selectStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(s.parent, db, context) return execContext(context, s.parent, db)
} }
// SelectLock // SelectLock is interface for SELECT statement locks
type SelectLock interface { type SelectLock interface {
clause clause

View file

@ -122,3 +122,91 @@ FROM db.table1
FOR NO KEY UPDATE SKIP LOCKED; FOR NO KEY UPDATE SKIP LOCKED;
`) `)
} }
func TestSelectSets(t *testing.T) {
select1 := SELECT(table1ColBool).FROM(table1)
select2 := SELECT(table2ColBool).FROM(table2)
assertStatement(t, select1.UNION(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
UNION
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.UNION_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
UNION ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.INTERSECT(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
INTERSECT
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.INTERSECT_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
INTERSECT ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.EXCEPT(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
EXCEPT
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.EXCEPT_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
EXCEPT ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
}

63
select_table.go Normal file
View file

@ -0,0 +1,63 @@
package jet
import "errors"
// SelectTable is interface for SELECT sub-queries
type SelectTable interface {
ReadableTable
Alias() string
AllColumns() ProjectionList
}
type selectTableImpl struct {
readableTableInterfaceImpl
selectStmt SelectStatement
alias string
projections []projection
}
func newSelectTable(selectStmt SelectStatement, alias string) SelectTable {
expTable := &selectTableImpl{selectStmt: selectStmt, alias: alias}
expTable.readableTableInterfaceImpl.parent = expTable
for _, projection := range selectStmt.projections() {
newProjection := projection.from(expTable)
expTable.projections = append(expTable.projections, newProjection)
}
return expTable
}
func (s *selectTableImpl) Alias() string {
return s.alias
}
func (s *selectTableImpl) columns() []column {
return nil
}
func (s *selectTableImpl) AllColumns() ProjectionList {
return s.projections
}
func (s *selectTableImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if s == nil {
return errors.New("jet: Expression table is nil. ")
}
err := s.selectStmt.serialize(statement, out)
if err != nil {
return err
}
out.writeString("AS")
out.writeIdentifier(s.alias)
return nil
}

View file

@ -4,28 +4,40 @@ import (
"errors" "errors"
) )
// UNION effectively appends the result of sub-queries(select statements) into single query.
// It eliminates duplicate rows from its result.
func UNION(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { func UNION(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...)) return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...))
} }
// UNION_ALL effectively appends the result of sub-queries(select statements) into single query.
// It does not eliminates duplicate rows from its result.
func UNION_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { func UNION_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...)) return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...))
} }
// INTERSECT returns all rows that are in query results.
// It eliminates duplicate rows from its result.
func INTERSECT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { func INTERSECT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(intersect, false, toSelectList(lhs, rhs, selects...)) return newSetStatementImpl(intersect, false, toSelectList(lhs, rhs, selects...))
} }
// INTERSECT_ALL returns all rows that are in query results.
// It does not eliminates duplicate rows from its result.
func INTERSECT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { func INTERSECT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(intersect, true, toSelectList(lhs, rhs, selects...)) return newSetStatementImpl(intersect, true, toSelectList(lhs, rhs, selects...))
} }
func EXCEPT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { // EXCEPT returns all rows that are in the result of query lhs but not in the result of query rhs.
return newSetStatementImpl(except, false, toSelectList(lhs, rhs, selects...)) // It eliminates duplicate rows from its result.
func EXCEPT(lhs, rhs SelectStatement) SelectStatement {
return newSetStatementImpl(except, false, toSelectList(lhs, rhs))
} }
func EXCEPT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { // EXCEPT_ALL returns all rows that are in the result of query lhs but not in the result of query rhs.
return newSetStatementImpl(except, true, toSelectList(lhs, rhs, selects...)) // It does not eliminates duplicate rows from its result.
func EXCEPT_ALL(lhs, rhs SelectStatement) SelectStatement {
return newSetStatementImpl(except, true, toSelectList(lhs, rhs))
} }
func toSelectList(lhs, rhs SelectStatement, selects ...SelectStatement) []SelectStatement { func toSelectList(lhs, rhs SelectStatement, selects ...SelectStatement) []SelectStatement {
@ -102,7 +114,7 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error {
} }
if len(s.selects) < 2 { if len(s.selects) < 2 {
return errors.New("jet: UNION Statement must have at least two SELECT statements.") return errors.New("jet: UNION Statement must have at least two SELECT statements")
} }
out.newLine() out.newLine()
@ -124,7 +136,7 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error {
return errors.New("jet: select statement is nil") return errors.New("jet: select statement is nil")
} }
err := selectStmt.serialize(set_statement, out) err := selectStmt.serialize(setStatement, out)
if err != nil { if err != nil {
return err return err
@ -136,7 +148,7 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error {
out.writeString(")") out.writeString(")")
if s.orderBy != nil { if s.orderBy != nil {
err := out.writeOrderBy(set_statement, s.orderBy) err := out.writeOrderBy(setStatement, s.orderBy)
if err != nil { if err != nil {
return err return err
} }
@ -145,13 +157,13 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error {
if s.limit >= 0 { if s.limit >= 0 {
out.newLine() out.newLine()
out.writeString("LIMIT") out.writeString("LIMIT")
out.insertPreparedArgument(s.limit) out.insertParametrizedArgument(s.limit)
} }
if s.offset >= 0 { if s.offset >= 0 {
out.newLine() out.newLine()
out.writeString("OFFSET") out.writeString("OFFSET")
out.insertPreparedArgument(s.offset) out.insertParametrizedArgument(s.offset)
} }
return nil return nil

View file

@ -6,7 +6,7 @@ import (
) )
func TestUnionTwoSelect(t *testing.T) { func TestUnionTwoSelect(t *testing.T) {
var expectedSql = ` var expectedSQL = `
( (
( (
SELECT table1.col1 AS "table1.col1" SELECT table1.col1 AS "table1.col1"
@ -27,8 +27,8 @@ func TestUnionTwoSelect(t *testing.T) {
unionStmt2 := UNION(table1.SELECT(table1Col1), table2.SELECT(table2Col3)) unionStmt2 := UNION(table1.SELECT(table1Col1), table2.SELECT(table2Col3))
assertStatement(t, unionStmt1, expectedSql) assertStatement(t, unionStmt1, expectedSQL)
assertStatement(t, unionStmt2, expectedSql) assertStatement(t, unionStmt2, expectedSQL)
} }
func TestUnionNilSelect(t *testing.T) { func TestUnionNilSelect(t *testing.T) {
@ -49,7 +49,7 @@ func TestUnionThreeSelect1(t *testing.T) {
table3.SELECT(table3Col1), table3.SELECT(table3Col1),
) )
var expectedSql = ` var expectedSQL = `
( (
( (
@ -71,7 +71,7 @@ func TestUnionThreeSelect1(t *testing.T) {
); );
` `
assertStatement(t, unionStmt1, expectedSql) assertStatement(t, unionStmt1, expectedSQL)
} }
func TestUnionThreeSelect2(t *testing.T) { func TestUnionThreeSelect2(t *testing.T) {
@ -82,7 +82,7 @@ func TestUnionThreeSelect2(t *testing.T) {
table3.SELECT(table3Col1), table3.SELECT(table3Col1),
) )
var expectedSql = ` var expectedSQL = `
( (
( (
SELECT table1.col1 AS "table1.col1" SELECT table1.col1 AS "table1.col1"
@ -101,7 +101,7 @@ func TestUnionThreeSelect2(t *testing.T) {
); );
` `
assertStatement(t, unionStmt2, expectedSql) assertStatement(t, unionStmt2, expectedSQL)
} }
func TestUnionWithOrderBy(t *testing.T) { func TestUnionWithOrderBy(t *testing.T) {
@ -155,7 +155,7 @@ OFFSET $2;
} }
func TestUnionInUnion(t *testing.T) { func TestUnionInUnion(t *testing.T) {
expectedSql := ` expectedSQL := `
( (
( (
SELECT table2.col3 AS "table2.col3", SELECT table2.col3 AS "table2.col3",
@ -182,7 +182,7 @@ func TestUnionInUnion(t *testing.T) {
UNION_ALL(table1.SELECT(table1Col1), table2.SELECT(table2Col3)), UNION_ALL(table1.SELECT(table1Col1), table2.SELECT(table2Col3)),
) )
assertStatement(t, query, expectedSql) assertStatement(t, query, expectedSQL)
} }
func TestUnionALL(t *testing.T) { func TestUnionALL(t *testing.T) {

View file

@ -8,6 +8,7 @@ import (
"strings" "strings"
) )
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
type Statement interface { type Statement interface {
// Sql returns parametrized sql query with list of arguments. // Sql returns parametrized sql query with list of arguments.
// err is returned if statement is not composed correctly // err is returned if statement is not composed correctly
@ -22,12 +23,12 @@ type Statement interface {
Query(db execution.DB, destination interface{}) error Query(db execution.DB, destination interface{}) error
// QueryContext executes statement with a context over database connection db and stores row result in destination. // QueryContext executes statement with a context over database connection db and stores row result in destination.
// Destination can be of arbitrary structure // Destination can be of arbitrary structure
QueryContext(db execution.DB, context context.Context, destination interface{}) error QueryContext(context context.Context, db execution.DB, destination interface{}) error
//Exec executes statement over db connection without returning any rows. //Exec executes statement over db connection without returning any rows.
Exec(db execution.DB) (sql.Result, error) Exec(db execution.DB) (sql.Result, error)
//Exec executes statement with context over db connection without returning any rows. //Exec executes statement with context over db connection without returning any rows.
ExecContext(db execution.DB, context context.Context) (sql.Result, error) ExecContext(context context.Context, db execution.DB) (sql.Result, error)
} }
func debugSql(statement Statement) (string, error) { func debugSql(statement Statement) (string, error) {
@ -37,14 +38,14 @@ func debugSql(statement Statement) (string, error) {
return "", err return "", err
} }
debugSqlQuery := sqlQuery debugSQLQuery := sqlQuery
for i, arg := range args { for i, arg := range args {
argPlaceholder := "$" + strconv.Itoa(i+1) argPlaceholder := "$" + strconv.Itoa(i+1)
debugSqlQuery = strings.Replace(debugSqlQuery, argPlaceholder, ArgToString(arg), 1) debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1)
} }
return debugSqlQuery, nil return debugSQLQuery, nil
} }
func query(statement Statement, db execution.DB, destination interface{}) error { func query(statement Statement, db execution.DB, destination interface{}) error {
@ -54,17 +55,17 @@ func query(statement Statement, db execution.DB, destination interface{}) error
return err return err
} }
return execution.Query(db, context.Background(), query, args, destination) return execution.Query(context.Background(), db, query, args, destination)
} }
func queryContext(statement Statement, db execution.DB, context context.Context, destination interface{}) error { func queryContext(context context.Context, statement Statement, db execution.DB, destination interface{}) error {
query, args, err := statement.Sql() query, args, err := statement.Sql()
if err != nil { if err != nil {
return err return err
} }
return execution.Query(db, context, query, args, destination) return execution.Query(context, db, query, args, destination)
} }
func exec(statement Statement, db execution.DB) (res sql.Result, err error) { func exec(statement Statement, db execution.DB) (res sql.Result, err error) {
@ -77,7 +78,7 @@ func exec(statement Statement, db execution.DB) (res sql.Result, err error) {
return db.Exec(query, args...) return db.Exec(query, args...)
} }
func execContext(statement Statement, db execution.DB, context context.Context) (res sql.Result, err error) { func execContext(context context.Context, statement Statement, db execution.DB) (res sql.Result, err error) {
query, args, err := statement.Sql() query, args, err := statement.Sql()
if err != nil { if err != nil {

View file

@ -1,5 +1,6 @@
package jet package jet
// StringExpression interface
type StringExpression interface { type StringExpression interface {
Expression Expression
@ -108,6 +109,9 @@ func newStringExpressionWrap(expression Expression) StringExpression {
return &stringExpressionWrap return &stringExpressionWrap
} }
// StringExp is string expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as string expression.
// Does not add sql cast to generated sql builder output.
func StringExp(expression Expression) StringExpression { func StringExp(expression Expression) StringExpression {
return newStringExpressionWrap(expression) return newStringExpressionWrap(expression)
} }

View file

@ -17,6 +17,16 @@ func TestStringNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_EQ(String("JOHN")), "(table3.col2 != $1)", "JOHN") assertClauseSerialize(t, table3StrCol.NOT_EQ(String("JOHN")), "(table3.col2 != $1)", "JOHN")
} }
func TestStringExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table3StrCol.IS_DISTINCT_FROM(table2ColStr), "(table3.col2 IS DISTINCT FROM table2.col_str)")
assertClauseSerialize(t, table3StrCol.IS_DISTINCT_FROM(String("JOHN")), "(table3.col2 IS DISTINCT FROM $1)", "JOHN")
}
func TestStringExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table3StrCol.IS_NOT_DISTINCT_FROM(table2ColStr), "(table3.col2 IS NOT DISTINCT FROM table2.col_str)")
assertClauseSerialize(t, table3StrCol.IS_NOT_DISTINCT_FROM(String("JOHN")), "(table3.col2 IS NOT DISTINCT FROM $1)", "JOHN")
}
func TestStringGT(t *testing.T) { func TestStringGT(t *testing.T) {
exp := table3StrCol.GT(table2ColStr) exp := table3StrCol.GT(table2ColStr)
assertClauseSerialize(t, exp, "(table3.col2 > table2.col_str)") assertClauseSerialize(t, exp, "(table3.col2 > table2.col_str)")

View file

@ -2,6 +2,7 @@ package jet
import ( import (
"errors" "errors"
"github.com/go-jet/jet/internal/utils"
) )
type table interface { type table interface {
@ -36,18 +37,21 @@ type writableTable interface {
LOCK() LockStatement LOCK() LockStatement
} }
// ReadableTable interface
type ReadableTable interface { type ReadableTable interface {
table table
readableTable readableTable
clause clause
} }
// WritableTable interface
type WritableTable interface { type WritableTable interface {
table table
writableTable writableTable
clause clause
} }
// Table interface
type Table interface { type Table interface {
table table
readableTable readableTable
@ -110,6 +114,7 @@ func (w *writableTableInterfaceImpl) LOCK() LockStatement {
return LOCK(w.parent) return LOCK(w.parent)
} }
// NewTable creates new table with schema name, table name and list of columns
func NewTable(schemaName, name string, columns ...Column) Table { func NewTable(schemaName, name string, columns ...Column) Table {
t := &tableImpl{ t := &tableImpl{
@ -196,20 +201,20 @@ type joinTable struct {
lhs ReadableTable lhs ReadableTable
rhs ReadableTable rhs ReadableTable
join_type joinType joinType joinType
onCondition BoolExpression onCondition BoolExpression
} }
func newJoinTable( func newJoinTable(
lhs ReadableTable, lhs ReadableTable,
rhs ReadableTable, rhs ReadableTable,
join_type joinType, joinType joinType,
onCondition BoolExpression) ReadableTable { onCondition BoolExpression) ReadableTable {
joinTable := &joinTable{ joinTable := &joinTable{
lhs: lhs, lhs: lhs,
rhs: rhs, rhs: rhs,
join_type: join_type, joinType: joinType,
onCondition: onCondition, onCondition: onCondition,
} }
@ -235,7 +240,7 @@ func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options
return errors.New("jet: Join table is nil. ") return errors.New("jet: Join table is nil. ")
} }
if isNil(t.lhs) { if utils.IsNil(t.lhs) {
return errors.New("jet: left hand side of join operation is nil table") return errors.New("jet: left hand side of join operation is nil table")
} }
@ -245,7 +250,7 @@ func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options
out.newLine() out.newLine()
switch t.join_type { switch t.joinType {
case innerJoin: case innerJoin:
out.writeString("INNER JOIN") out.writeString("INNER JOIN")
case leftJoin: case leftJoin:
@ -258,7 +263,7 @@ func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options
out.writeString("CROSS JOIN") out.writeString("CROSS JOIN")
} }
if isNil(t.rhs) { if utils.IsNil(t.rhs) {
return errors.New("jet: right hand side of join operation is nil table") return errors.New("jet: right hand side of join operation is nil table")
} }
@ -266,7 +271,7 @@ func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options
return return
} }
if t.onCondition == nil && t.join_type != crossJoin { if t.onCondition == nil && t.joinType != crossJoin {
return errors.New("jet: join condition is nil") return errors.New("jet: join condition is nil")
} }

View file

@ -127,7 +127,10 @@ func TestStringOperators(t *testing.T) {
LOWER(AllTypes.CharacterVaryingPtr), LOWER(AllTypes.CharacterVaryingPtr),
UPPER(AllTypes.Character), UPPER(AllTypes.Character),
BTRIM(AllTypes.CharacterVarying), BTRIM(AllTypes.CharacterVarying),
BTRIM(AllTypes.CharacterVarying, String("AA")),
LTRIM(AllTypes.CharacterVarying),
LTRIM(AllTypes.CharacterVarying, String("A")), LTRIM(AllTypes.CharacterVarying, String("A")),
RTRIM(AllTypes.CharacterVarying),
RTRIM(AllTypes.CharacterVarying, String("B")), RTRIM(AllTypes.CharacterVarying, String("B")),
CHR(Int(65)), CHR(Int(65)),
//CONCAT(String("string1"), Int(1), Float(11.12)), //CONCAT(String("string1"), Int(1), Float(11.12)),
@ -143,13 +146,16 @@ func TestStringOperators(t *testing.T) {
RIGHT(String("abcde"), Int(2)), RIGHT(String("abcde"), Int(2)),
LENGTH(String("jose")), LENGTH(String("jose")),
LENGTH(String("jose"), String("UTF8")), LENGTH(String("jose"), String("UTF8")),
LPAD(String("Hi"), Int(5)),
LPAD(String("Hi"), Int(5), String("xy")), LPAD(String("Hi"), Int(5), String("xy")),
RPAD(String("Hi"), Int(5)),
RPAD(String("Hi"), Int(5), String("xy")), RPAD(String("Hi"), Int(5), String("xy")),
MD5(AllTypes.CharacterVarying), MD5(AllTypes.CharacterVarying),
REPEAT(AllTypes.Text, Int(33)), REPEAT(AllTypes.Text, Int(33)),
REPLACE(AllTypes.Character, String("BA"), String("AB")), REPLACE(AllTypes.Character, String("BA"), String("AB")),
REVERSE(AllTypes.CharacterVarying), REVERSE(AllTypes.CharacterVarying),
STRPOS(AllTypes.Text, String("A")), STRPOS(AllTypes.Text, String("A")),
SUBSTR(AllTypes.CharacterPtr, Int(3)),
SUBSTR(AllTypes.CharacterPtr, Int(3), Int(2)), SUBSTR(AllTypes.CharacterPtr, Int(3), Int(2)),
TO_HEX(AllTypes.IntegerPtr), TO_HEX(AllTypes.IntegerPtr),
) )
@ -229,7 +235,7 @@ func TestFloatOperators(t *testing.T) {
CEIL(AllTypes.Real), CEIL(AllTypes.Real),
FLOOR(AllTypes.Real), FLOOR(AllTypes.Real),
ROUND(AllTypes.Decimal), ROUND(AllTypes.Decimal),
ROUND(AllTypes.Decimal, Int(3)).AS("round"), ROUND(AllTypes.Decimal, AllTypes.Integer).AS("round"),
SIGN(AllTypes.Real), SIGN(AllTypes.Real),
TRUNC(AllTypes.Decimal), TRUNC(AllTypes.Decimal),
TRUNC(AllTypes.Decimal, Int(1)), TRUNC(AllTypes.Decimal, Int(1)),
@ -361,7 +367,7 @@ func TestSubQueryColumnReference(t *testing.T) {
args []interface{} args []interface{}
} }
subQueries := map[ExpressionTable]expected{} subQueries := map[SelectTable]expected{}
selectSubQuery := AllTypes.SELECT( selectSubQuery := AllTypes.SELECT(
AllTypes.Boolean, AllTypes.Boolean,
@ -378,7 +384,7 @@ func TestSubQueryColumnReference(t *testing.T) {
LIMIT(2). LIMIT(2).
AsTable("subQuery") AsTable("subQuery")
var selectExpectedSql = ` ( var selectexpectedSQL = ` (
SELECT all_types.boolean AS "all_types.boolean", SELECT all_types.boolean AS "all_types.boolean",
all_types.integer AS "all_types.integer", all_types.integer AS "all_types.integer",
all_types.real AS "all_types.real", all_types.real AS "all_types.real",
@ -424,7 +430,7 @@ func TestSubQueryColumnReference(t *testing.T) {
). ).
AsTable("subQuery") AsTable("subQuery")
unionExpectedSql := ` unionexpectedSQL := `
( (
( (
SELECT all_types.boolean AS "all_types.boolean", SELECT all_types.boolean AS "all_types.boolean",
@ -458,8 +464,8 @@ func TestSubQueryColumnReference(t *testing.T) {
) )
) AS "subQuery"` ) AS "subQuery"`
subQueries[selectSubQuery] = expected{sql: selectExpectedSql, args: []interface{}{int64(2)}} subQueries[selectSubQuery] = expected{sql: selectexpectedSQL, args: []interface{}{int64(2)}}
subQueries[unionSubQuery] = expected{sql: unionExpectedSql, args: []interface{}{int64(1), int64(1), int64(1)}} subQueries[unionSubQuery] = expected{sql: unionexpectedSQL, args: []interface{}{int64(1), int64(1), int64(1)}}
for subQuery, expected := range subQueries { for subQuery, expected := range subQueries {
boolColumn := AllTypes.Boolean.From(subQuery) boolColumn := AllTypes.Boolean.From(subQuery)
@ -487,7 +493,7 @@ func TestSubQueryColumnReference(t *testing.T) {
). ).
FROM(subQuery) FROM(subQuery)
var expectedSql = ` var expectedSQL = `
SELECT "subQuery"."all_types.boolean" AS "all_types.boolean", SELECT "subQuery"."all_types.boolean" AS "all_types.boolean",
"subQuery"."all_types.integer" AS "all_types.integer", "subQuery"."all_types.integer" AS "all_types.integer",
"subQuery"."all_types.real" AS "all_types.real", "subQuery"."all_types.real" AS "all_types.real",
@ -500,7 +506,7 @@ SELECT "subQuery"."all_types.boolean" AS "all_types.boolean",
"subQuery"."aliasedColumn" AS "aliasedColumn" "subQuery"."aliasedColumn" AS "aliasedColumn"
FROM` FROM`
assertStatementSql(t, stmt1, expectedSql+expected.sql+";\n", expected.args...) assertStatementSql(t, stmt1, expectedSQL+expected.sql+";\n", expected.args...)
dest1 := []model.AllTypes{} dest1 := []model.AllTypes{}
err := stmt1.Query(db, &dest1) err := stmt1.Query(db, &dest1)
@ -523,7 +529,7 @@ FROM`
//fmt.Println(stmt2.DebugSql()) //fmt.Println(stmt2.DebugSql())
assertStatementSql(t, stmt2, expectedSql+expected.sql+";\n", expected.args...) assertStatementSql(t, stmt2, expectedSQL+expected.sql+";\n", expected.args...)
dest2 := []model.AllTypes{} dest2 := []model.AllTypes{}
err = stmt2.Query(db, &dest2) err = stmt2.Query(db, &dest2)

View file

@ -1,6 +1,7 @@
package tests package tests
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -9,6 +10,7 @@ import (
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/table"
"gotest.tools/assert" "gotest.tools/assert"
"io/ioutil" "io/ioutil"
"runtime"
"testing" "testing"
"time" "time"
) )
@ -104,7 +106,7 @@ func TestJoinEverything(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(dest), 275) assert.Equal(t, len(dest), 275)
assertJsonFile(t, "./testdata/joined_everything.json", dest) assertJSONFile(t, "./testdata/joined_everything.json", dest)
} }
func TestSelfJoin(t *testing.T) { func TestSelfJoin(t *testing.T) {
@ -144,7 +146,7 @@ ORDER BY "Employee"."EmployeeId";
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(dest), 8) assert.Equal(t, len(dest), 8)
assertJson(t, dest[0:2], ` assertJSON(t, dest[0:2], `
[ [
{ {
"EmployeeId": 1, "EmployeeId": 1,
@ -244,23 +246,21 @@ ORDER BY "Album.AlbumId";
assert.DeepEqual(t, dest[1], album2) assert.DeepEqual(t, dest[1], album2)
} }
//func TestQueryWithContext(t *testing.T) { func TestQueryWithContext(t *testing.T) {
//
// ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
// defer cancel() defer cancel()
//
// dest := []model.Album{} dest := []model.Album{}
//
// err := Album. err := Album.
// CROSS_JOIN(Track). CROSS_JOIN(Track).
// CROSS_JOIN(InvoiceLine). CROSS_JOIN(InvoiceLine).
// SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns). SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
// QueryContext(db, ctx, &dest) QueryContext(ctx, db, &dest)
//
// spew.Dump(dest) assert.Error(t, err, "context deadline exceeded")
// }
// assert.Error(t, err, "context deadline exceeded")
//}
func TestExecWithContext(t *testing.T) { func TestExecWithContext(t *testing.T) {
@ -271,7 +271,7 @@ func TestExecWithContext(t *testing.T) {
CROSS_JOIN(Track). CROSS_JOIN(Track).
CROSS_JOIN(InvoiceLine). CROSS_JOIN(InvoiceLine).
SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns). SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
ExecContext(db, ctx) ExecContext(ctx, db)
assert.Error(t, err, "pq: canceling statement due to user request") assert.Error(t, err, "pq: canceling statement due to user request")
} }
@ -283,7 +283,7 @@ func TestSubQueriesForQuotedNames(t *testing.T) {
LIMIT(10). LIMIT(10).
AsTable("first10Artist") AsTable("first10Artist")
artistId := Artist.ArtistId.From(first10Artist) artistID := Artist.ArtistId.From(first10Artist)
first10Albums := Album. first10Albums := Album.
SELECT(Album.AllColumns). SELECT(Album.AllColumns).
@ -291,12 +291,12 @@ func TestSubQueriesForQuotedNames(t *testing.T) {
LIMIT(10). LIMIT(10).
AsTable("first10Albums") AsTable("first10Albums")
albumArtistId := Album.ArtistId.From(first10Albums) albumArtistID := Album.ArtistId.From(first10Albums)
stmt := first10Artist. stmt := first10Artist.
INNER_JOIN(first10Albums, artistId.EQ(albumArtistId)). INNER_JOIN(first10Albums, artistID.EQ(albumArtistID)).
SELECT(first10Artist.AllColumns(), first10Albums.AllColumns()). SELECT(first10Artist.AllColumns(), first10Albums.AllColumns()).
ORDER_BY(artistId) ORDER_BY(artistID)
assertStatementSql(t, stmt, ` assertStatementSql(t, stmt, `
SELECT "first10Artist"."Artist.ArtistId" AS "Artist.ArtistId", SELECT "first10Artist"."Artist.ArtistId" AS "Artist.ArtistId",
@ -335,22 +335,26 @@ ORDER BY "first10Artist"."Artist.ArtistId";
//spew.Dump(dest) //spew.Dump(dest)
} }
func assertJson(t *testing.T, data interface{}, expectedJson string) { func assertJSON(t *testing.T, data interface{}, expectedJSON string) {
jsonData, err := json.MarshalIndent(data, "", "\t") jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJson) assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON)
} }
func assertJsonFile(t *testing.T, jsonFilePath string, data interface{}) { func assertJSONFile(t *testing.T, jsonFilePath string, data interface{}) {
fileJsonData, err := ioutil.ReadFile(jsonFilePath) fileJSONData, err := ioutil.ReadFile(jsonFilePath)
assert.NilError(t, err) assert.NilError(t, err)
if runtime.GOOS == "windows" {
fileJSONData = bytes.Replace(fileJSONData, []byte("\r\n"), []byte("\n"), -1)
}
jsonData, err := json.MarshalIndent(data, "", "\t") jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err) assert.NilError(t, err)
assert.Assert(t, string(fileJsonData) == string(jsonData)) assert.Assert(t, string(fileJSONData) == string(jsonData))
//assert.Equal(t, string(fileJsonData), string(jsonData)) //assert.Equal(t, string(fileJSONData), string(jsonData))
} }
func jsonPrint(v interface{}) { func jsonPrint(v interface{}) {

View file

@ -2,6 +2,7 @@ package dbconfig
import "fmt" import "fmt"
// test database connection parameters
const ( const (
Host = "localhost" Host = "localhost"
Port = 5432 Port = 5432
@ -10,4 +11,5 @@ const (
DBName = "jetdb" DBName = "jetdb"
) )
// ConnectString is PostgreSQL connection string
var ConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName) var ConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName)

View file

@ -1,17 +1,19 @@
package tests package tests
import ( import (
"context"
. "github.com/go-jet/jet" . "github.com/go-jet/jet"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"gotest.tools/assert" "gotest.tools/assert"
"testing" "testing"
"time"
) )
func TestDeleteWithWhere(t *testing.T) { func TestDeleteWithWhere(t *testing.T) {
initForDeleteTest(t) initForDeleteTest(t)
var expectedSql = ` var expectedSQL = `
DELETE FROM test_sample.link DELETE FROM test_sample.link
WHERE link.name IN ('Gmail', 'Outlook'); WHERE link.name IN ('Gmail', 'Outlook');
` `
@ -19,14 +21,14 @@ WHERE link.name IN ('Gmail', 'Outlook');
DELETE(). DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
assertStatementSql(t, deleteStmt, expectedSql, "Gmail", "Outlook") assertStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook")
assertExec(t, deleteStmt, 2) assertExec(t, deleteStmt, 2)
} }
func TestDeleteWithWhereAndReturning(t *testing.T) { func TestDeleteWithWhereAndReturning(t *testing.T) {
initForDeleteTest(t) initForDeleteTest(t)
var expectedSql = ` var expectedSQL = `
DELETE FROM test_sample.link DELETE FROM test_sample.link
WHERE link.name IN ('Gmail', 'Outlook') WHERE link.name IN ('Gmail', 'Outlook')
RETURNING link.id AS "link.id", RETURNING link.id AS "link.id",
@ -39,7 +41,7 @@ RETURNING link.id AS "link.id",
WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))).
RETURNING(Link.AllColumns) RETURNING(Link.AllColumns)
assertStatementSql(t, deleteStmt, expectedSql, "Gmail", "Outlook") assertStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook")
dest := []model.Link{} dest := []model.Link{}
@ -60,3 +62,38 @@ func initForDeleteTest(t *testing.T) {
assertExec(t, stmt, 2) assertExec(t, stmt, 2)
} }
func TestDeleteQueryContext(t *testing.T) {
initForDeleteTest(t)
deleteStmt := Link.
DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
dest := []model.Link{}
err := deleteStmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded")
}
func TestDeleteExecContext(t *testing.T) {
initForDeleteTest(t)
deleteStmt := Link.
DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
_, err := deleteStmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded")
}

View file

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/go-jet/jet/generator/postgres" "github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/tests/dbconfig" "github.com/go-jet/jet/tests/dbconfig"
_ "github.com/lib/pq"
"io/ioutil" "io/ioutil"
) )
@ -38,7 +39,7 @@ func main() {
err = postgres.Generate("./.gentestdata", postgres.DBConnection{ err = postgres.Generate("./.gentestdata", postgres.DBConnection{
Host: dbconfig.Host, Host: dbconfig.Host,
Port: "5432", Port: 5432,
User: dbconfig.User, User: dbconfig.User,
Password: dbconfig.Password, Password: dbconfig.Password,
DBName: dbconfig.DBName, DBName: dbconfig.DBName,

View file

@ -1,17 +1,19 @@
package tests package tests
import ( import (
"context"
. "github.com/go-jet/jet" . "github.com/go-jet/jet"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"gotest.tools/assert" "gotest.tools/assert"
"testing" "testing"
"time"
) )
func TestInsertValues(t *testing.T) { func TestInsertValues(t *testing.T) {
cleanUpLinkTable(t) cleanUpLinkTable(t)
var expectedSql = ` var expectedSQL = `
INSERT INTO test_sample.link (id, url, name, description) VALUES INSERT INTO test_sample.link (id, url, name, description) VALUES
(100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
(101, 'http://www.google.com', 'Google', DEFAULT), (101, 'http://www.google.com', 'Google', DEFAULT),
@ -28,7 +30,7 @@ RETURNING link.id AS "link.id",
VALUES(102, "http://www.yahoo.com", "Yahoo", nil). VALUES(102, "http://www.yahoo.com", "Yahoo", nil).
RETURNING(Link.AllColumns) RETURNING(Link.AllColumns)
assertStatementSql(t, insertQuery, expectedSql, assertStatementSql(t, insertQuery, expectedSQL,
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
101, "http://www.google.com", "Google", 101, "http://www.google.com", "Google",
102, "http://www.yahoo.com", "Yahoo", nil) 102, "http://www.yahoo.com", "Yahoo", nil)
@ -74,7 +76,7 @@ RETURNING link.id AS "link.id",
func TestInsertEmptyColumnList(t *testing.T) { func TestInsertEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t) cleanUpLinkTable(t)
expectedSql := ` expectedSQL := `
INSERT INTO test_sample.link VALUES INSERT INTO test_sample.link VALUES
(100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT);
` `
@ -82,7 +84,7 @@ INSERT INTO test_sample.link VALUES
stmt := Link.INSERT(). stmt := Link.INSERT().
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT)
assertStatementSql(t, stmt, expectedSql, assertStatementSql(t, stmt, expectedSQL,
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial")
assertExec(t, stmt, 1) assertExec(t, stmt, 1)
@ -90,7 +92,7 @@ INSERT INTO test_sample.link VALUES
func TestInsertModelObject(t *testing.T) { func TestInsertModelObject(t *testing.T) {
cleanUpLinkTable(t) cleanUpLinkTable(t)
var expectedSql = ` var expectedSQL = `
INSERT INTO test_sample.link (url, name) VALUES INSERT INTO test_sample.link (url, name) VALUES
('http://www.duckduckgo.com', 'Duck Duck go'); ('http://www.duckduckgo.com', 'Duck Duck go');
` `
@ -104,19 +106,35 @@ INSERT INTO test_sample.link (url, name) VALUES
INSERT(Link.URL, Link.Name). INSERT(Link.URL, Link.Name).
MODEL(linkData) MODEL(linkData)
assertStatementSql(t, query, expectedSql, "http://www.duckduckgo.com", "Duck Duck go") assertStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go")
result, err := query.Exec(db) assertExec(t, query, 1)
}
assert.NilError(t, err) func TestInsertModelObjectEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t)
var expectedSQL = `
INSERT INTO test_sample.link VALUES
(1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
`
rowsAffected, err := result.RowsAffected() linkData := model.Link{
ID: 1000,
URL: "http://www.duckduckgo.com",
Name: "Duck Duck go",
}
assert.Equal(t, rowsAffected, int64(1)) query := Link.
INSERT().
MODEL(linkData)
assertStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil)
assertExec(t, query, 1)
} }
func TestInsertModelsObject(t *testing.T) { func TestInsertModelsObject(t *testing.T) {
expectedSql := ` expectedSQL := `
INSERT INTO test_sample.link (url, name) VALUES INSERT INTO test_sample.link (url, name) VALUES
('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'), ('http://www.google.com', 'Google'),
@ -142,7 +160,7 @@ INSERT INTO test_sample.link (url, name) VALUES
INSERT(Link.URL, Link.Name). INSERT(Link.URL, Link.Name).
MODELS([]model.Link{tutorial, google, yahoo}) MODELS([]model.Link{tutorial, google, yahoo})
assertStatementSql(t, stmt, expectedSql, assertStatementSql(t, stmt, expectedSQL,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
"http://www.google.com", "Google", "http://www.google.com", "Google",
"http://www.yahoo.com", "Yahoo") "http://www.yahoo.com", "Yahoo")
@ -151,7 +169,7 @@ INSERT INTO test_sample.link (url, name) VALUES
} }
func TestInsertUsingMutableColumns(t *testing.T) { func TestInsertUsingMutableColumns(t *testing.T) {
var expectedSql = ` var expectedSQL = `
INSERT INTO test_sample.link (url, name, description) VALUES INSERT INTO test_sample.link (url, name, description) VALUES
('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
('http://www.google.com', 'Google', NULL), ('http://www.google.com', 'Google', NULL),
@ -175,7 +193,7 @@ INSERT INTO test_sample.link (url, name, description) VALUES
MODEL(google). MODEL(google).
MODELS([]model.Link{google, yahoo}) MODELS([]model.Link{google, yahoo})
assertStatementSql(t, stmt, expectedSql, assertStatementSql(t, stmt, expectedSQL,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
"http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil,
"http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil,
@ -190,7 +208,7 @@ func TestInsertQuery(t *testing.T) {
Exec(db) Exec(db)
assert.NilError(t, err) assert.NilError(t, err)
var expectedSql = ` var expectedSQL = `
INSERT INTO test_sample.link (url, name) ( INSERT INTO test_sample.link (url, name) (
SELECT link.url AS "link.url", SELECT link.url AS "link.url",
link.name AS "link.name" link.name AS "link.name"
@ -212,7 +230,7 @@ RETURNING link.id AS "link.id",
). ).
RETURNING(Link.AllColumns) RETURNING(Link.AllColumns)
assertStatementSql(t, query, expectedSql, int64(0)) assertStatementSql(t, query, expectedSQL, int64(0))
dest := []model.Link{} dest := []model.Link{}
@ -229,3 +247,37 @@ RETURNING link.id AS "link.id",
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(youtubeLinks), 2) assert.Equal(t, len(youtubeLinks), 2)
} }
func TestInsertWithQueryContext(t *testing.T) {
cleanUpLinkTable(t)
stmt := Link.INSERT().
VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
RETURNING(Link.AllColumns)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
dest := []model.Link{}
err := stmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded")
}
func TestInsertWithExecContext(t *testing.T) {
cleanUpLinkTable(t)
stmt := Link.INSERT().
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
_, err := stmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded")
}

73
tests/lock_test.go Normal file
View file

@ -0,0 +1,73 @@
package tests
import (
"context"
"gotest.tools/assert"
"testing"
"time"
. "github.com/go-jet/jet"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table"
)
func TestLockTable(t *testing.T) {
expectedSQL := `
LOCK TABLE dvds.address IN`
var testData = []TableLockMode{
LOCK_ACCESS_SHARE,
LOCK_ROW_SHARE,
LOCK_ROW_EXCLUSIVE,
LOCK_SHARE_UPDATE_EXCLUSIVE,
LOCK_SHARE,
LOCK_SHARE_ROW_EXCLUSIVE,
LOCK_EXCLUSIVE,
LOCK_ACCESS_EXCLUSIVE,
}
for _, lockMode := range testData {
query := Address.LOCK().IN(lockMode)
assertStatementSql(t, query, expectedSQL+" "+string(lockMode)+" MODE;\n")
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
for _, lockMode := range testData {
query := Address.LOCK().IN(lockMode).NOWAIT()
assertStatementSql(t, query, expectedSQL+" "+string(lockMode)+" MODE NOWAIT;\n")
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
}
func TestLockExecContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
tx, _ := db.Begin()
defer tx.Rollback()
_, err := Address.LOCK().IN(LOCK_ACCESS_SHARE).ExecContext(ctx, tx)
assert.Error(t, err, "context deadline exceeded")
}

View file

@ -1,13 +1,17 @@
package tests package tests
import ( import (
"bytes"
"database/sql" "database/sql"
"github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model"
"github.com/go-jet/jet/tests/dbconfig" "github.com/go-jet/jet/tests/dbconfig"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/pkg/profile" "github.com/pkg/profile"
"gotest.tools/assert" "gotest.tools/assert"
"io/ioutil"
"os" "os"
"os/exec"
"reflect" "reflect"
"testing" "testing"
) )
@ -29,7 +33,7 @@ func TestMain(m *testing.M) {
os.Exit(ret) os.Exit(ret)
} }
func TestGenerateModel(t *testing.T) { func TestGeneratedModel(t *testing.T) {
actor := model.Actor{} actor := model.Actor{}
@ -58,3 +62,189 @@ func TestGenerateModel(t *testing.T) {
assert.Equal(t, reflect.TypeOf(staff.Email).String(), "*string") assert.Equal(t, reflect.TypeOf(staff.Email).String(), "*string")
assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8") assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8")
} }
const genTestDir2 = "./.gentestdata2"
func TestCmdGenerator(t *testing.T) {
err := os.RemoveAll(genTestDir2)
assert.NilError(t, err)
cmd := exec.Command("jet", "-dbname=jetdb", "-host=localhost", "-port=5432",
"-user=jet", "-password=jet", "-schema=dvds", "-path="+genTestDir2)
err = cmd.Run()
assert.NilError(t, err)
assertGeneratedFiles(t)
err = os.RemoveAll(genTestDir2)
assert.NilError(t, err)
}
func TestGenerator(t *testing.T) {
err := os.RemoveAll(genTestDir2)
assert.NilError(t, err)
err = postgres.Generate(genTestDir2, postgres.DBConnection{
Host: dbconfig.Host,
Port: dbconfig.Port,
User: dbconfig.User,
Password: dbconfig.Password,
SslMode: "disable",
Params: "",
DBName: dbconfig.DBName,
SchemaName: "dvds",
})
assert.NilError(t, err)
assertGeneratedFiles(t)
err = os.RemoveAll(genTestDir2)
assert.NilError(t, err)
}
func assertGeneratedFiles(t *testing.T) {
// Table SQL Builder files
tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table")
assert.NilError(t, err)
assertFileNameEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile)
// Enums SQL Builder files
enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum")
assert.NilError(t, err)
assertFileNameEqual(t, enumFiles, "mpaa_rating.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile)
// Model files
modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model")
assert.NilError(t, err)
assertFileNameEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", "\npackage model", actorModelFile)
}
func assertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err)
beginIndex := bytes.Index(enumFileData, []byte(contentBegin))
//fmt.Println("-"+string(enumFileData[beginIndex:])+"-")
assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent)
}
func assertFileNameEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) {
fileNamesMap := map[string]bool{}
for _, fileInfo := range fileInfos {
fileNamesMap[fileInfo.Name()] = true
}
for _, fileName := range fileNames {
assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.")
}
}
var mpaaRatingEnumFile = `
package enum
import "github.com/go-jet/jet"
var MpaaRating = &struct {
G jet.StringExpression
Pg jet.StringExpression
Pg13 jet.StringExpression
R jet.StringExpression
Nc17 jet.StringExpression
}{
G: jet.NewEnumValue("G"),
Pg: jet.NewEnumValue("PG"),
Pg13: jet.NewEnumValue("PG-13"),
R: jet.NewEnumValue("R"),
Nc17: jet.NewEnumValue("NC-17"),
}
`
var actorSQLBuilderFile = `
package table
import (
"github.com/go-jet/jet"
)
var Actor = newActorTable()
type ActorTable struct {
jet.Table
//Columns
ActorID jet.ColumnInteger
FirstName jet.ColumnString
LastName jet.ColumnString
LastUpdate jet.ColumnTimestamp
AllColumns jet.ColumnList
MutableColumns jet.ColumnList
}
// creates new ActorTable with assigned alias
func (a *ActorTable) AS(alias string) *ActorTable {
aliasTable := newActorTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newActorTable() *ActorTable {
var (
ActorIDColumn = jet.IntegerColumn("actor_id")
FirstNameColumn = jet.StringColumn("first_name")
LastNameColumn = jet.StringColumn("last_name")
LastUpdateColumn = jet.TimestampColumn("last_update")
)
return &ActorTable{
Table: jet.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn),
//Columns
ActorID: ActorIDColumn,
FirstName: FirstNameColumn,
LastName: LastNameColumn,
LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn},
MutableColumns: jet.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn},
}
}
`
var actorModelFile = `
package model
import (
"time"
)
type Actor struct {
ActorID int32 ` + "`sql:\"primary_key\"`" + `
FirstName string
LastName string
LastUpdate time.Time
}
`

View file

@ -61,5 +61,5 @@ func TestNorthwindJoinEverything(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
//jsonSave("./testdata/northwind-all.json", dest) //jsonSave("./testdata/northwind-all.json", dest)
assertJsonFile(t, "./testdata/northwind-all.json", dest) assertJSONFile(t, "./testdata/northwind-all.json", dest)
} }

View file

@ -46,7 +46,7 @@ FROM test_sample.person;
err := query.Query(db, &result) err := query.Query(db, &result)
assert.NilError(t, err) assert.NilError(t, err)
assertJson(t, result, ` assertJSON(t, result, `
[ [
{ {
"PersonID": "b68dbff4-a87d-11e9-a7f2-98ded00c39c6", "PersonID": "b68dbff4-a87d-11e9-a7f2-98ded00c39c6",
@ -72,7 +72,7 @@ FROM test_sample.person;
func TestSelecSelfJoin1(t *testing.T) { func TestSelecSelfJoin1(t *testing.T) {
var expectedSql = ` var expectedSQL = `
SELECT employee.employee_id AS "employee.employee_id", SELECT employee.employee_id AS "employee.employee_id",
employee.first_name AS "employee.first_name", employee.first_name AS "employee.first_name",
employee.last_name AS "employee.last_name", employee.last_name AS "employee.last_name",
@ -97,7 +97,7 @@ ORDER BY employee.employee_id;
). ).
ORDER_BY(Employee.EmployeeID) ORDER_BY(Employee.EmployeeID)
assertStatementSql(t, query, expectedSql) assertStatementSql(t, query, expectedSQL)
type Manager model.Employee type Manager model.Employee

View file

@ -19,7 +19,7 @@ func TestScanToInvalidDestination(t *testing.T) {
t.Run("nil dest", func(t *testing.T) { t.Run("nil dest", func(t *testing.T) {
err := query.Query(db, nil) err := query.Query(db, nil)
assert.Error(t, err, "jet: Destination is nil.") assert.Error(t, err, "jet: Destination is nil")
}) })
t.Run("struct dest", func(t *testing.T) { t.Run("struct dest", func(t *testing.T) {

View file

@ -10,7 +10,7 @@ import (
) )
func TestSelect_ScanToStruct(t *testing.T) { func TestSelect_ScanToStruct(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT DISTINCT actor.actor_id AS "actor.actor_id", SELECT DISTINCT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name", actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name", actor.last_name AS "actor.last_name",
@ -24,7 +24,7 @@ WHERE actor.actor_id = 1;
DISTINCT(). DISTINCT().
WHERE(Actor.ActorID.EQ(Int(1))) WHERE(Actor.ActorID.EQ(Int(1)))
assertStatementSql(t, query, expectedSql, int64(1)) assertStatementSql(t, query, expectedSQL, int64(1))
actor := model.Actor{} actor := model.Actor{}
err := query.Query(db, &actor) err := query.Query(db, &actor)
@ -42,7 +42,7 @@ WHERE actor.actor_id = 1;
} }
func TestClassicSelect(t *testing.T) { func TestClassicSelect(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT payment.payment_id AS "payment.payment_id", SELECT payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id", payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id", payment.staff_id AS "payment.staff_id",
@ -74,7 +74,7 @@ LIMIT 30;
ORDER_BY(Payment.PaymentID.ASC()). ORDER_BY(Payment.PaymentID.ASC()).
LIMIT(30) LIMIT(30)
assertStatementSql(t, query, expectedSql, int64(30)) assertStatementSql(t, query, expectedSQL, int64(30))
dest := []model.Payment{} dest := []model.Payment{}
@ -85,7 +85,7 @@ LIMIT 30;
} }
func TestSelect_ScanToSlice(t *testing.T) { func TestSelect_ScanToSlice(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id", SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id", customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name", customer.first_name AS "customer.first_name",
@ -103,7 +103,7 @@ ORDER BY customer.customer_id ASC;
query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.ASC()) query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.ASC())
assertStatementSql(t, query, expectedSql) assertStatementSql(t, query, expectedSQL)
err := query.Query(db, &customers) err := query.Query(db, &customers)
assert.NilError(t, err) assert.NilError(t, err)
@ -116,7 +116,7 @@ ORDER BY customer.customer_id ASC;
} }
func TestSelectAndUnionInProjection(t *testing.T) { func TestSelectAndUnionInProjection(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT payment.payment_id AS "payment.payment_id", SELECT payment.payment_id AS "payment.payment_id",
( (
SELECT customer.customer_id AS "customer.customer_id" SELECT customer.customer_id AS "customer.customer_id"
@ -156,12 +156,12 @@ LIMIT 12;
). ).
LIMIT(12) LIMIT(12)
assertStatementSql(t, query, expectedSql, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) assertStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12))
} }
func TestJoinQueryStruct(t *testing.T) { func TestJoinQueryStruct(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT film_actor.actor_id AS "film_actor.actor_id", SELECT film_actor.actor_id AS "film_actor.actor_id",
film_actor.film_id AS "film_actor.film_id", film_actor.film_id AS "film_actor.film_id",
film_actor.last_update AS "film_actor.last_update", film_actor.last_update AS "film_actor.last_update",
@ -224,7 +224,7 @@ LIMIT 1000;
ORDER_BY(Film.FilmID.ASC()). ORDER_BY(Film.FilmID.ASC()).
LIMIT(1000) LIMIT(1000)
assertStatementSql(t, query, expectedSql, int64(1000)) assertStatementSql(t, query, expectedSQL, int64(1000))
var languageActorFilm []struct { var languageActorFilm []struct {
model.Language model.Language
@ -253,7 +253,7 @@ LIMIT 1000;
} }
func TestJoinQuerySlice(t *testing.T) { func TestJoinQuerySlice(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT language.language_id AS "language.language_id", SELECT language.language_id AS "language.language_id",
language.name AS "language.name", language.name AS "language.name",
language.last_update AS "language.last_update", language.last_update AS "language.last_update",
@ -290,7 +290,7 @@ LIMIT 15;
WHERE(Film.Rating.EQ(enum.MpaaRating.Nc17)). WHERE(Film.Rating.EQ(enum.MpaaRating.Nc17)).
LIMIT(15) LIMIT(15)
assertStatementSql(t, query, expectedSql, int64(15)) assertStatementSql(t, query, expectedSQL, int64(15))
err := query.Query(db, &filmsPerLanguage) err := query.Query(db, &filmsPerLanguage)
@ -532,7 +532,7 @@ ORDER BY city.city_id, address.address_id, customer.customer_id;
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assertJson(t, dest, ` assertJSON(t, dest, `
[ [
{ {
"CityID": 312, "CityID": 312,
@ -657,7 +657,7 @@ func TestSelectOrderByAscDesc(t *testing.T) {
} }
func TestSelectFullJoin(t *testing.T) { func TestSelectFullJoin(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id", SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id", customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name", customer.first_name AS "customer.first_name",
@ -685,7 +685,7 @@ ORDER BY customer.customer_id ASC;
SELECT(Customer.AllColumns, Address.AllColumns). SELECT(Customer.AllColumns, Address.AllColumns).
ORDER_BY(Customer.CustomerID.ASC()) ORDER_BY(Customer.CustomerID.ASC())
assertStatementSql(t, query, expectedSql) assertStatementSql(t, query, expectedSQL)
allCustomersAndAddress := []struct { allCustomersAndAddress := []struct {
Address *model.Address Address *model.Address
@ -708,7 +708,7 @@ ORDER BY customer.customer_id ASC;
} }
func TestSelectFullCrossJoin(t *testing.T) { func TestSelectFullCrossJoin(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id", SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id", customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name", customer.first_name AS "customer.first_name",
@ -738,7 +738,7 @@ LIMIT 1000;
ORDER_BY(Customer.CustomerID.ASC()). ORDER_BY(Customer.CustomerID.ASC()).
LIMIT(1000) LIMIT(1000)
assertStatementSql(t, query, expectedSql, int64(1000)) assertStatementSql(t, query, expectedSQL, int64(1000))
var customerAddresCrosJoined []struct { var customerAddresCrosJoined []struct {
model.Customer model.Customer
@ -753,7 +753,7 @@ LIMIT 1000;
} }
func TestSelectSelfJoin(t *testing.T) { func TestSelectSelfJoin(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT f1.film_id AS "f1.film_id", SELECT f1.film_id AS "f1.film_id",
f1.title AS "f1.title", f1.title AS "f1.title",
f1.description AS "f1.description", f1.description AS "f1.description",
@ -793,7 +793,7 @@ ORDER BY f1.film_id ASC;
SELECT(f1.AllColumns, f2.AllColumns). SELECT(f1.AllColumns, f2.AllColumns).
ORDER_BY(f1.FilmID.ASC()) ORDER_BY(f1.FilmID.ASC())
assertStatementSql(t, query, expectedSql) assertStatementSql(t, query, expectedSQL)
type F1 model.Film type F1 model.Film
type F2 model.Film type F2 model.Film
@ -813,7 +813,7 @@ ORDER BY f1.film_id ASC;
} }
func TestSelectAliasColumn(t *testing.T) { func TestSelectAliasColumn(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT f1.title AS "thesame_length_films.title1", SELECT f1.title AS "thesame_length_films.title1",
f2.title AS "thesame_length_films.title2", f2.title AS "thesame_length_films.title2",
f1.length AS "thesame_length_films.length" f1.length AS "thesame_length_films.length"
@ -835,7 +835,7 @@ LIMIT 1000;
ORDER_BY(f1.Length.ASC(), f1.Title.ASC(), f2.Title.ASC()). ORDER_BY(f1.Length.ASC(), f1.Title.ASC(), f2.Title.ASC()).
LIMIT(1000) LIMIT(1000)
assertStatementSql(t, query, expectedSql, int64(1000)) assertStatementSql(t, query, expectedSQL, int64(1000))
type thesameLengthFilms struct { type thesameLengthFilms struct {
Title1 string Title1 string
@ -886,11 +886,11 @@ FROM dvds.actor
WHERE(Film.Rating.EQ(enum.MpaaRating.R)). WHERE(Film.Rating.EQ(enum.MpaaRating.R)).
AsTable("rFilms") AsTable("rFilms")
rFilmId := Film.FilmID.From(rRatingFilms) rFilmID := Film.FilmID.From(rRatingFilms)
query := Actor. query := Actor.
INNER_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.FilmID)). INNER_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.FilmID)).
INNER_JOIN(rRatingFilms, FilmActor.FilmID.EQ(rFilmId)). INNER_JOIN(rRatingFilms, FilmActor.FilmID.EQ(rFilmID)).
SELECT( SELECT(
Actor.AllColumns, Actor.AllColumns,
FilmActor.AllColumns, FilmActor.AllColumns,
@ -928,7 +928,7 @@ FROM dvds.film;
} }
func TestSelectQueryScalar(t *testing.T) { func TestSelectQueryScalar(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT film.film_id AS "film.film_id", SELECT film.film_id AS "film.film_id",
film.title AS "film.title", film.title AS "film.title",
film.description AS "film.description", film.description AS "film.description",
@ -960,7 +960,7 @@ ORDER BY film.film_id ASC;
WHERE(Film.RentalRate.EQ(maxFilmRentalRate)). WHERE(Film.RentalRate.EQ(maxFilmRentalRate)).
ORDER_BY(Film.FilmID.ASC()) ORDER_BY(Film.FilmID.ASC())
assertStatementSql(t, query, expectedSql) assertStatementSql(t, query, expectedSQL)
maxRentalRateFilms := []model.Film{} maxRentalRateFilms := []model.Film{}
err := query.Query(db, &maxRentalRateFilms) err := query.Query(db, &maxRentalRateFilms)
@ -989,7 +989,7 @@ ORDER BY film.film_id ASC;
} }
func TestSelectGroupByHaving(t *testing.T) { func TestSelectGroupByHaving(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT payment.customer_id AS "customer_payment_sum.customer_id", SELECT payment.customer_id AS "customer_payment_sum.customer_id",
SUM(payment.amount) AS "customer_payment_sum.amount_sum", SUM(payment.amount) AS "customer_payment_sum.amount_sum",
AVG(payment.amount) AS "customer_payment_sum.amount_avg", AVG(payment.amount) AS "customer_payment_sum.amount_avg",
@ -1018,7 +1018,7 @@ ORDER BY SUM(payment.amount) ASC;
SUMf(Payment.Amount).GT(Float(100)), SUMf(Payment.Amount).GT(Float(100)),
) )
assertStatementSql(t, customersPaymentQuery, expectedSql, float64(100)) assertStatementSql(t, customersPaymentQuery, expectedSQL, float64(100))
type CustomerPaymentSum struct { type CustomerPaymentSum struct {
CustomerID int16 CustomerID int16
@ -1047,7 +1047,7 @@ ORDER BY SUM(payment.amount) ASC;
} }
func TestSelectGroupBy2(t *testing.T) { func TestSelectGroupBy2(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id", SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id", customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name", customer.first_name AS "customer.first_name",
@ -1077,18 +1077,18 @@ ORDER BY customer_payment_sum."amount_sum" ASC;
GROUP_BY(Payment.CustomerID). GROUP_BY(Payment.CustomerID).
AsTable("customer_payment_sum") AsTable("customer_payment_sum")
customerId := Payment.CustomerID.From(customersPayments) customerID := Payment.CustomerID.From(customersPayments)
amountSum := FloatColumn("amount_sum").From(customersPayments) amountSum := FloatColumn("amount_sum").From(customersPayments)
query := Customer. query := Customer.
INNER_JOIN(customersPayments, Customer.CustomerID.EQ(customerId)). INNER_JOIN(customersPayments, Customer.CustomerID.EQ(customerID)).
SELECT( SELECT(
Customer.AllColumns, Customer.AllColumns,
amountSum.AS("CustomerWithAmounts.AmountSum"), amountSum.AS("CustomerWithAmounts.AmountSum"),
). ).
ORDER_BY(amountSum.ASC()) ORDER_BY(amountSum.ASC())
assertStatementSql(t, query, expectedSql) assertStatementSql(t, query, expectedSQL)
type CustomerWithAmounts struct { type CustomerWithAmounts struct {
Customer *model.Customer Customer *model.Customer
@ -1123,7 +1123,7 @@ func TestSelectStaff(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
assertJson(t, staffs, ` assertJSON(t, staffs, `
[ [
{ {
"StaffID": 1, "StaffID": 1,
@ -1157,7 +1157,7 @@ func TestSelectStaff(t *testing.T) {
func TestSelectTimeColumns(t *testing.T) { func TestSelectTimeColumns(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT payment.payment_id AS "payment.payment_id", SELECT payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id", payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id", payment.staff_id AS "payment.staff_id",
@ -1173,7 +1173,7 @@ ORDER BY payment.payment_date ASC;
WHERE(Payment.PaymentDate.LT(Timestamp(2007, 02, 14, 22, 16, 01, 0))). WHERE(Payment.PaymentDate.LT(Timestamp(2007, 02, 14, 22, 16, 01, 0))).
ORDER_BY(Payment.PaymentDate.ASC()) ORDER_BY(Payment.PaymentDate.ASC())
assertStatementSql(t, query, expectedSql, "2007-02-14 22:16:01.000") assertStatementSql(t, query, expectedSQL, "2007-02-14 22:16:01.000")
payments := []model.Payment{} payments := []model.Payment{}
@ -1260,8 +1260,8 @@ func TestAllSetOperators(t *testing.T) {
UNION_ALL, UNION_ALL,
INTERSECT, INTERSECT,
INTERSECT_ALL, INTERSECT_ALL,
EXCEPT, //EXCEPT,
EXCEPT_ALL, //EXCEPT_ALL,
} }
expectedDestLen := []int{ expectedDestLen := []int{
@ -1304,63 +1304,15 @@ LIMIT 20;
assertStatementSql(t, query, expectedQuery, int64(1), "ONE", int64(2), "TWO", int64(3), "THREE", "OTHER", int64(20)) assertStatementSql(t, query, expectedQuery, int64(1), "ONE", int64(2), "TWO", int64(3), "THREE", "OTHER", int64(20))
dest := []struct { dest := []struct {
StaffIdNum string StaffIDNum string
}{} }{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(dest), 20) assert.Equal(t, len(dest), 20)
assert.Equal(t, dest[0].StaffIdNum, "TWO") assert.Equal(t, dest[0].StaffIDNum, "TWO")
assert.Equal(t, dest[1].StaffIdNum, "ONE") assert.Equal(t, dest[1].StaffIDNum, "ONE")
}
func TestLockTable(t *testing.T) {
expectedSql := `
LOCK TABLE dvds.address IN`
var testData = []TableLockMode{
LOCK_ACCESS_SHARE,
LOCK_ROW_SHARE,
LOCK_ROW_EXCLUSIVE,
LOCK_SHARE_UPDATE_EXCLUSIVE,
LOCK_SHARE,
LOCK_SHARE_ROW_EXCLUSIVE,
LOCK_EXCLUSIVE,
LOCK_ACCESS_EXCLUSIVE,
}
for _, lockMode := range testData {
query := Address.LOCK().IN(lockMode)
assertStatementSql(t, query, expectedSql+" "+string(lockMode)+" MODE;\n")
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
for _, lockMode := range testData {
query := Address.LOCK().IN(lockMode).NOWAIT()
assertStatementSql(t, query, expectedSql+" "+string(lockMode)+" MODE NOWAIT;\n")
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
} }
func getRowLockTestData() map[SelectLock]string { func getRowLockTestData() map[SelectLock]string {
@ -1373,7 +1325,7 @@ func getRowLockTestData() map[SelectLock]string {
} }
func TestRowLock(t *testing.T) { func TestRowLock(t *testing.T) {
expectedSql := ` expectedSQL := `
SELECT * SELECT *
FROM dvds.address FROM dvds.address
LIMIT 3 LIMIT 3
@ -1385,7 +1337,7 @@ FOR`
for lockType, lockTypeStr := range getRowLockTestData() { for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType) query.FOR(lockType)
assertStatementSql(t, query, expectedSql+" "+lockTypeStr+";\n", int64(3)) assertStatementSql(t, query, expectedSQL+" "+lockTypeStr+";\n", int64(3))
tx, _ := db.Begin() tx, _ := db.Begin()
@ -1401,7 +1353,7 @@ FOR`
for lockType, lockTypeStr := range getRowLockTestData() { for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType.NOWAIT()) query.FOR(lockType.NOWAIT())
assertStatementSql(t, query, expectedSql+" "+lockTypeStr+" NOWAIT;\n", int64(3)) assertStatementSql(t, query, expectedSQL+" "+lockTypeStr+" NOWAIT;\n", int64(3))
tx, _ := db.Begin() tx, _ := db.Begin()
@ -1417,7 +1369,7 @@ FOR`
for lockType, lockTypeStr := range getRowLockTestData() { for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType.SKIP_LOCKED()) query.FOR(lockType.SKIP_LOCKED())
assertStatementSql(t, query, expectedSql+" "+lockTypeStr+" SKIP LOCKED;\n", int64(3)) assertStatementSql(t, query, expectedSQL+" "+lockTypeStr+" SKIP LOCKED;\n", int64(3))
tx, _ := db.Begin() tx, _ := db.Begin()
@ -1433,7 +1385,7 @@ FOR`
func TestQuickStart(t *testing.T) { func TestQuickStart(t *testing.T) {
var expectedSql = ` var expectedSQL = `
SELECT actor.actor_id AS "actor.actor_id", SELECT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name", actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name", actor.last_name AS "actor.last_name",
@ -1488,7 +1440,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
Film.FilmID.ASC(), Film.FilmID.ASC(),
) )
assertStatementSql(t, stmt, expectedSql, "English", "Action", int64(180)) assertStatementSql(t, stmt, expectedSQL, "English", "Action", int64(180))
var dest []struct { var dest []struct {
model.Actor model.Actor
@ -1506,7 +1458,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
assert.NilError(t, err) assert.NilError(t, err)
//jsonSave("./testdata/quick-start-dest.json", dest) //jsonSave("./testdata/quick-start-dest.json", dest)
assertJsonFile(t, "./testdata/quick-start-dest.json", dest) assertJSONFile(t, "./testdata/quick-start-dest.json", dest)
var dest2 []struct { var dest2 []struct {
model.Category model.Category
@ -1519,7 +1471,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
assert.NilError(t, err) assert.NilError(t, err)
//jsonSave("./testdata/quick-start-dest2.json", dest2) //jsonSave("./testdata/quick-start-dest2.json", dest2)
assertJsonFile(t, "./testdata/quick-start-dest2.json", dest2) assertJSONFile(t, "./testdata/quick-start-dest2.json", dest2)
} }
func TestQuickStartWithSubQueries(t *testing.T) { func TestQuickStartWithSubQueries(t *testing.T) {
@ -1529,22 +1481,22 @@ func TestQuickStartWithSubQueries(t *testing.T) {
WHERE(Film.Length.GT(Int(180))). WHERE(Film.Length.GT(Int(180))).
AsTable("films") AsTable("films")
filmId := Film.FilmID.From(filmLogerThan180) filmID := Film.FilmID.From(filmLogerThan180)
filmLanguageId := Film.LanguageID.From(filmLogerThan180) filmLanguageID := Film.LanguageID.From(filmLogerThan180)
categoriesNotAction := Category. categoriesNotAction := Category.
SELECT(Category.AllColumns). SELECT(Category.AllColumns).
WHERE(Category.Name.NOT_EQ(String("Action"))). WHERE(Category.Name.NOT_EQ(String("Action"))).
AsTable("categories") AsTable("categories")
categoryId := Category.CategoryID.From(categoriesNotAction) categoryID := Category.CategoryID.From(categoriesNotAction)
stmt := Actor. stmt := Actor.
INNER_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.ActorID)). INNER_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.ActorID)).
INNER_JOIN(filmLogerThan180, filmId.EQ(FilmActor.FilmID)). INNER_JOIN(filmLogerThan180, filmID.EQ(FilmActor.FilmID)).
INNER_JOIN(Language, Language.LanguageID.EQ(filmLanguageId)). INNER_JOIN(Language, Language.LanguageID.EQ(filmLanguageID)).
INNER_JOIN(FilmCategory, FilmCategory.FilmID.EQ(filmId)). INNER_JOIN(FilmCategory, FilmCategory.FilmID.EQ(filmID)).
INNER_JOIN(categoriesNotAction, categoryId.EQ(FilmCategory.CategoryID)). INNER_JOIN(categoriesNotAction, categoryID.EQ(FilmCategory.CategoryID)).
SELECT( SELECT(
Actor.AllColumns, Actor.AllColumns,
filmLogerThan180.AllColumns(), filmLogerThan180.AllColumns(),
@ -1552,7 +1504,7 @@ func TestQuickStartWithSubQueries(t *testing.T) {
categoriesNotAction.AllColumns(), categoriesNotAction.AllColumns(),
).ORDER_BY( ).ORDER_BY(
Actor.ActorID.ASC(), Actor.ActorID.ASC(),
filmId.ASC(), filmID.ASC(),
) )
var dest []struct { var dest []struct {
@ -1571,7 +1523,7 @@ func TestQuickStartWithSubQueries(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
//jsonSave("./testdata/quick-start-dest.json", dest) //jsonSave("./testdata/quick-start-dest.json", dest)
assertJsonFile(t, "./testdata/quick-start-dest.json", dest) assertJSONFile(t, "./testdata/quick-start-dest.json", dest)
var dest2 []struct { var dest2 []struct {
model.Category model.Category
@ -1584,5 +1536,5 @@ func TestQuickStartWithSubQueries(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
//jsonSave("./testdata/quick-start-dest2.json", dest2) //jsonSave("./testdata/quick-start-dest2.json", dest2)
assertJsonFile(t, "./testdata/quick-start-dest2.json", dest2) assertJSONFile(t, "./testdata/quick-start-dest2.json", dest2)
} }

View file

@ -1,11 +1,13 @@
package tests package tests
import ( import (
"context"
. "github.com/go-jet/jet" . "github.com/go-jet/jet"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"gotest.tools/assert" "gotest.tools/assert"
"testing" "testing"
"time"
) )
func TestUpdateValues(t *testing.T) { func TestUpdateValues(t *testing.T) {
@ -16,12 +18,12 @@ func TestUpdateValues(t *testing.T) {
SET("Bong", "http://bong.com"). SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing"))) WHERE(Link.Name.EQ(String("Bing")))
var expectedSql = ` var expectedSQL = `
UPDATE test_sample.link UPDATE test_sample.link
SET (name, url) = ('Bong', 'http://bong.com') SET (name, url) = ('Bong', 'http://bong.com')
WHERE link.name = 'Bing'; WHERE link.name = 'Bing';
` `
assertStatementSql(t, query, expectedSql, "Bong", "http://bong.com", "Bing") assertStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing")
assertExec(t, query, 1) assertExec(t, query, 1)
@ -54,7 +56,7 @@ func TestUpdateWithSubQueries(t *testing.T) {
). ).
WHERE(Link.Name.EQ(String("Bing"))) WHERE(Link.Name.EQ(String("Bing")))
expectedSql := ` expectedSQL := `
UPDATE test_sample.link UPDATE test_sample.link
SET (name, url) = (( SET (name, url) = ((
SELECT 'Bong' SELECT 'Bong'
@ -66,7 +68,7 @@ SET (name, url) = ((
WHERE link.name = 'Bing'; WHERE link.name = 'Bing';
` `
assertStatementSql(t, query, expectedSql, "Bong", "Bing", "Bing") assertStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing")
assertExec(t, query, 1) assertExec(t, query, 1)
} }
@ -74,7 +76,7 @@ WHERE link.name = 'Bing';
func TestUpdateAndReturning(t *testing.T) { func TestUpdateAndReturning(t *testing.T) {
setupLinkTableForUpdateTest(t) setupLinkTableForUpdateTest(t)
expectedSql := ` expectedSQL := `
UPDATE test_sample.link UPDATE test_sample.link
SET (name, url) = ('DuckDuckGo', 'http://www.duckduckgo.com') SET (name, url) = ('DuckDuckGo', 'http://www.duckduckgo.com')
WHERE link.name = 'Ask' WHERE link.name = 'Ask'
@ -90,7 +92,7 @@ RETURNING link.id AS "link.id",
WHERE(Link.Name.EQ(String("Ask"))). WHERE(Link.Name.EQ(String("Ask"))).
RETURNING(Link.AllColumns) RETURNING(Link.AllColumns)
assertStatementSql(t, stmt, expectedSql, "DuckDuckGo", "http://www.duckduckgo.com", "Ask") assertStatementSql(t, stmt, expectedSQL, "DuckDuckGo", "http://www.duckduckgo.com", "Ask")
links := []model.Link{} links := []model.Link{}
@ -112,7 +114,7 @@ func TestUpdateWithSelect(t *testing.T) {
). ).
WHERE(Link.ID.EQ(Int(0))) WHERE(Link.ID.EQ(Int(0)))
expectedSql := ` expectedSQL := `
UPDATE test_sample.link UPDATE test_sample.link
SET (id, url, name, description) = ( SET (id, url, name, description) = (
SELECT link.id AS "link.id", SELECT link.id AS "link.id",
@ -124,7 +126,7 @@ SET (id, url, name, description) = (
) )
WHERE link.id = 0; WHERE link.id = 0;
` `
assertStatementSql(t, stmt, expectedSql, int64(0), int64(0)) assertStatementSql(t, stmt, expectedSQL, int64(0), int64(0))
assertExec(t, stmt, 1) assertExec(t, stmt, 1)
} }
@ -139,7 +141,7 @@ func TestUpdateWithInvalidSelect(t *testing.T) {
). ).
WHERE(Link.ID.EQ(Int(0))) WHERE(Link.ID.EQ(Int(0)))
var expectedSql = ` var expectedSQL = `
UPDATE test_sample.link UPDATE test_sample.link
SET (id, url, name, description) = ( SET (id, url, name, description) = (
SELECT link.id AS "link.id", SELECT link.id AS "link.id",
@ -149,7 +151,7 @@ SET (id, url, name, description) = (
) )
WHERE link.id = 0; WHERE link.id = 0;
` `
assertStatementSql(t, stmt, expectedSql, int64(0), int64(0)) assertStatementSql(t, stmt, expectedSQL, int64(0), int64(0))
assertExecErr(t, stmt, "pq: number of columns does not match number of values") assertExecErr(t, stmt, "pq: number of columns does not match number of values")
} }
@ -168,12 +170,12 @@ func TestUpdateWithModelData(t *testing.T) {
MODEL(link). MODEL(link).
WHERE(Link.ID.EQ(Int(int64(link.ID)))) WHERE(Link.ID.EQ(Int(int64(link.ID))))
expectedSql := ` expectedSQL := `
UPDATE test_sample.link UPDATE test_sample.link
SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL) SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL)
WHERE link.id = 201; WHERE link.id = 201;
` `
assertStatementSql(t, stmt, expectedSql, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) assertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201))
assertExec(t, stmt, 1) assertExec(t, stmt, 1)
} }
@ -195,12 +197,12 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
MODEL(link). MODEL(link).
WHERE(Link.ID.EQ(Int(int64(link.ID)))) WHERE(Link.ID.EQ(Int(int64(link.ID))))
var expectedSql = ` var expectedSQL = `
UPDATE test_sample.link UPDATE test_sample.link
SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com')
WHERE link.id = 201; WHERE link.id = 201;
` `
assertStatementSql(t, stmt, expectedSql, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) assertStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201))
assertExec(t, stmt, 1) assertExec(t, stmt, 1)
} }
@ -231,16 +233,53 @@ func TestUpdateWithInvalidModelData(t *testing.T) {
MODEL(link). MODEL(link).
WHERE(Link.ID.EQ(Int(int64(link.Ident)))) WHERE(Link.ID.EQ(Int(int64(link.Ident))))
var expectedSql = ` var expectedSQL = `
UPDATE test_sample.link UPDATE test_sample.link
SET (id, url, name, description, rel) = ('http://www.duckduckgo.com', 'DuckDuckGo', NULL, NULL) SET (id, url, name, description, rel) = ('http://www.duckduckgo.com', 'DuckDuckGo', NULL, NULL)
WHERE link.id = 201; WHERE link.id = 201;
` `
assertStatementSql(t, stmt, expectedSql, "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201)) assertStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201))
assertExecErr(t, stmt, "pq: number of columns does not match number of values") assertExecErr(t, stmt, "pq: number of columns does not match number of values")
} }
func TestUpdateQueryContext(t *testing.T) {
setupLinkTableForUpdateTest(t)
updateStmt := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
dest := []model.Link{}
err := updateStmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded")
}
func TestUpdateExecContext(t *testing.T) {
setupLinkTableForUpdateTest(t)
updateStmt := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
_, err := updateStmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded")
}
func setupLinkTableForUpdateTest(t *testing.T) { func setupLinkTableForUpdateTest(t *testing.T) {
cleanUpLinkTable(t) cleanUpLinkTable(t)

View file

@ -1,5 +1,6 @@
package jet package jet
// TimeExpression interface
type TimeExpression interface { type TimeExpression interface {
Expression Expression
@ -58,19 +59,15 @@ type prefixTimeExpression struct {
prefixOpExpression prefixOpExpression
} }
func newPrefixTimeExpression(operator string, expression Expression) TimeExpression { //func newPrefixTimeExpression(operator string, expression Expression) TimeExpression {
timeExpr := prefixTimeExpression{} // timeExpr := prefixTimeExpression{}
timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) // timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
//
timeExpr.expressionInterfaceImpl.parent = &timeExpr // timeExpr.expressionInterfaceImpl.parent = &timeExpr
timeExpr.timeInterfaceImpl.parent = &timeExpr // timeExpr.timeInterfaceImpl.parent = &timeExpr
//
return &timeExpr // return &timeExpr
} //}
func INTERVAL(interval string) Expression {
return newPrefixTimeExpression("INTERVAL", literal(interval))
}
//---------------------------------------------------// //---------------------------------------------------//
@ -85,6 +82,9 @@ func newTimeExpressionWrap(expression Expression) TimeExpression {
return &timeExpressionWrap return &timeExpressionWrap
} }
// TimeExp is time expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as time expression.
// Does not add sql cast to generated sql builder output.
func TimeExp(expression Expression) TimeExpression { func TimeExp(expression Expression) TimeExpression {
return newTimeExpressionWrap(expression) return newTimeExpressionWrap(expression)
} }

View file

@ -4,34 +4,46 @@ import (
"testing" "testing"
) )
var timeVar = Time(10, 20, 0, 0)
func TestTimeExpressionEQ(t *testing.T) { func TestTimeExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.EQ(table2ColTime), "(table1.col_time = table2.col_time)") assertClauseSerialize(t, table1ColTime.EQ(table2ColTime), "(table1.col_time = table2.col_time)")
assertClauseSerialize(t, table1ColTime.EQ(Time(10, 20, 0, 0)), "(table1.col_time = $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.EQ(timeVar), "(table1.col_time = $1::time without time zone)", "10:20:00.000")
} }
func TestTimeExpressionNOT_EQ(t *testing.T) { func TestTimeExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.NOT_EQ(table2ColTime), "(table1.col_time != table2.col_time)") assertClauseSerialize(t, table1ColTime.NOT_EQ(table2ColTime), "(table1.col_time != table2.col_time)")
assertClauseSerialize(t, table1ColTime.NOT_EQ(Time(10, 20, 0, 0)), "(table1.col_time != $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.NOT_EQ(timeVar), "(table1.col_time != $1::time without time zone)", "10:20:00.000")
}
func TestTimeExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(table2ColTime), "(table1.col_time IS DISTINCT FROM table2.col_time)")
assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(timeVar), "(table1.col_time IS DISTINCT FROM $1::time without time zone)", "10:20:00.000")
}
func TestTimeExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(table2ColTime), "(table1.col_time IS NOT DISTINCT FROM table2.col_time)")
assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(timeVar), "(table1.col_time IS NOT DISTINCT FROM $1::time without time zone)", "10:20:00.000")
} }
func TestTimeExpressionLT(t *testing.T) { func TestTimeExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTime.LT(table2ColTime), "(table1.col_time < table2.col_time)") assertClauseSerialize(t, table1ColTime.LT(table2ColTime), "(table1.col_time < table2.col_time)")
assertClauseSerialize(t, table1ColTime.LT(Time(10, 20, 0, 0)), "(table1.col_time < $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.LT(timeVar), "(table1.col_time < $1::time without time zone)", "10:20:00.000")
} }
func TestTimeExpressionLT_EQ(t *testing.T) { func TestTimeExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.LT_EQ(table2ColTime), "(table1.col_time <= table2.col_time)") assertClauseSerialize(t, table1ColTime.LT_EQ(table2ColTime), "(table1.col_time <= table2.col_time)")
assertClauseSerialize(t, table1ColTime.LT_EQ(Time(10, 20, 0, 0)), "(table1.col_time <= $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.LT_EQ(timeVar), "(table1.col_time <= $1::time without time zone)", "10:20:00.000")
} }
func TestTimeExpressionGT(t *testing.T) { func TestTimeExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTime.GT(table2ColTime), "(table1.col_time > table2.col_time)") assertClauseSerialize(t, table1ColTime.GT(table2ColTime), "(table1.col_time > table2.col_time)")
assertClauseSerialize(t, table1ColTime.GT(Time(10, 20, 0, 0)), "(table1.col_time > $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.GT(timeVar), "(table1.col_time > $1::time without time zone)", "10:20:00.000")
} }
func TestTimeExpressionGT_EQ(t *testing.T) { func TestTimeExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.GT_EQ(table2ColTime), "(table1.col_time >= table2.col_time)") assertClauseSerialize(t, table1ColTime.GT_EQ(table2ColTime), "(table1.col_time >= table2.col_time)")
assertClauseSerialize(t, table1ColTime.GT_EQ(Time(10, 20, 0, 0)), "(table1.col_time >= $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.GT_EQ(timeVar), "(table1.col_time >= $1::time without time zone)", "10:20:00.000")
} }
func TestTimeExp(t *testing.T) { func TestTimeExp(t *testing.T) {

View file

@ -1,5 +1,6 @@
package jet package jet
// TimestampExpression interface
type TimestampExpression interface { type TimestampExpression interface {
Expression Expression
@ -63,6 +64,9 @@ func newTimestampExpressionWrap(expression Expression) TimestampExpression {
return &timestampExpressionWrap return &timestampExpressionWrap
} }
// TimestampExp is timestamp expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as timestamp expression.
// Does not add sql cast to generated sql builder output.
func TimestampExp(expression Expression) TimestampExpression { func TimestampExp(expression Expression) TimestampExpression {
return newTimestampExpressionWrap(expression) return newTimestampExpressionWrap(expression)
} }

View file

@ -0,0 +1,52 @@
package jet
import "testing"
var timestamp = Timestamp(2000, 1, 31, 10, 20, 0, 0)
func TestTimestampExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.EQ(table2ColTimestamp), "(table1.col_timestamp = table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.EQ(timestamp),
"(table1.col_timestamp = $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(table2ColTimestamp), "(table1.col_timestamp != table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(timestamp), "(table1.col_timestamp != $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS DISTINCT FROM table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS DISTINCT FROM $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS NOT DISTINCT FROM table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS NOT DISTINCT FROM $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.LT(table2ColTimestamp), "(table1.col_timestamp < table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.LT(timestamp), "(table1.col_timestamp < $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.LT_EQ(table2ColTimestamp), "(table1.col_timestamp <= table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.LT_EQ(timestamp), "(table1.col_timestamp <= $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.GT(table2ColTimestamp), "(table1.col_timestamp > table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.GT(timestamp), "(table1.col_timestamp > $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.GT_EQ(table2ColTimestamp), "(table1.col_timestamp >= table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.GT_EQ(timestamp), "(table1.col_timestamp >= $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExp(t *testing.T) {
assertClauseSerialize(t, TimestampExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp),
"(table1.col_float < $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}

View file

@ -1,5 +1,6 @@
package jet package jet
// TimestampzExpression interface
type TimestampzExpression interface { type TimestampzExpression interface {
Expression Expression
@ -63,6 +64,9 @@ func newTimestampzExpressionWrap(expression Expression) TimestampzExpression {
return &timestampzExpressionWrap return &timestampzExpressionWrap
} }
// TimestampzExp is timestamp with time zone expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as timestamp with time zone expression.
// Does not add sql cast to generated sql builder output.
func TimestampzExp(expression Expression) TimestampzExpression { func TimestampzExp(expression Expression) TimestampzExpression {
return newTimestampzExpressionWrap(expression) return newTimestampzExpressionWrap(expression)
} }

View file

@ -0,0 +1,52 @@
package jet
import "testing"
var timestampz = Timestampz(2000, 1, 31, 10, 20, 0, 0, 2)
func TestTimestampzExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.EQ(table2ColTimestampz), "(table1.col_timestampz = table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.EQ(timestampz),
"(table1.col_timestampz = $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(table2ColTimestampz), "(table1.col_timestampz != table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS NOT DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT(table2ColTimestampz), "(table1.col_timestampz < table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(table2ColTimestampz), "(table1.col_timestampz <= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT(table2ColTimestampz), "(table1.col_timestampz > table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(table2ColTimestampz), "(table1.col_timestampz >= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExp(t *testing.T) {
assertClauseSerialize(t, TimestampzExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz),
"(table1.col_float < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}

View file

@ -1,16 +1,25 @@
package jet package jet
// TimezExpression interface 'time with time zone'
type TimezExpression interface { type TimezExpression interface {
Expression Expression
//EQ
EQ(rhs TimezExpression) BoolExpression EQ(rhs TimezExpression) BoolExpression
//NOT_EQ
NOT_EQ(rhs TimezExpression) BoolExpression NOT_EQ(rhs TimezExpression) BoolExpression
//IS_DISTINCT_FROM
IS_DISTINCT_FROM(rhs TimezExpression) BoolExpression IS_DISTINCT_FROM(rhs TimezExpression) BoolExpression
//IS_NOT_DISTINCT_FROM
IS_NOT_DISTINCT_FROM(rhs TimezExpression) BoolExpression IS_NOT_DISTINCT_FROM(rhs TimezExpression) BoolExpression
//LT
LT(rhs TimezExpression) BoolExpression LT(rhs TimezExpression) BoolExpression
//LT_EQ
LT_EQ(rhs TimezExpression) BoolExpression LT_EQ(rhs TimezExpression) BoolExpression
//GT
GT(rhs TimezExpression) BoolExpression GT(rhs TimezExpression) BoolExpression
//GT_EQ
GT_EQ(rhs TimezExpression) BoolExpression GT_EQ(rhs TimezExpression) BoolExpression
} }
@ -58,15 +67,15 @@ type prefixTimezExpression struct {
prefixOpExpression prefixOpExpression
} }
func newPrefixTimezExpression(operator string, expression Expression) TimezExpression { //func newPrefixTimezExpression(operator string, expression Expression) TimezExpression {
timeExpr := prefixTimezExpression{} // timeExpr := prefixTimezExpression{}
timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) // timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
//
timeExpr.expressionInterfaceImpl.parent = &timeExpr // timeExpr.expressionInterfaceImpl.parent = &timeExpr
timeExpr.timezInterfaceImpl.parent = &timeExpr // timeExpr.timezInterfaceImpl.parent = &timeExpr
//
return &timeExpr // return &timeExpr
} //}
//---------------------------------------------------// //---------------------------------------------------//
@ -81,6 +90,9 @@ func newTimezExpressionWrap(expression Expression) TimezExpression {
return &timezExpressionWrap return &timezExpressionWrap
} }
// TimezExp is time with time zone expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as time with time zone expression.
// Does not add sql cast to generated sql builder output.
func TimezExp(expression Expression) TimezExpression { func TimezExp(expression Expression) TimezExpression {
return newTimezExpressionWrap(expression) return newTimezExpressionWrap(expression)
} }

51
timez_expression_test.go Normal file
View file

@ -0,0 +1,51 @@
package jet
import "testing"
var timezVar = Timez(10, 20, 0, 0, 4)
func TestTimezExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.EQ(table2ColTimez), "(table1.col_timez = table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.NOT_EQ(table2ColTimez), "(table1.col_timez != table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS NOT DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT(table2ColTimez), "(table1.col_timez < table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT_EQ(table2ColTimez), "(table1.col_timez <= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT(table2ColTimez), "(table1.col_timez > table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT_EQ(table2ColTimez), "(table1.col_timez >= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExp(t *testing.T) {
assertClauseSerialize(t, TimezExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, 4)),
"(table1.col_float < $1::time with time zone)", string("01:01:01.001 +04"))
}

View file

@ -5,8 +5,10 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"github.com/go-jet/jet/execution" "github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/utils"
) )
// UpdateStatement is interface of SQL UPDATE statement
type UpdateStatement interface { type UpdateStatement interface {
Statement Statement
@ -61,11 +63,11 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
out.newLine() out.newLine()
out.writeString("UPDATE") out.writeString("UPDATE")
if isNil(u.table) { if utils.IsNil(u.table) {
return "", nil, errors.New("jet: table to update is nil") return "", nil, errors.New("jet: table to update is nil")
} }
if err = u.table.serialize(update_statement, out); err != nil { if err = u.table.serialize(updateStatement, out); err != nil {
return return
} }
@ -100,7 +102,7 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
out.writeString("(") out.writeString("(")
} }
err = serializeClauseList(update_statement, u.row, out) err = serializeClauseList(updateStatement, u.row, out)
if err != nil { if err != nil {
return return
@ -114,11 +116,11 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
return "", nil, errors.New("jet: WHERE clause not set") return "", nil, errors.New("jet: WHERE clause not set")
} }
if err = out.writeWhere(update_statement, u.where); err != nil { if err = out.writeWhere(updateStatement, u.where); err != nil {
return return
} }
if err = out.writeReturning(update_statement, u.returning); err != nil { if err = out.writeReturning(updateStatement, u.returning); err != nil {
return return
} }
@ -134,14 +136,14 @@ func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) er
return query(u, db, destination) return query(u, db, destination)
} }
func (u *updateStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { func (u *updateStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(u, db, context, destination) return queryContext(context, u, db, destination)
} }
func (u *updateStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { func (u *updateStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(u, db) return exec(u, db)
} }
func (u *updateStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { func (u *updateStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(u, db, context) return execContext(context, u, db)
} }

View file

@ -5,7 +5,7 @@ import (
) )
func TestUpdateWithOneValue(t *testing.T) { func TestUpdateWithOneValue(t *testing.T) {
expectedSql := ` expectedSQL := `
UPDATE db.table1 UPDATE db.table1
SET col_int = $1 SET col_int = $1
WHERE table1.col_int >= $2; WHERE table1.col_int >= $2;
@ -14,11 +14,11 @@ WHERE table1.col_int >= $2;
SET(1). SET(1).
WHERE(table1ColInt.GT_EQ(Int(33))) WHERE(table1ColInt.GT_EQ(Int(33)))
assertStatement(t, stmt, expectedSql, 1, int64(33)) assertStatement(t, stmt, expectedSQL, 1, int64(33))
} }
func TestUpdateWithValues(t *testing.T) { func TestUpdateWithValues(t *testing.T) {
expectedSql := ` expectedSQL := `
UPDATE db.table1 UPDATE db.table1
SET (col_int, col_float) = ($1, $2) SET (col_int, col_float) = ($1, $2)
WHERE table1.col_int >= $3; WHERE table1.col_int >= $3;
@ -27,11 +27,11 @@ WHERE table1.col_int >= $3;
SET(1, 22.2). SET(1, 22.2).
WHERE(table1ColInt.GT_EQ(Int(33))) WHERE(table1ColInt.GT_EQ(Int(33)))
assertStatement(t, stmt, expectedSql, 1, 22.2, int64(33)) assertStatement(t, stmt, expectedSQL, 1, 22.2, int64(33))
} }
func TestUpdateOneColumnWithSelect(t *testing.T) { func TestUpdateOneColumnWithSelect(t *testing.T) {
expectedSql := ` expectedSQL := `
UPDATE db.table1 UPDATE db.table1
SET col_float = ( SET col_float = (
SELECT table1.col_float AS "table1.col_float" SELECT table1.col_float AS "table1.col_float"
@ -48,11 +48,11 @@ RETURNING table1.col1 AS "table1.col1";
WHERE(table1Col1.EQ(Int(2))). WHERE(table1Col1.EQ(Int(2))).
RETURNING(table1Col1) RETURNING(table1Col1)
assertStatement(t, stmt, expectedSql, int64(2)) assertStatement(t, stmt, expectedSQL, int64(2))
} }
func TestUpdateColumnsWithSelect(t *testing.T) { func TestUpdateColumnsWithSelect(t *testing.T) {
expectedSql := ` expectedSQL := `
UPDATE db.table1 UPDATE db.table1
SET (col1, col_float) = ( SET (col1, col_float) = (
SELECT table1.col_float AS "table1.col_float", SELECT table1.col_float AS "table1.col_float",
@ -67,7 +67,7 @@ RETURNING table1.col1 AS "table1.col1";
WHERE(table1Col1.EQ(Int(2))). WHERE(table1Col1.EQ(Int(2))).
RETURNING(table1Col1) RETURNING(table1Col1)
assertStatement(t, stmt, expectedSql, int64(2)) assertStatement(t, stmt, expectedSQL, int64(2))
} }
func TestInvalidInputs(t *testing.T) { func TestInvalidInputs(t *testing.T) {

View file

@ -7,7 +7,7 @@ import (
"strings" "strings"
) )
func serializeOrderByClauseList(statement statementType, orderByClauses []OrderByClause, out *sqlBuilder) error { func serializeOrderByClauseList(statement statementType, orderByClauses []orderByClause, out *sqlBuilder) error {
for i, value := range orderByClauses { for i, value := range orderByClauses {
if i > 0 { if i > 0 {
@ -32,7 +32,7 @@ func serializeGroupByClauseList(statement statementType, clauses []groupByClause
} }
if c == nil { if c == nil {
return errors.New("jet: nil clause.") return errors.New("jet: nil clause")
} }
if err = c.serializeForGroupBy(statement, out); err != nil { if err = c.serializeForGroupBy(statement, out); err != nil {
@ -51,7 +51,7 @@ func serializeClauseList(statement statementType, clauses []clause, out *sqlBuil
} }
if c == nil { if c == nil {
return errors.New("jet: nil clause.") return errors.New("jet: nil clause")
} }
if err = c.serialize(statement, out); err != nil { if err = c.serialize(statement, out); err != nil {
@ -124,16 +124,12 @@ func columnListToProjectionList(columns []Column) []projection {
return ret return ret
} }
func isNil(v interface{}) bool {
return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil())
}
func valueToClause(value interface{}) clause { func valueToClause(value interface{}) clause {
if clause, ok := value.(clause); ok { if clause, ok := value.(clause); ok {
return clause return clause
} else {
return literal(value)
} }
return literal(value)
} }
func unwindRowFromModel(columns []column, data interface{}) []clause { func unwindRowFromModel(columns []column, data interface{}) []clause {

View file

@ -10,7 +10,11 @@ var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float") var table1ColFloat = FloatColumn("col_float")
var table1Col3 = IntegerColumn("col3") var table1Col3 = IntegerColumn("col3")
var table1ColTime = TimeColumn("col_time") var table1ColTime = TimeColumn("col_time")
var table1ColTimez = TimezColumn("col_timez")
var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColTimestampz = TimestampzColumn("col_timestampz")
var table1ColBool = BoolColumn("col_bool") var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date")
var table1 = NewTable( var table1 = NewTable(
"db", "db",
@ -20,7 +24,12 @@ var table1 = NewTable(
table1ColFloat, table1ColFloat,
table1Col3, table1Col3,
table1ColTime, table1ColTime,
table1ColBool) table1ColTimez,
table1ColBool,
table1ColDate,
table1ColTimestamp,
table1ColTimestampz,
)
var table2Col3 = IntegerColumn("col3") var table2Col3 = IntegerColumn("col3")
var table2Col4 = IntegerColumn("col4") var table2Col4 = IntegerColumn("col4")
@ -29,6 +38,10 @@ var table2ColFloat = FloatColumn("col_float")
var table2ColStr = StringColumn("col_str") var table2ColStr = StringColumn("col_str")
var table2ColBool = BoolColumn("col_bool") var table2ColBool = BoolColumn("col_bool")
var table2ColTime = TimeColumn("col_time") var table2ColTime = TimeColumn("col_time")
var table2ColTimez = TimezColumn("col_timez")
var table2ColTimestamp = TimestampColumn("col_timestamp")
var table2ColTimestampz = TimestampzColumn("col_timestampz")
var table2ColDate = DateColumn("col_date")
var table2 = NewTable( var table2 = NewTable(
"db", "db",
@ -39,7 +52,12 @@ var table2 = NewTable(
table2ColFloat, table2ColFloat,
table2ColStr, table2ColStr,
table2ColBool, table2ColBool,
table2ColTime) table2ColTime,
table2ColTimez,
table2ColDate,
table2ColTimestamp,
table2ColTimestampz,
)
var table3Col1 = IntegerColumn("col1") var table3Col1 = IntegerColumn("col1")
var table3ColInt = IntegerColumn("col_int") var table3ColInt = IntegerColumn("col_int")
@ -53,7 +71,7 @@ var table3 = NewTable(
func assertClauseSerialize(t *testing.T, clause clause, query string, args ...interface{}) { func assertClauseSerialize(t *testing.T, clause clause, query string, args ...interface{}) {
out := sqlBuilder{} out := sqlBuilder{}
err := clause.serialize(select_statement, &out) err := clause.serialize(selectStatement, &out)
assert.NilError(t, err) assert.NilError(t, err)
@ -63,7 +81,7 @@ func assertClauseSerialize(t *testing.T, clause clause, query string, args ...in
func assertClauseSerializeErr(t *testing.T, clause clause, errString string) { func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
out := sqlBuilder{} out := sqlBuilder{}
err := clause.serialize(select_statement, &out) err := clause.serialize(selectStatement, &out)
//fmt.Println(out.buff.String()) //fmt.Println(out.buff.String())
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
@ -72,7 +90,7 @@ func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
func assertProjectionSerialize(t *testing.T, projection projection, query string, args ...interface{}) { func assertProjectionSerialize(t *testing.T, projection projection, query string, args ...interface{}) {
out := sqlBuilder{} out := sqlBuilder{}
err := projection.serializeForProjection(select_statement, &out) err := projection.serializeForProjection(selectStatement, &out)
assert.NilError(t, err) assert.NilError(t, err)