Merge pull request #7 from go-jet/develop

Merge develop to master
This commit is contained in:
go-jet 2019-07-23 13:42:47 +02:00 committed by GitHub
commit 05311b8129
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
84 changed files with 9819 additions and 1140 deletions

View file

@ -34,10 +34,7 @@ jobs:
go get github.com/davecgh/go-spew/spew
go get github.com/jstemmer/go-junit-report
- run: mkdir -p $TEST_RESULTS/unit-tests
- run: mkdir -p $TEST_RESULTS/integration-tests
- run: go test -v 2>&1 | go-junit-report > $TEST_RESULTS/unit-tests/results.xml
go install github.com/go-jet/jet/cmd/jet
- run:
name: Waiting for Postgres to be ready
@ -51,13 +48,19 @@ jobs:
echo Failed waiting for Postgres && exit 1
- run:
name: Run integration tests
name: Init Postgres database
command: |
cd tests
go run ./init/init.go
go test -v 2>&1 | go-junit-report > $TEST_RESULTS/integration-tests/results.xml
cd ..
- run: mkdir -p $TEST_RESULTS
- run: go test -v . ./tests -coverpkg=github.com/go-jet/jet,github.com/go-jet/jet/execution/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
- run:
name: Upload code coverage
command: bash <(curl -s https://codecov.io/bash)
- store_artifacts: # Upload test summary for display in Artifacts: https://circleci.com/docs/2.0/artifacts/
path: /tmp/test-results
destination: raw-test-output

View file

@ -1,10 +1,13 @@
# Jet
[![Go Report Card](https://goreportcard.com/badge/github.com/go-jet/jet)](https://goreportcard.com/report/github.com/go-jet/jet)
[![Documentation](https://godoc.org/github.com/go-jet/jet?status.svg)](http://godoc.org/github.com/go-jet/jet)
[![codecov](https://codecov.io/gh/go-jet/jet/branch/develop/graph/badge.svg)](https://codecov.io/gh/go-jet/jet)
[![CircleCI](https://circleci.com/gh/go-jet/jet/tree/develop.svg?style=svg&circle-token=97f255c6a4a3ab6590ea2e9195eb3ebf9f97b4a7)](https://circleci.com/gh/go-jet/jet/tree/develop)
Jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability to easily
convert database query result to desired arbitrary structure.
_Support for additional databases will be added in future jet releases._
_*Support for additional databases will be added in future jet releases._
## Contents
@ -22,7 +25,7 @@ _Support for additional databases will be added in future jet releases._
- [License](#license)
## Features
1) Type-safe SQL Builder
1) Auto-generated type-safe SQL Builder
- Types - boolean, integers(smallint, integer, bigint), floats(real, numeric, decimal, double precision),
strings(text, character, character varying), date, time(z), timestamp(z) and enums.
- Statements:
@ -31,10 +34,9 @@ _Support for additional databases will be added in future jet releases._
* UPDATE (SET, WHERE, RETURNING),
* DELETE (WHERE, RETURNING),
* LOCK (IN, NOWAIT)
2) Auto-generated Data Model types - Go struct mapped to database type (table or enum), used to store
result of database queries.
3) Query execution with mapping to arbitrary destination structure - destination structure can be
created by combining auto-generated data model types.
2) Auto-generated Data Model types - Go types mapped to database type (table or enum), used to store
result of database queries. Can be combined to create desired query result destination.
3) Query execution with result mapping to arbitrary destination structure.
## Getting Started
@ -63,15 +65,16 @@ Make sure GOPATH bin folder is added to the PATH environment variable.
For this quick start example we will use sample _dvd rental_ database. Full database dump can be found in [./tests/init/data/dvds.sql](./tests/init/data/dvds.sql).
Schema diagram of interest for example can be found [here](./examples/quick-start/diagram.png).
#### Generate sql builder and model files
#### Generate SQL Builder and Model files
To generate jet SQL Builder and Data Model files from postgres database we need to call `jet` generator with postgres
connection parameters and root destination folder path for generated files.\
Assuming we are running local postgres database, with user `jet`, database `jetdb` and schema `dvds` we will use this command:
Assuming we are running local postgres database, with user `jetuser`, user password `jetpass`, database `jetdb` and
schema `dvds` we will use this command:
```sh
jet -host=localhost -port=5432 -user=jet -password=jet -dbname=jetdb -schema=dvds -path=./gen
jet -host=localhost -port=5432 -user=jetuser -password=jetpass -dbname=jetdb -schema=dvds -path=./gen
```
```sh
Connecting to postgres database: host=localhost port=5432 user=jet password=jet dbname=jetdb sslmode=disable
Connecting to postgres database: host=localhost port=5432 user=jetuser password=jetpass dbname=jetdb sslmode=disable
Retrieving schema information...
FOUND 15 table(s), 1 enum(s)
Destination directory: ./gen/jetdb/dvds
@ -82,11 +85,12 @@ Generating enum sql builder files...
Generating enum model files...
Done
```
_*User has to have a permission to read information schema tables_
As command output suggest, Jet will:
- connect to postgres database and retrieve information about the _tables_ and _enums_ of `dvds` schema
- delete everything in schema destination folder - `./gen/jetdb/dvds`,
- and finally generate sql builder and model files for each schema tables and enums.
- and finally generate SQL Builder and Model files for each schema table and enum.
Generated files folder structure will look like this:
@ -101,13 +105,13 @@ Generated files folder structure will look like this:
| |-- address.go
| |-- category.go
| ...
| |-- model # Plain Old Data for every table and enum
| |-- model # model files for each table and enum
| | |-- actor.go
| | |-- address.go
| | |-- mpaa_rating.go
| | ...
```
Types from `table` and `enum` are used to write type safe SQL in Go, and `model` types are combined to store
Types from `table` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store
results of the SQL queries.
#### Lets write some SQL queries in Go
@ -147,12 +151,11 @@ stmt := SELECT(
Film.FilmID.ASC(),
)
```
With package(dot) import above statements looks almost the same as native SQL.
Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with
With package(dot) import above statements looks almost the same as native SQL. Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with
string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns
and can be compared only with integer columns and expressions.
__How to get parametrized SQL query?__
__How to get parametrized SQL query from statement?__
```go
query, args, err := stmt.Sql()
```
@ -160,7 +163,7 @@ query - parametrized query\
args - parameters for the query
<details>
<summary>Click to see `query` and `arg`</summary>
<summary>Click to see `query` and `args`</summary>
```sql
SELECT actor.actor_id AS "actor.actor_id",
@ -202,11 +205,12 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
</details>
__How to get debug SQL that can be copy pasted to sql editor and executed?__
```go
__How to get debug SQL from statement?__
```go
debugSql, err := stmt.DebugSql()
```
debugSql - parametrized query where every parameter is replaced with appropriate string argument representation
debugSql - query string that can be copy pasted to sql editor and executed. It's not intended to be used in production.
<details>
<summary>Click to see debug sql</summary>
@ -252,20 +256,24 @@ Well formed SQL is just a first half the job. Lets see how can we make some sens
above statement. Usually this is the most complex and tedious work, but with Jet it is the easiest.
First we have to create desired structure to store query result set.
This is done be combining autogenerated model types or it can be done manually(see wiki for more information).
This is done be combining autogenerated model types or it can be done manually(see [wiki](https://github.com/go-jet/jet/wiki/Scan-to-arbitrary-destination) for more information).
Let's say this is our desired structure:
Let's say this is our desired structure, created by combining auto-generated model types:
```go
var dest []struct {
model.Actor
Films []struct {
model.Film
Language model.Language
Categories []model.Category
Language model.Language
Categories []model.Category
}
}
```
_There is no limitation for how big or nested destination structure can be._
Because one actor can act in multiple films, `Films` field is a slice, and because each film belongs to one language
`Langauge` field is just a single model struct.
_*There is no limitation of how big or nested destination structure can be._
Now lets execute a above statement on open database connection db and store result into `dest`.
@ -494,12 +502,12 @@ ORM sometimes can access the database once for every object needed. Now lets say
different objects required from the database. This handler will last 3 seconds !!!.
With Jet, handler time lost on latency between server and database is constant. Because we can write complex query and
return result in one database call. Handler execution will be proportional to the number of rows returned from database.
return result in one database call. Handler execution will be only proportional to the number of rows returned from database.
ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate).
With Jet you can even join the whole database and store the whole structured result in in one query call.
This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/chinook_db_test.go#L40).
The whole test database is joined and query result is stored in a structured variable in less than 1s.
The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s.
##### How quickly bugs are found
@ -534,4 +542,4 @@ To run the tests, additional dependencies are required:
## License
Copyright 2019 Goran Bjelanovic
Licensed under the Apache License, Version 2.0.
Licensed under the Apache License, Version 2.0.

View file

@ -12,7 +12,7 @@ func newAlias(expression Expression, aliasName string) projection {
}
}
func (a *alias) from(subQuery ExpressionTable) projection {
func (a *alias) from(subQuery SelectTable) projection {
column := newColumn(a.alias, "", nil)
column.parent = &column
column.subQuery = subQuery

View file

@ -1,5 +1,6 @@
package jet
//BoolExpression interface
type BoolExpression interface {
Expression
@ -150,6 +151,9 @@ func newBoolExpressionWrap(expression Expression) BoolExpression {
return &boolExpressionWrap
}
// BoolExp is bool expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as bool expression.
// Does not add sql cast to generated sql builder output.
func BoolExp(expression Expression) BoolExpression {
return newBoolExpressionWrap(expression)
}

View file

@ -15,6 +15,16 @@ func TestBoolExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColBool.NOT_EQ(Bool(true)), "(table1.col_bool != $1)", true)
}
func TestBoolExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS DISTINCT FROM table2.col_bool)")
assertClauseSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS DISTINCT FROM $1)", false)
}
func TestBoolExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS NOT DISTINCT FROM table2.col_bool)")
assertClauseSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS NOT DISTINCT FROM $1)", false)
}
func TestBoolExpressionIS_TRUE(t *testing.T) {
assertClauseSerialize(t, table1ColBool.IS_TRUE(), "table1.col_bool IS TRUE")
assertClauseSerialize(t, (Int(2).EQ(table1ColInt)).IS_TRUE(),

View file

@ -36,6 +36,8 @@ type castImpl struct {
castType string
}
// CAST wraps expression for casting.
// For instance: CAST(table.column).AS_BOOL()
func CAST(expression Expression) cast {
return &castImpl{
Expression: expression,

View file

@ -40,12 +40,12 @@ type sqlBuilder struct {
type statementType string
const (
select_statement statementType = "SELECT"
insert_statement statementType = "INSERT"
update_statement statementType = "UPDATE"
delete_statement statementType = "DELETE"
set_statement statementType = "SET"
lock_statement statementType = "LOCK"
selectStatement statementType = "SELECT"
insertStatement statementType = "INSERT"
updateStatement statementType = "UPDATE"
deleteStatement statementType = "DELETE"
setStatement statementType = "SET"
lockStatement statementType = "LOCK"
)
const defaultIdent = 5
@ -102,7 +102,7 @@ func (q *sqlBuilder) writeGroupBy(statement statementType, groupBy []groupByClau
return err
}
func (q *sqlBuilder) writeOrderBy(statement statementType, orderBy []OrderByClause) error {
func (q *sqlBuilder) writeOrderBy(statement statementType, orderBy []orderByClause) error {
q.newLine()
q.writeString("ORDER BY")
@ -189,23 +189,18 @@ func (q *sqlBuilder) finalize() (string, []interface{}) {
}
func (q *sqlBuilder) insertConstantArgument(arg interface{}) {
q.writeString(ArgToString(arg))
q.writeString(argToString(arg))
}
func (q *sqlBuilder) insertPreparedArgument(arg interface{}) {
func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) {
q.args = append(q.args, arg)
argPlaceholder := "$" + strconv.Itoa(len(q.args))
q.writeString(argPlaceholder)
}
func (q *sqlBuilder) reset() {
q.buff.Reset()
q.args = []interface{}{}
}
func ArgToString(value interface{}) string {
if isNil(value) {
func argToString(value interface{}) string {
if utils.IsNil(value) {
return "NULL"
}
@ -213,9 +208,8 @@ func ArgToString(value interface{}) string {
case bool:
if bindVal {
return "TRUE"
} else {
return "FALSE"
}
return "FALSE"
case int8:
return strconv.FormatInt(int64(bindVal), 10)
case int:
@ -252,7 +246,7 @@ func ArgToString(value interface{}) string {
case time.Time:
return stringQuote(string(utils.FormatTimestamp(bindVal)))
default:
return "[Unknown type]"
return "[Unsupported type]"
}
}

33
clause_test.go Normal file
View file

@ -0,0 +1,33 @@
package jet
import (
"github.com/google/uuid"
"gotest.tools/assert"
"testing"
"time"
)
func TestArgToString(t *testing.T) {
assert.Equal(t, argToString(true), "TRUE")
assert.Equal(t, argToString(false), "FALSE")
assert.Equal(t, argToString(int8(-8)), "-8")
assert.Equal(t, argToString(int16(-16)), "-16")
assert.Equal(t, argToString(int(-32)), "-32")
assert.Equal(t, argToString(int32(-32)), "-32")
assert.Equal(t, argToString(int64(-64)), "-64")
assert.Equal(t, argToString(uint8(8)), "8")
assert.Equal(t, argToString(uint16(16)), "16")
assert.Equal(t, argToString(uint(32)), "32")
assert.Equal(t, argToString(uint32(32)), "32")
assert.Equal(t, argToString(uint64(64)), "64")
assert.Equal(t, argToString("john"), "'john'")
assert.Equal(t, argToString([]byte("john")), "'john'")
assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")
time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006")
assert.NilError(t, err)
assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'")
assert.Equal(t, argToString(map[string]bool{}), "[Unsupported type]")
}

View file

@ -4,8 +4,8 @@ import (
"flag"
"fmt"
"github.com/go-jet/jet/generator/postgres"
_ "github.com/lib/pq"
"os"
"strconv"
)
var (
@ -70,7 +70,7 @@ Usage of jet:
genData := postgres.DBConnection{
Host: host,
Port: strconv.Itoa(port),
Port: port,
User: user,
Password: password,
SslMode: sslmode,

View file

@ -7,10 +7,11 @@ type column interface {
TableName() string
setTableName(table string)
setSubQuery(subQuery ExpressionTable)
setSubQuery(subQuery SelectTable)
defaultAlias() string
}
// Column is common column interface for all types of columns.
type Column interface {
Expression
column
@ -23,7 +24,7 @@ type columnImpl struct {
name string
tableName string
subQuery ExpressionTable
subQuery SelectTable
}
func newColumn(name string, tableName string, parent Column) columnImpl {
@ -49,7 +50,7 @@ func (c *columnImpl) setTableName(table string) {
c.tableName = table
}
func (c *columnImpl) setSubQuery(subQuery ExpressionTable) {
func (c *columnImpl) setSubQuery(subQuery SelectTable) {
c.subQuery = subQuery
}
@ -62,7 +63,7 @@ func (c *columnImpl) defaultAlias() string {
}
func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
if statement == set_statement {
if statement == setStatement {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.writeString(`"` + c.defaultAlias() + `"`) //always quote
@ -104,13 +105,13 @@ func (c columnImpl) serialize(statement statementType, out *sqlBuilder, options
//------------------------------------------------------//
// Redefined type to support list of columns as projection
// ColumnList is redefined type to support list of columns as single projection
type ColumnList []Column
// projection interface implementation
func (cl ColumnList) isProjectionType() {}
func (cl ColumnList) from(subQuery ExpressionTable) projection {
func (cl ColumnList) from(subQuery SelectTable) projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
@ -134,8 +135,11 @@ func (cl ColumnList) serializeForProjection(statement statementType, out *sqlBui
// dummy column interface implementation
func (cl ColumnList) Name() string { return "" }
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery ExpressionTable) {}
func (cl ColumnList) defaultAlias() string { return "" }
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }

View file

@ -1,10 +1,11 @@
package jet
// ColumnBool is interface for SQL boolean columns.
type ColumnBool interface {
BoolExpression
column
From(subQuery ExpressionTable) ColumnBool
From(subQuery SelectTable) ColumnBool
}
type boolColumnImpl struct {
@ -13,7 +14,7 @@ type boolColumnImpl struct {
columnImpl
}
func (i *boolColumnImpl) from(subQuery ExpressionTable) projection {
func (i *boolColumnImpl) from(subQuery SelectTable) projection {
newBoolColumn := BoolColumn(i.name)
newBoolColumn.setTableName(i.tableName)
newBoolColumn.setSubQuery(subQuery)
@ -21,12 +22,13 @@ func (i *boolColumnImpl) from(subQuery ExpressionTable) projection {
return newBoolColumn
}
func (i *boolColumnImpl) From(subQuery ExpressionTable) ColumnBool {
func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
newBoolColumn := i.from(subQuery).(ColumnBool)
return newBoolColumn
}
// BoolColumn creates named bool column.
func BoolColumn(name string) ColumnBool {
boolColumn := &boolColumnImpl{}
boolColumn.columnImpl = newColumn(name, "", boolColumn)
@ -37,11 +39,12 @@ func BoolColumn(name string) ColumnBool {
//------------------------------------------------------//
// ColumnFloat is interface for SQL real, numeric, decimal or double precision column.
type ColumnFloat interface {
FloatExpression
column
From(subQuery ExpressionTable) ColumnFloat
From(subQuery SelectTable) ColumnFloat
}
type floatColumnImpl struct {
@ -49,7 +52,7 @@ type floatColumnImpl struct {
columnImpl
}
func (i *floatColumnImpl) from(subQuery ExpressionTable) projection {
func (i *floatColumnImpl) from(subQuery SelectTable) projection {
newFloatColumn := FloatColumn(i.name)
newFloatColumn.setTableName(i.tableName)
newFloatColumn.setSubQuery(subQuery)
@ -57,12 +60,13 @@ func (i *floatColumnImpl) from(subQuery ExpressionTable) projection {
return newFloatColumn
}
func (i *floatColumnImpl) From(subQuery ExpressionTable) ColumnFloat {
func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
newFloatColumn := i.from(subQuery).(ColumnFloat)
return newFloatColumn
}
// FloatColumn creates named float column.
func FloatColumn(name string) ColumnFloat {
floatColumn := &floatColumnImpl{}
floatColumn.floatInterfaceImpl.parent = floatColumn
@ -73,11 +77,12 @@ func FloatColumn(name string) ColumnFloat {
//------------------------------------------------------//
// ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger interface {
IntegerExpression
column
From(subQuery ExpressionTable) ColumnInteger
From(subQuery SelectTable) ColumnInteger
}
type integerColumnImpl struct {
@ -86,7 +91,7 @@ type integerColumnImpl struct {
columnImpl
}
func (i *integerColumnImpl) from(subQuery ExpressionTable) projection {
func (i *integerColumnImpl) from(subQuery SelectTable) projection {
newIntColumn := IntegerColumn(i.name)
newIntColumn.setTableName(i.tableName)
newIntColumn.setSubQuery(subQuery)
@ -94,10 +99,11 @@ func (i *integerColumnImpl) from(subQuery ExpressionTable) projection {
return newIntColumn
}
func (i *integerColumnImpl) From(subQuery ExpressionTable) ColumnInteger {
func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
return i.from(subQuery).(ColumnInteger)
}
// IntegerColumn creates named integer column.
func IntegerColumn(name string) ColumnInteger {
integerColumn := &integerColumnImpl{}
integerColumn.integerInterfaceImpl.parent = integerColumn
@ -108,11 +114,13 @@ func IntegerColumn(name string) ColumnInteger {
//------------------------------------------------------//
// ColumnString is interface for SQL text, character, character varying
// bytea, uuid columns and enums types.
type ColumnString interface {
StringExpression
column
From(subQuery ExpressionTable) ColumnString
From(subQuery SelectTable) ColumnString
}
type stringColumnImpl struct {
@ -121,7 +129,7 @@ type stringColumnImpl struct {
columnImpl
}
func (i *stringColumnImpl) from(subQuery ExpressionTable) projection {
func (i *stringColumnImpl) from(subQuery SelectTable) projection {
newStrColumn := StringColumn(i.name)
newStrColumn.setTableName(i.tableName)
newStrColumn.setSubQuery(subQuery)
@ -129,10 +137,11 @@ func (i *stringColumnImpl) from(subQuery ExpressionTable) projection {
return newStrColumn
}
func (i *stringColumnImpl) From(subQuery ExpressionTable) ColumnString {
func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
return i.from(subQuery).(ColumnString)
}
// StringColumn creates named string column.
func StringColumn(name string) ColumnString {
stringColumn := &stringColumnImpl{}
stringColumn.stringInterfaceImpl.parent = stringColumn
@ -143,11 +152,12 @@ func StringColumn(name string) ColumnString {
//------------------------------------------------------//
// ColumnTime is interface for SQL time column.
type ColumnTime interface {
TimeExpression
column
From(subQuery ExpressionTable) ColumnTime
From(subQuery SelectTable) ColumnTime
}
type timeColumnImpl struct {
@ -155,7 +165,7 @@ type timeColumnImpl struct {
columnImpl
}
func (i *timeColumnImpl) from(subQuery ExpressionTable) projection {
func (i *timeColumnImpl) from(subQuery SelectTable) projection {
newTimeColumn := TimeColumn(i.name)
newTimeColumn.setTableName(i.tableName)
newTimeColumn.setSubQuery(subQuery)
@ -163,10 +173,11 @@ func (i *timeColumnImpl) from(subQuery ExpressionTable) projection {
return newTimeColumn
}
func (i *timeColumnImpl) From(subQuery ExpressionTable) ColumnTime {
func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
return i.from(subQuery).(ColumnTime)
}
// TimeColumn creates named time column
func TimeColumn(name string) ColumnTime {
timeColumn := &timeColumnImpl{}
timeColumn.timeInterfaceImpl.parent = timeColumn
@ -176,11 +187,12 @@ func TimeColumn(name string) ColumnTime {
//------------------------------------------------------//
// ColumnTimez is interface of SQL time with time zone columns.
type ColumnTimez interface {
TimezExpression
column
From(subQuery ExpressionTable) ColumnTimez
From(subQuery SelectTable) ColumnTimez
}
type timezColumnImpl struct {
@ -189,7 +201,7 @@ type timezColumnImpl struct {
columnImpl
}
func (i *timezColumnImpl) from(subQuery ExpressionTable) projection {
func (i *timezColumnImpl) from(subQuery SelectTable) projection {
newTimezColumn := TimezColumn(i.name)
newTimezColumn.setTableName(i.tableName)
newTimezColumn.setSubQuery(subQuery)
@ -197,10 +209,11 @@ func (i *timezColumnImpl) from(subQuery ExpressionTable) projection {
return newTimezColumn
}
func (i *timezColumnImpl) From(subQuery ExpressionTable) ColumnTimez {
func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
return i.from(subQuery).(ColumnTimez)
}
// TimezColumn creates named time with time zone column.
func TimezColumn(name string) ColumnTimez {
timezColumn := &timezColumnImpl{}
timezColumn.timezInterfaceImpl.parent = timezColumn
@ -211,11 +224,12 @@ func TimezColumn(name string) ColumnTimez {
//------------------------------------------------------//
// ColumnTimestamp is interface of SQL timestamp columns.
type ColumnTimestamp interface {
TimestampExpression
column
From(subQuery ExpressionTable) ColumnTimestamp
From(subQuery SelectTable) ColumnTimestamp
}
type timestampColumnImpl struct {
@ -224,7 +238,7 @@ type timestampColumnImpl struct {
columnImpl
}
func (i *timestampColumnImpl) from(subQuery ExpressionTable) projection {
func (i *timestampColumnImpl) from(subQuery SelectTable) projection {
newTimestampColumn := TimestampColumn(i.name)
newTimestampColumn.setTableName(i.tableName)
newTimestampColumn.setSubQuery(subQuery)
@ -232,10 +246,11 @@ func (i *timestampColumnImpl) from(subQuery ExpressionTable) projection {
return newTimestampColumn
}
func (i *timestampColumnImpl) From(subQuery ExpressionTable) ColumnTimestamp {
func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
return i.from(subQuery).(ColumnTimestamp)
}
// TimestampColumn creates named timestamp column
func TimestampColumn(name string) ColumnTimestamp {
timestampColumn := &timestampColumnImpl{}
timestampColumn.timestampInterfaceImpl.parent = timestampColumn
@ -246,11 +261,12 @@ func TimestampColumn(name string) ColumnTimestamp {
//------------------------------------------------------//
// ColumnTimestampz is interface of SQL timestamp with timezone columns.
type ColumnTimestampz interface {
TimestampzExpression
column
From(subQuery ExpressionTable) ColumnTimestampz
From(subQuery SelectTable) ColumnTimestampz
}
type timestampzColumnImpl struct {
@ -259,7 +275,7 @@ type timestampzColumnImpl struct {
columnImpl
}
func (i *timestampzColumnImpl) from(subQuery ExpressionTable) projection {
func (i *timestampzColumnImpl) from(subQuery SelectTable) projection {
newTimestampzColumn := TimestampzColumn(i.name)
newTimestampzColumn.setTableName(i.tableName)
newTimestampzColumn.setSubQuery(subQuery)
@ -267,10 +283,11 @@ func (i *timestampzColumnImpl) from(subQuery ExpressionTable) projection {
return newTimestampzColumn
}
func (i *timestampzColumnImpl) From(subQuery ExpressionTable) ColumnTimestampz {
func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
return i.from(subQuery).(ColumnTimestampz)
}
// TimestampzColumn creates named timestamp with time zone column.
func TimestampzColumn(name string) ColumnTimestampz {
timestampzColumn := &timestampzColumnImpl{}
timestampzColumn.timestampzInterfaceImpl.parent = timestampzColumn
@ -281,11 +298,12 @@ func TimestampzColumn(name string) ColumnTimestampz {
//------------------------------------------------------//
// ColumnDate is interface of SQL date columns.
type ColumnDate interface {
DateExpression
column
From(subQuery ExpressionTable) ColumnDate
From(subQuery SelectTable) ColumnDate
}
type dateColumnImpl struct {
@ -294,7 +312,7 @@ type dateColumnImpl struct {
columnImpl
}
func (i *dateColumnImpl) from(subQuery ExpressionTable) projection {
func (i *dateColumnImpl) from(subQuery SelectTable) projection {
newDateColumn := DateColumn(i.name)
newDateColumn.setTableName(i.tableName)
newDateColumn.setSubQuery(subQuery)
@ -302,10 +320,11 @@ func (i *dateColumnImpl) from(subQuery ExpressionTable) projection {
return newDateColumn
}
func (i *dateColumnImpl) From(subQuery ExpressionTable) ColumnDate {
func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
return i.from(subQuery).(ColumnDate)
}
// DateColumn creates named date column.
func DateColumn(name string) ColumnDate {
dateColumn := &dateColumnImpl{}
dateColumn.dateInterfaceImpl.parent = dateColumn

View file

@ -1,5 +1,6 @@
package jet
// DateExpression is interface for all SQL date expressions.
type DateExpression interface {
Expression
@ -63,6 +64,9 @@ func newDateExpressionWrap(expression Expression) DateExpression {
return &dateExpressionWrap
}
// DateExp is date expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as date expression.
// Does not add sql cast to generated sql builder output.
func DateExp(expression Expression) DateExpression {
return newDateExpressionWrap(expression)
}

45
date_expression_test.go Normal file
View file

@ -0,0 +1,45 @@
package jet
import "testing"
var dateVar = Date(2000, 12, 30)
func TestDateExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.EQ(table2ColDate), "(table1.col_date = table2.col_date)")
assertClauseSerialize(t, table1ColDate.EQ(dateVar), "(table1.col_date = $1::date)", "2000-12-30")
}
func TestDateExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.NOT_EQ(table2ColDate), "(table1.col_date != table2.col_date)")
assertClauseSerialize(t, table1ColDate.NOT_EQ(dateVar), "(table1.col_date != $1::date)", "2000-12-30")
}
func TestDateExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(table2ColDate), "(table1.col_date IS DISTINCT FROM table2.col_date)")
assertClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(dateVar), "(table1.col_date IS DISTINCT FROM $1::date)", "2000-12-30")
}
func TestDateExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(table2ColDate), "(table1.col_date IS NOT DISTINCT FROM table2.col_date)")
assertClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(dateVar), "(table1.col_date IS NOT DISTINCT FROM $1::date)", "2000-12-30")
}
func TestDateExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColDate.GT(table2ColDate), "(table1.col_date > table2.col_date)")
assertClauseSerialize(t, table1ColDate.GT(dateVar), "(table1.col_date > $1::date)", "2000-12-30")
}
func TestDateExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.GT_EQ(table2ColDate), "(table1.col_date >= table2.col_date)")
assertClauseSerialize(t, table1ColDate.GT_EQ(dateVar), "(table1.col_date >= $1::date)", "2000-12-30")
}
func TestDateExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColDate.LT(table2ColDate), "(table1.col_date < table2.col_date)")
assertClauseSerialize(t, table1ColDate.LT(dateVar), "(table1.col_date < $1::date)", "2000-12-30")
}
func TestDateExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.LT_EQ(table2ColDate), "(table1.col_date <= table2.col_date)")
assertClauseSerialize(t, table1ColDate.LT_EQ(dateVar), "(table1.col_date <= $1::date)", "2000-12-30")
}

View file

@ -7,6 +7,7 @@ import (
"github.com/go-jet/jet/execution"
)
// DeleteStatement is interface for SQL DELETE statement
type DeleteStatement interface {
Statement
@ -48,7 +49,7 @@ func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
return errors.New("jet: nil tableName")
}
if err := d.table.serialize(delete_statement, out); err != nil {
if err := d.table.serialize(deleteStatement, out); err != nil {
return err
}
@ -56,11 +57,11 @@ func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
return errors.New("jet: deleting without a WHERE clause")
}
if err := out.writeWhere(delete_statement, d.where); err != nil {
if err := out.writeWhere(deleteStatement, d.where); err != nil {
return err
}
if err := out.writeReturning(delete_statement, d.returning); err != nil {
if err := out.writeReturning(deleteStatement, d.returning); err != nil {
return err
}
@ -88,14 +89,14 @@ func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) er
return query(d, db, destination)
}
func (d *deleteStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error {
return queryContext(d, db, context, destination)
func (d *deleteStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, d, db, destination)
}
func (d *deleteStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(d, db)
}
func (d *deleteStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) {
return execContext(d, db, context)
func (d *deleteStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, d, db)
}

4
doc.go
View file

@ -1,5 +1,5 @@
/*
Package Jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability
to easily convert database query result to desired arbitrary structure.
Package jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability
to easily convert database query result to desired arbitrary structure.
*/
package jet

View file

@ -6,6 +6,7 @@ type enumValue struct {
name string
}
// NewEnumValue creates new named enum value
func NewEnumValue(name string) StringExpression {
enumValue := &enumValue{name: name}

View file

@ -0,0 +1,12 @@
# Quick start example
This package contains sample usage for Jet framework.
Jet generated files of interest are in ./gen folder.
quick-start.go contains code explained at [README.md](../../README.md#quick-start),
with difference of redirecting json output to files(dest.json and dest2.json) rather then to a
standard output.
./gen, dest.json and dest2.json - added to git for presentation purposes.

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
_ "github.com/lib/pq"
"io/ioutil"
// dot import so go code would resemble as much as native SQL
// dot import is not mandatory
@ -12,15 +13,26 @@ import (
. "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table"
"github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/model"
"github.com/go-jet/jet/tests/dbconfig"
)
const (
Host = "localhost"
Port = 5432
User = "jet"
Password = "jet"
DBName = "jetdb"
)
func main() {
db, err := sql.Open("postgres", dbconfig.ConnectString)
// Connect to database
var connectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName)
db, err := sql.Open("postgres", connectString)
panicOnError(err)
defer db.Close()
// Write query
stmt := SELECT(
Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate,
Film.AllColumns,
@ -42,24 +54,13 @@ func main() {
Film.FilmID.ASC(),
)
query, args, err := stmt.Sql()
panicOnError(err)
fmt.Println("Parameterized query: ")
fmt.Println(query)
fmt.Println("Arguments: ")
fmt.Println(args)
debugSql, err := stmt.DebugSql()
panicOnError(err)
fmt.Println("Debug sql: ")
fmt.Println(debugSql)
// Execute query and store result
var dest []struct {
model.Actor
Films []struct {
model.Film
Language model.Language
Categories []model.Category
}
@ -68,9 +69,8 @@ func main() {
err = stmt.Query(db, &dest)
panicOnError(err)
fmt.Println("dest to json: ")
jsonText, _ := json.MarshalIndent(dest, "", "\t")
fmt.Println(string(jsonText))
printStatementInfo(stmt)
jsonSave("./dest.json", dest)
// New Destination
@ -84,9 +84,35 @@ func main() {
err = stmt.Query(db, &dest2)
panicOnError(err)
fmt.Println("dest2 to json: ")
jsonText, _ = json.MarshalIndent(dest2, "", "\t")
fmt.Println(string(jsonText))
jsonSave("./dest2.json", dest2)
}
func jsonSave(path string, v interface{}) {
jsonText, _ := json.MarshalIndent(v, "", "\t")
err := ioutil.WriteFile(path, jsonText, 0644)
if err != nil {
panic(err)
}
}
func printStatementInfo(stmt Statement) {
query, args, err := stmt.Sql()
panicOnError(err)
fmt.Println("Parameterized query: ")
fmt.Println(query)
fmt.Println("Arguments: ")
fmt.Println(args)
debugSQL, err := stmt.DebugSql()
panicOnError(err)
fmt.Println("\n\n==============================")
fmt.Println("\n\nDebug sql: ")
fmt.Println(debugSQL)
}
func panicOnError(err error) {

View file

@ -5,6 +5,7 @@ import (
"database/sql"
)
// DB is common database interface used by jet execution
type DB interface {
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)

View file

@ -7,16 +7,19 @@ import (
"errors"
"fmt"
"github.com/go-jet/jet/execution/internal"
"github.com/go-jet/jet/internal/utils"
"reflect"
"strconv"
"strings"
"time"
)
func Query(db DB, context context.Context, query string, args []interface{}, destinationPtr interface{}) error {
// Query executes query with arguments over database connection with context and stores result into destination.
// Destination can be either pointer to struct or pointer to slice of structs.
func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error {
if destinationPtr == nil {
return errors.New("jet: Destination is nil.")
if utils.IsNil(destinationPtr) {
return errors.New("jet: Destination is nil")
}
destinationPtrType := reflect.TypeOf(destinationPtr)
@ -25,12 +28,12 @@ func Query(db DB, context context.Context, query string, args []interface{}, des
}
if destinationPtrType.Elem().Kind() == reflect.Slice {
return queryToSlice(db, context, query, args, destinationPtr)
return queryToSlice(context, db, query, args, destinationPtr)
} else if destinationPtrType.Elem().Kind() == reflect.Struct {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem()
err := queryToSlice(db, context, query, args, tempSlicePtrValue.Interface())
err := queryToSlice(context, db, query, args, tempSlicePtrValue.Interface())
if err != nil {
return err
@ -52,7 +55,7 @@ func Query(db DB, context context.Context, query string, args []interface{}, des
}
}
func queryToSlice(db DB, ctx context.Context, query string, args []interface{}, slicePtr interface{}) error {
func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) error {
if db == nil {
return errors.New("jet: db is nil")
}
@ -142,23 +145,23 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
structPtrValue := getSliceElemPtrAt(slicePtrValue, index)
return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true)
} else {
destinationStructPtr := newElemPtrValueForSlice(slicePtrValue)
}
updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field)
destinationStructPtr := newElemPtrValueForSlice(slicePtrValue)
updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field)
if err != nil {
return
}
if updated {
scanContext.uniqueDestObjectsMap[groupKey] = slicePtrValue.Elem().Len()
err = appendElemToSlice(slicePtrValue, destinationStructPtr)
if err != nil {
return
}
if updated {
scanContext.uniqueDestObjectsMap[groupKey] = slicePtrValue.Elem().Len()
err = appendElemToSlice(slicePtrValue, destinationStructPtr)
if err != nil {
return
}
}
}
return
@ -481,9 +484,8 @@ func valueToString(value reflect.Value) string {
if value.Kind() == reflect.Ptr {
if value.IsNil() {
return "nil"
} else {
valueInterface = value.Elem().Interface()
}
valueInterface = value.Elem().Interface()
} else {
valueInterface = value.Interface()
}
@ -654,14 +656,14 @@ func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.
}
if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok {
return s.constructGroupKey(groupKeyInfo)
} else {
groupKeyInfo := s.getGroupKeyInfo(structType, structField)
s.groupKeyInfoCache[mapKey] = groupKeyInfo
return s.constructGroupKey(groupKeyInfo)
}
groupKeyInfo := s.getGroupKeyInfo(structType, structField)
s.groupKeyInfoCache[mapKey] = groupKeyInfo
return s.constructGroupKey(groupKeyInfo)
}
func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
@ -742,16 +744,6 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int {
return index
}
func (s *scanContext) getCellValue(typeName, fieldName string) interface{} {
index := s.typeToColumnIndex(typeName, fieldName)
if index < 0 {
return nil
}
return s.rowElem(index)
}
func (s *scanContext) rowElem(index int) interface{} {
valuer, ok := s.row[index].(driver.Valuer)

View file

@ -5,7 +5,7 @@ import (
"time"
)
// NullByteArray
// NullByteArray struct
type NullByteArray struct {
ByteArray []byte
Valid bool
@ -31,7 +31,7 @@ func (nb NullByteArray) Value() (driver.Value, error) {
return nb.ByteArray, nil
}
//NullTime
// NullTime struct
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
@ -51,6 +51,7 @@ func (nt NullTime) Value() (driver.Value, error) {
return nt.Time, nil
}
// NullInt32 struct
type NullInt32 struct {
Int32 int32
Valid bool // Valid is true if Int64 is not NULL
@ -83,6 +84,7 @@ func (n NullInt32) Value() (driver.Value, error) {
return n.Int32, nil
}
// NullInt16 struct
type NullInt16 struct {
Int16 int16
Valid bool // Valid is true if Int64 is not NULL
@ -115,6 +117,7 @@ func (n NullInt16) Value() (driver.Value, error) {
return n.Int16, nil
}
// NullFloat32 struct
type NullFloat32 struct {
Float32 float32
Valid bool // Valid is true if Int64 is not NULL

View file

@ -4,12 +4,13 @@ import (
"errors"
)
// Common expression interface
// Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface {
clause
projection
groupByClause
OrderByClause
orderByClause
// Test expression whether it is a NULL value.
IS_NULL() BoolExpression
@ -25,16 +26,16 @@ type Expression interface {
AS(alias string) projection
// Expression will be used to sort query result in ascending order
ASC() OrderByClause
ASC() orderByClause
// Expression will be used to sort query result in ascending order
DESC() OrderByClause
DESC() orderByClause
}
type expressionInterfaceImpl struct {
parent Expression
}
func (e *expressionInterfaceImpl) from(subQuery ExpressionTable) projection {
func (e *expressionInterfaceImpl) from(subQuery SelectTable) projection {
return e.parent
}
@ -58,11 +59,11 @@ func (e *expressionInterfaceImpl) AS(alias string) projection {
return newAlias(e.parent, alias)
}
func (e *expressionInterfaceImpl) ASC() OrderByClause {
func (e *expressionInterfaceImpl) ASC() orderByClause {
return newOrderByClause(e.parent, true)
}
func (e *expressionInterfaceImpl) DESC() OrderByClause {
func (e *expressionInterfaceImpl) DESC() orderByClause {
return newOrderByClause(e.parent, false)
}
@ -145,13 +146,13 @@ func newPrefixExpression(expression Expression, operator string) prefixOpExpress
func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil {
return errors.New("jet: Prefix Expression is nil.")
return errors.New("jet: Prefix Expression is nil")
}
out.writeString(p.operator + " ")
if p.expression == nil {
return errors.New("jet: nil prefix Expression.")
return errors.New("jet: nil prefix Expression")
}
if err := p.expression.serialize(statement, out); err != nil {
return err
@ -177,11 +178,11 @@ func newPostfixOpExpression(expression Expression, operator string) postfixOpExp
func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil {
return errors.New("jet: Postifx operator Expression is nil.")
return errors.New("jet: Postifx operator Expression is nil")
}
if p.expression == nil {
return errors.New("jet: nil prefix Expression.")
return errors.New("jet: nil prefix Expression")
}
if err := p.expression.serialize(statement, out); err != nil {
return err

View file

@ -1,62 +0,0 @@
package jet
import "errors"
type ExpressionTable interface {
ReadableTable
Alias() string
AllColumns() ProjectionList
}
type expressionTableImpl struct {
readableTableInterfaceImpl
expression Expression
alias string
projections []projection
}
func newExpressionTable(expression Expression, alias string, projections []projection) ExpressionTable {
expTable := &expressionTableImpl{expression: expression, alias: alias}
expTable.readableTableInterfaceImpl.parent = expTable
for _, projection := range projections {
newProjection := projection.from(expTable)
expTable.projections = append(expTable.projections, newProjection)
}
return expTable
}
func (e *expressionTableImpl) Alias() string {
return e.alias
}
func (e *expressionTableImpl) columns() []column {
return nil
}
func (e *expressionTableImpl) AllColumns() ProjectionList {
return e.projections
}
func (e *expressionTableImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if e == nil {
return errors.New("jet: Expression table is nil. ")
}
err := e.expression.serialize(statement, out)
if err != nil {
return err
}
out.writeString("AS")
out.writeIdentifier(e.alias)
return nil
}

View file

@ -1,5 +1,6 @@
package jet
//FloatExpression is interface for SQL float columns
type FloatExpression interface {
Expression
numericExpression
@ -115,6 +116,9 @@ func newFloatExpressionWrap(expression Expression) FloatExpression {
return &floatExpressionWrap
}
// FloatExp is date expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as float expression.
// Does not add sql cast to generated sql builder output.
func FloatExp(expression Expression) FloatExpression {
return newFloatExpressionWrap(expression)
}

View file

@ -14,6 +14,16 @@ func TestFloatExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColFloat.NOT_EQ(Float(2.11)), "(table1.col_float != $1)", float64(2.11))
}
func TestFloatExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColFloat.IS_DISTINCT_FROM(table2ColFloat), "(table1.col_float IS DISTINCT FROM table2.col_float)")
assertClauseSerialize(t, table1ColFloat.IS_DISTINCT_FROM(Float(2.11)), "(table1.col_float IS DISTINCT FROM $1)", float64(2.11))
}
func TestFloatExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColFloat.IS_NOT_DISTINCT_FROM(table2ColFloat), "(table1.col_float IS NOT DISTINCT FROM table2.col_float)")
assertClauseSerialize(t, table1ColFloat.IS_NOT_DISTINCT_FROM(Float(2.11)), "(table1.col_float IS NOT DISTINCT FROM $1)", float64(2.11))
}
func TestFloatExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColFloat.GT(table2ColFloat), "(table1.col_float > table2.col_float)")
assertClauseSerialize(t, table1ColFloat.GT(Float(2.11)), "(table1.col_float > $1)", float64(2.11))

View file

@ -2,6 +2,466 @@ package jet
import "errors"
// ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression {
return newFunc("ROW", expressions, nil)
}
// ------------------ Mathematical functions ---------------//
// ABSf calculates absolute value from float expression
func ABSf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("ABS", floatExpression)
}
// ABSi calculates absolute value from int expression
func ABSi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("ABS", integerExpression)
}
// SQRT calculates square root of numeric expression
func SQRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("SQRT", numericExpression)
}
// CBRT calculates cube root of numeric expression
func CBRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("CBRT", numericExpression)
}
// CEIL calculates ceil of float expression
func CEIL(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("CEIL", floatExpression)
}
// FLOOR calculates floor of float expression
func FLOOR(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("FLOOR", floatExpression)
}
// ROUND calculates round of a float expressions with optional precision
func ROUND(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression {
if len(precision) > 0 {
return newFloatFunc("ROUND", floatExpression, precision[0])
}
return newFloatFunc("ROUND", floatExpression)
}
// SIGN returns sign of float expression
func SIGN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SIGN", floatExpression)
}
// TRUNC calculates trunc of float expression with optional precision
func TRUNC(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression {
if len(precision) > 0 {
return newFloatFunc("TRUNC", floatExpression, precision[0])
}
return newFloatFunc("TRUNC", floatExpression)
}
// LN calculates natural algorithm of float expression
func LN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LN", floatExpression)
}
// LOG calculates logarithm of float expression
func LOG(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LOG", floatExpression)
}
// ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression
func AVG(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("AVG", numericExpression)
}
// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none.
func BIT_AND(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_AND", integerExpression)
}
// BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none.
func BIT_OR(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_OR", integerExpression)
}
// BOOL_AND is aggregate function. Returns true if all input values are true, otherwise false
func BOOL_AND(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_AND", boolExpression)
}
// BOOL_OR is aggregate function. Returns true if at least one input value is true, otherwise false
func BOOL_OR(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_OR", boolExpression)
}
// COUNT is aggregate function. Returns number of input rows for which the value of expression is not null.
func COUNT(expression Expression) IntegerExpression {
return newIntegerFunc("COUNT", expression)
}
// EVERY is aggregate function. Returns true if all input values are true, otherwise false
func EVERY(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("EVERY", boolExpression)
}
// MAXf is aggregate function. Returns maximum value of float expression across all input values
func MAXf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MAX", floatExpression)
}
// MAXi is aggregate function. Returns maximum value of int expression across all input values
func MAXi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MAX", integerExpression)
}
// MINf is aggregate function. Returns minimum value of float expression across all input values
func MINf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MIN", floatExpression)
}
// MINi is aggregate function. Returns minimum value of int expression across all input values
func MINi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MIN", integerExpression)
}
// SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SUM", floatExpression)
}
// SUMi is aggregate function. Returns sum of expression across all integer expression.
func SUMi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("SUM", integerExpression)
}
//------------ String functions ------------------//
// BIT_LENGTH returns number of bits in string expression
func BIT_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("BIT_LENGTH", stringExpression)
}
// CHAR_LENGTH returns number of characters in string expression
func CHAR_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("CHAR_LENGTH", stringExpression)
}
// OCTET_LENGTH returns number of bytes in string expression
func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("OCTET_LENGTH", stringExpression)
}
// LOWER returns string expression in lower case
func LOWER(stringExpression StringExpression) StringExpression {
return newStringFunc("LOWER", stringExpression)
}
// UPPER returns string expression in upper case
func UPPER(stringExpression StringExpression) StringExpression {
return newStringFunc("UPPER", stringExpression)
}
// BTRIM removes the longest string consisting only of characters
// in characters (a space by default) from the start and end of string
func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 {
return newStringFunc("BTRIM", stringExpression, trimChars[0])
}
return newStringFunc("BTRIM", stringExpression)
}
// LTRIM removes the longest string containing only characters
// from characters (a space by default) from the start of string
func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 {
return newStringFunc("LTRIM", str, trimChars[0])
}
return newStringFunc("LTRIM", str)
}
// RTRIM removes the longest string containing only characters
// from characters (a space by default) from the end of string
func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 {
return newStringFunc("RTRIM", str, trimChars[0])
}
return newStringFunc("RTRIM", str)
}
// CHR returns character with the given code.
func CHR(integerExpression IntegerExpression) StringExpression {
return newStringFunc("CHR", integerExpression)
}
//
//func CONCAT(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT", expressions...)
//}
//
//func CONCAT_WS(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT_WS", expressions...)
//}
// CONVERT converts string to dest_encoding. The original encoding is
// specified by src_encoding. The string must be valid in this encoding.
func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT", str, srcEncoding, destEncoding)
}
// CONVERT_FROM converts string to the database encoding. The original
// encoding is specified by src_encoding. The string must be valid in this encoding.
func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_FROM", str, srcEncoding)
}
// CONVERT_TO converts string to dest_encoding.
func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_TO", str, toEncoding)
}
// ENCODE encodes binary data into a textual representation.
// Supported formats are: base64, hex, escape. escape converts zero bytes and
// high-bit-set bytes to octal sequences (\nnn) and doubles backslashes.
func ENCODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("ENCODE", data, format)
}
// DECODE decodes binary data from textual representation in string.
// Options for format are same as in encode.
func DECODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("DECODE", data, format)
}
//func FORMAT(formatStr StringExpression, formatArgs ...expressions) StringExpression {
// args := []expressions{formatStr}
// args = append(args, formatArgs...)
// return newStringFunc("FORMAT", args...)
//}
// INITCAP converts the first letter of each word to upper case
// and the rest to lower case. Words are sequences of alphanumeric
// characters separated by non-alphanumeric characters.
func INITCAP(str StringExpression) StringExpression {
return newStringFunc("INITCAP", str)
}
// LEFT returns first n characters in the string.
// When n is negative, return all but last |n| characters.
func LEFT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("LEFT", str, n)
}
// RIGHT returns last n characters in the string.
// When n is negative, return all but first |n| characters.
func RIGHT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("RIGHT", str, n)
}
// LENGTH returns number of characters in string with a given encoding
func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 {
return newStringFunc("LENGTH", str, encoding[0])
}
return newStringFunc("LENGTH", str)
}
// LPAD fills up the string to length length by prepending the characters
// fill (a space by default). If the string is already longer than length
// then it is truncated (on the right).
func LPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 {
return newStringFunc("LPAD", str, length, text[0])
}
return newStringFunc("LPAD", str, length)
}
// RPAD fills up the string to length length by appending the characters
// fill (a space by default). If the string is already longer than length then it is truncated.
func RPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 {
return newStringFunc("RPAD", str, length, text[0])
}
return newStringFunc("RPAD", str, length)
}
// MD5 calculates the MD5 hash of string, returning the result in hexadecimal
func MD5(stringExpression StringExpression) StringExpression {
return newStringFunc("MD5", stringExpression)
}
// REPEAT repeats string the specified number of times
func REPEAT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("REPEAT", str, n)
}
// REPLACE replaces all occurrences in string of substring from with substring to
func REPLACE(text, from, to StringExpression) StringExpression {
return newStringFunc("REPLACE", text, from, to)
}
// REVERSE returns reversed string.
func REVERSE(stringExpression StringExpression) StringExpression {
return newStringFunc("REVERSE", stringExpression)
}
// STRPOS returns location of specified substring (same as position(substring in string),
// but note the reversed argument order)
func STRPOS(str, substring StringExpression) IntegerExpression {
return newIntegerFunc("STRPOS", str, substring)
}
// SUBSTR extracts substring
func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression {
if len(count) > 0 {
return newStringFunc("SUBSTR", str, from, count[0])
}
return newStringFunc("SUBSTR", str, from)
}
// TO_ASCII convert string to ASCII from another encoding
func TO_ASCII(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 {
return newStringFunc("TO_ASCII", str, encoding[0])
}
return newStringFunc("TO_ASCII", str)
}
// TO_HEX converts number to its equivalent hexadecimal representation
func TO_HEX(number IntegerExpression) StringExpression {
return newStringFunc("TO_HEX", number)
}
//----------Data Type Formatting Functions ----------------------//
// TO_CHAR converts expression to string with format
func TO_CHAR(expression Expression, format StringExpression) StringExpression {
return newStringFunc("TO_CHAR", expression, format)
}
// TO_DATE converts string to date using format
func TO_DATE(dateStr, format StringExpression) DateExpression {
return newDateFunc("TO_DATE", dateStr, format)
}
// TO_NUMBER converts string to numeric using format
func TO_NUMBER(floatStr, format StringExpression) FloatExpression {
return newFloatFunc("TO_NUMBER", floatStr, format)
}
// TO_TIMESTAMP converts string to time stamp with time zone using format
func TO_TIMESTAMP(timestampzStr, format StringExpression) TimestampzExpression {
return newTimestampzFunc("TO_TIMESTAMP", timestampzStr, format)
}
//----------------- Date/Time Functions and Operators ---------------//
// CURRENT_DATE returns current date
func CURRENT_DATE() DateExpression {
dateFunc := newDateFunc("CURRENT_DATE")
dateFunc.noBrackets = true
return dateFunc
}
// CURRENT_TIME returns current time with time zone
func CURRENT_TIME(precision ...int) TimezExpression {
var timezFunc *timezFunc
if len(precision) > 0 {
timezFunc = newTimezFunc("CURRENT_TIME", constLiteral(precision[0]))
} else {
timezFunc = newTimezFunc("CURRENT_TIME")
}
timezFunc.noBrackets = true
return timezFunc
}
// CURRENT_TIMESTAMP returns current timestamp with time zone
func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression {
var timestampzFunc *timestampzFunc
if len(precision) > 0 {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", constLiteral(precision[0]))
} else {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP")
}
timestampzFunc.noBrackets = true
return timestampzFunc
}
// LOCALTIME returns local time of day using optional precision
func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc
if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", constLiteral(precision[0]))
} else {
timeFunc = newTimeFunc("LOCALTIME")
}
timeFunc.noBrackets = true
return timeFunc
}
// LOCALTIMESTAMP returns current date and time using optional precision
func LOCALTIMESTAMP(precision ...int) TimestampExpression {
var timestampFunc *timestampFunc
if len(precision) > 0 {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP", constLiteral(precision[0]))
} else {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP")
}
timestampFunc.noBrackets = true
return timestampFunc
}
// NOW returns current date and time
func NOW() TimestampzExpression {
return newTimestampzFunc("NOW")
}
// --------------- Conditional Expressions Functions -------------//
// COALESCE function returns the first of its arguments that is not null.
func COALESCE(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("COALESCE", allValues, nil)
}
// NULLIF function returns a null value if value1 equals value2; otherwise it returns value1.
func NULLIF(value1, value2 Expression) Expression {
return newFunc("NULLIF", []Expression{value1, value2}, nil)
}
// GREATEST selects the largest value from a list of expressions
func GREATEST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("GREATEST", allValues, nil)
}
// LEAST selects the smallest value from a list of expressions
func LEAST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("LEAST", allValues, nil)
}
//--------------------------------------------------------------------//
type funcExpressionImpl struct {
expressionInterfaceImpl
@ -175,375 +635,3 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc {
return timestampzFunc
}
func ROW(expressions ...Expression) Expression {
return newFunc("ROW", expressions, nil)
}
// ------------------ Mathematical functions ---------------//
func ABSf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("ABS", floatExpression)
}
func ABSi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("ABS", integerExpression)
}
func SQRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("SQRT", numericExpression)
}
func CBRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("CBRT", numericExpression)
}
func CEIL(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("CEIL", floatExpression)
}
func FLOOR(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("FLOOR", floatExpression)
}
func ROUND(floatExpression FloatExpression, intExpression ...IntegerExpression) FloatExpression {
if len(intExpression) > 0 {
return newFloatFunc("ROUND", floatExpression, intExpression[0])
}
return newFloatFunc("ROUND", floatExpression)
}
func SIGN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SIGN", floatExpression)
}
func TRUNC(floatExpression FloatExpression, intExpression ...IntegerExpression) FloatExpression {
if len(intExpression) > 0 {
return newFloatFunc("TRUNC", floatExpression, intExpression[0])
}
return newFloatFunc("TRUNC", floatExpression)
}
func LN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LN", floatExpression)
}
func LOG(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LOG", floatExpression)
}
// ----------------- Aggregate functions -------------------//
func AVG(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("AVG", numericExpression)
}
func BIT_AND(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_AND", integerExpression)
}
func BIT_OR(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("BIT_OR", integerExpression)
}
func BOOL_AND(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_AND", boolExpression)
}
func BOOL_OR(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("BOOL_OR", boolExpression)
}
func COUNT(expression Expression) IntegerExpression {
return newIntegerFunc("COUNT", expression)
}
func EVERY(boolExpression BoolExpression) BoolExpression {
return newBoolFunc("EVERY", boolExpression)
}
func MAXf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MAX", floatExpression)
}
func MAXi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MAX", integerExpression)
}
func MINf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MIN", floatExpression)
}
func MINi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("MIN", integerExpression)
}
func SUMf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SUM", floatExpression)
}
func SUMi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("SUM", integerExpression)
}
//------------ String functions ------------------//
func BIT_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("BIT_LENGTH", stringExpression)
}
func CHAR_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("CHAR_LENGTH", stringExpression)
}
func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression {
return newIntegerFunc("OCTET_LENGTH", stringExpression)
}
func LOWER(stringExpression StringExpression) StringExpression {
return newStringFunc("LOWER", stringExpression)
}
func UPPER(stringExpression StringExpression) StringExpression {
return newStringFunc("UPPER", stringExpression)
}
func BTRIM(stringExpression StringExpression) StringExpression {
return newStringFunc("BTRIM", stringExpression)
}
func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 {
return newStringFunc("LTRIM", str, trimChars[0])
}
return newStringFunc("LTRIM", str)
}
func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 {
return newStringFunc("RTRIM", str, trimChars[0])
}
return newStringFunc("RTRIM", str)
}
func CHR(integerExpression IntegerExpression) StringExpression {
return newStringFunc("CHR", integerExpression)
}
//
//func CONCAT(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT", expressions...)
//}
//
//func CONCAT_WS(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT_WS", expressions...)
//}
func CONVERT(str StringExpression, fromEncoding StringExpression, toEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT", str, fromEncoding, toEncoding)
}
func CONVERT_FROM(str StringExpression, fromEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_FROM", str, fromEncoding)
}
func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_TO", str, toEncoding)
}
func ENCODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("ENCODE", data, format)
}
func DECODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("DECODE", data, format)
}
//func FORMAT(formatStr StringExpression, formatArgs ...expressions) StringExpression {
// args := []expressions{formatStr}
// args = append(args, formatArgs...)
// return newStringFunc("FORMAT", args...)
//}
func INITCAP(str StringExpression) StringExpression {
return newStringFunc("INITCAP", str)
}
func LEFT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("LEFT", str, n)
}
func RIGHT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("RIGHT", str, n)
}
func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 {
return newStringFunc("LENGTH", str, encoding[0])
}
return newStringFunc("LENGTH", str)
}
func LPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 {
return newStringFunc("LPAD", str, length, text[0])
}
return newStringFunc("LPAD", str, length)
}
func RPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 {
return newStringFunc("RPAD", str, length, text[0])
}
return newStringFunc("RPAD", str, length)
}
func MD5(stringExpression StringExpression) StringExpression {
return newStringFunc("MD5", stringExpression)
}
func REPEAT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("REPEAT", str, n)
}
func REPLACE(text, from, to StringExpression) StringExpression {
return newStringFunc("REPLACE", text, from, to)
}
func REVERSE(stringExpression StringExpression) StringExpression {
return newStringFunc("REVERSE", stringExpression)
}
func STRPOS(str, substring StringExpression) IntegerExpression {
return newIntegerFunc("STRPOS", str, substring)
}
func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression {
if len(count) > 0 {
return newStringFunc("SUBSTR", str, from, count[0])
}
return newStringFunc("SUBSTR", str, from)
}
func TO_ASCII(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 {
return newStringFunc("TO_ASCII", str, encoding[0])
}
return newStringFunc("TO_ASCII", str)
}
func TO_HEX(number IntegerExpression) StringExpression {
return newStringFunc("TO_HEX", number)
}
//----------Data Type Formatting Functions ----------------------//
func TO_CHAR(expression Expression, text StringExpression) StringExpression {
return newStringFunc("TO_CHAR", expression, text)
}
func TO_DATE(dateStr, format StringExpression) DateExpression {
return newDateFunc("TO_DATE", dateStr, format)
}
func TO_NUMBER(floatStr, format StringExpression) FloatExpression {
return newFloatFunc("TO_NUMBER", floatStr, format)
}
func TO_TIMESTAMP(timestampzStr, format StringExpression) TimestampzExpression {
return newTimestampzFunc("TO_TIMESTAMP", timestampzStr, format)
}
//----------------- Date/Time Functions and Operators ---------------//
func CURRENT_DATE() DateExpression {
dateFunc := newDateFunc("CURRENT_DATE")
dateFunc.noBrackets = true
return dateFunc
}
func CURRENT_TIME(precision ...int) TimezExpression {
var timezFunc *timezFunc
if len(precision) > 0 {
timezFunc = newTimezFunc("CURRENT_TIME", constLiteral(precision[0]))
} else {
timezFunc = newTimezFunc("CURRENT_TIME")
}
timezFunc.noBrackets = true
return timezFunc
}
func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression {
var timestampzFunc *timestampzFunc
if len(precision) > 0 {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", constLiteral(precision[0]))
} else {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP")
}
timestampzFunc.noBrackets = true
return timestampzFunc
}
func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc
if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", constLiteral(precision[0]))
} else {
timeFunc = newTimeFunc("LOCALTIME")
}
timeFunc.noBrackets = true
return timeFunc
}
func LOCALTIMESTAMP(precision ...int) TimestampExpression {
var timestampFunc *timestampFunc
if len(precision) > 0 {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP", constLiteral(precision[0]))
} else {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP")
}
timestampFunc.noBrackets = true
return timestampFunc
}
func NOW() TimestampzExpression {
return newTimestampzFunc("NOW")
}
// --------------- Conditional Expressions Functions -------------//
func COALESCE(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("COALESCE", allValues, nil)
}
func NULLIF(value1, value2 Expression) Expression {
return newFunc("NULLIF", []Expression{value1, value2}, nil)
}
func GREATEST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("GREATEST", allValues, nil)
}
func LEAST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value}
allValues = append(allValues, values...)
return newFunc("LEAST", allValues, nil)
}

View file

@ -1,7 +1,6 @@
package jet
import (
"gotest.tools/assert"
"testing"
)
@ -157,13 +156,7 @@ func TestFuncLEAST(t *testing.T) {
assertClauseSerialize(t, LEAST(Float(11.2222), NULL, String("str")), "LEAST($1, NULL, $2)", float64(11.2222), "str")
}
func TestInterval(t *testing.T) {
query := INTERVAL(`6 years 5 months 4 days 3 hours 2 minutes 1 second`)
queryData := &sqlBuilder{}
err := query.serialize(select_statement, queryData)
assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `INTERVAL $1`)
func TestTO_ASCII(t *testing.T) {
assertClauseSerialize(t, TO_ASCII(String("Karel")), `TO_ASCII($1)`, "Karel")
assertClauseSerialize(t, TO_ASCII(String("Karel")), `TO_ASCII($1)`, "Karel")
}

View file

@ -1,5 +1,6 @@
package metadata
// MetaData interface
type MetaData interface {
Name() string
}

View file

@ -1,4 +1,4 @@
package postgres_metadata
package postgresmeta
import (
"database/sql"
@ -7,6 +7,7 @@ import (
"strings"
)
// ColumnInfo metadata struct
type ColumnInfo struct {
Name string
IsNullable bool
@ -14,6 +15,7 @@ type ColumnInfo struct {
EnumName string
}
// SqlBuilderColumnType returns type of jet sql builder column
func (c ColumnInfo) SqlBuilderColumnType() string {
switch c.DataType {
case "boolean":
@ -41,6 +43,7 @@ func (c ColumnInfo) SqlBuilderColumnType() string {
}
}
// GoBaseType returns model type for column info.
func (c ColumnInfo) GoBaseType() string {
switch c.DataType {
case "USER-DEFINED":
@ -72,6 +75,8 @@ func (c ColumnInfo) GoBaseType() string {
}
}
// GoModelType returns model type for column info with optional pointer if
// column can be NULL.
func (c ColumnInfo) GoModelType() string {
typeStr := c.GoBaseType()
if c.IsNullable {
@ -81,6 +86,7 @@ func (c ColumnInfo) GoModelType() string {
return typeStr
}
// GoModelTag returns model field tag for column
func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string {
tags := []string{}

View file

@ -1,15 +1,17 @@
package postgres_metadata
package postgresmeta
import (
"database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
)
// EnumInfo struct
type EnumInfo struct {
name string
Values []string
}
// Name returns enum name
func (e EnumInfo) Name() string {
return e.name
}

View file

@ -1,10 +1,11 @@
package postgres_metadata
package postgresmeta
import (
"database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
)
// SchemaInfo metadata struct
type SchemaInfo struct {
DatabaseName string
Name string
@ -12,6 +13,7 @@ type SchemaInfo struct {
EnumInfos []metadata.MetaData
}
// GetSchemaInfo returns schema information from db connection.
func GetSchemaInfo(db *sql.DB, databaseName, schemaName string) (schemaInfo SchemaInfo, err error) {
schemaInfo.DatabaseName = databaseName

View file

@ -1,10 +1,11 @@
package postgres_metadata
package postgresmeta
import (
"database/sql"
"github.com/go-jet/jet/internal/utils"
)
// TableInfo metadata struct
type TableInfo struct {
SchemaName string
name string
@ -12,14 +13,17 @@ type TableInfo struct {
Columns []ColumnInfo
}
// Name returns table info name
func (t TableInfo) Name() string {
return t.name
}
func (t TableInfo) IsPrimaryKey(columnName string) bool {
return t.PrimaryKeys[columnName]
// IsPrimaryKey returns if column is a part of primary key
func (t TableInfo) IsPrimaryKey(column string) bool {
return t.PrimaryKeys[column]
}
// MutableColumns returns list of mutable columns for table
func (t TableInfo) MutableColumns() []ColumnInfo {
ret := []ColumnInfo{}
@ -34,6 +38,7 @@ func (t TableInfo) MutableColumns() []ColumnInfo {
return ret
}
// GetImports returns model imports for table.
func (t TableInfo) GetImports() []string {
imports := map[string]string{}
@ -57,10 +62,12 @@ func (t TableInfo) GetImports() []string {
return ret
}
// GoStructName returns go struct name for sql builder
func (t TableInfo) GoStructName() string {
return utils.ToGoIdentifier(t.name) + "Table"
}
// GetTableInfo returns table info metadata
func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo TableInfo, err error) {
tableInfo.SchemaName = schemaName

View file

@ -4,16 +4,17 @@ import (
"database/sql"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/metadata/postgres-metadata"
"github.com/go-jet/jet/generator/internal/metadata/postgresmeta"
"github.com/go-jet/jet/internal/utils"
_ "github.com/lib/pq"
"path"
"path/filepath"
"strconv"
)
// DBConnection contains postgres connection details
type DBConnection struct {
Host string
Port string
Port int
User string
Password string
SslMode string
@ -23,10 +24,11 @@ type DBConnection struct {
SchemaName string
}
func Generate(destDir string, genData DBConnection) error {
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
genData.Host, genData.Port, genData.User, genData.Password, genData.DBName, genData.SslMode, genData.Params)
dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params)
fmt.Println("Connecting to postgres database: " + connectionString)
@ -43,7 +45,7 @@ func Generate(destDir string, genData DBConnection) error {
}
fmt.Println("Retrieving schema information...")
schemaInfo, err := postgres_metadata.GetSchemaInfo(db, genData.DBName, genData.SchemaName)
schemaInfo, err := postgresmeta.GetSchemaInfo(db, dbConn.DBName, dbConn.SchemaName)
if err != nil {
return err
@ -55,7 +57,7 @@ func Generate(destDir string, genData DBConnection) error {
return nil
}
schemaGenPath := path.Join(destDir, genData.DBName, genData.SchemaName)
schemaGenPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
fmt.Println("Destination directory:", schemaGenPath)
fmt.Println("Cleaning up destination directory...")
err = utils.CleanUpGeneratedFiles(schemaGenPath)
@ -99,7 +101,7 @@ func Generate(destDir string, genData DBConnection) error {
return nil
}
func generate(schemaInfo postgres_metadata.SchemaInfo, dirPath, packageName string, template string, metaDataList []metadata.MetaData) error {
func generate(schemaInfo postgresmeta.SchemaInfo, dirPath, packageName string, template string, metaDataList []metadata.MetaData) error {
modelDirPath := filepath.Join(dirPath, schemaInfo.DatabaseName, schemaInfo.Name, packageName)
err := utils.EnsureDirPath(modelDirPath)

View file

@ -5,8 +5,10 @@ import (
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/utils"
)
// InsertStatement is interface for SQL INSERT statements
type InsertStatement interface {
Statement
@ -81,11 +83,11 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
queryData.newLine()
queryData.writeString("INSERT INTO")
if isNil(i.table) {
if utils.IsNil(i.table) {
return "", nil, errors.New("jet: table is nil")
}
err = i.table.serialize(insert_statement, queryData)
err = i.table.serialize(insertStatement, queryData)
if err != nil {
return
@ -114,8 +116,8 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
if len(i.rows) > 0 {
queryData.writeString("VALUES")
for row_i, row := range i.rows {
if row_i > 0 {
for rowIndex, row := range i.rows {
if rowIndex > 0 {
queryData.writeString(",")
}
@ -123,7 +125,7 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
queryData.newLine()
queryData.writeString("(")
err = serializeClauseList(insert_statement, row, queryData)
err = serializeClauseList(insertStatement, row, queryData)
if err != nil {
return "", nil, err
@ -135,14 +137,14 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
}
if i.query != nil {
err = i.query.serialize(insert_statement, queryData)
err = i.query.serialize(insertStatement, queryData)
if err != nil {
return
}
}
if err = queryData.writeReturning(insert_statement, i.returning); err != nil {
if err = queryData.writeReturning(insertStatement, i.returning); err != nil {
return
}
@ -155,14 +157,14 @@ func (i *insertStatementImpl) Query(db execution.DB, destination interface{}) er
return query(i, db, destination)
}
func (i *insertStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error {
return queryContext(i, db, context, destination)
func (i *insertStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, i, db, destination)
}
func (i *insertStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(i, db)
}
func (i *insertStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) {
return execContext(i, db, context)
func (i *insertStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, i, db)
}

View file

@ -81,13 +81,13 @@ func TestInsertValuesFromModel(t *testing.T) {
MODEL(toInsert).
MODEL(&toInsert)
expectedSql := `
expectedSQL := `
INSERT INTO db.table1 (col1, col_float) VALUES
($1, $2),
($3, $4);
`
assertStatement(t, stmt, expectedSql, int(1), float64(1.11), int(1), float64(1.11))
assertStatement(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11))
}
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
@ -125,23 +125,23 @@ func TestInsertQuery(t *testing.T) {
stmt := table1.INSERT(table1Col1).
QUERY(table1.SELECT(table1Col1))
var expectedSql = `
var expectedSQL = `
INSERT INTO db.table1 (col1) (
SELECT table1.col1 AS "table1.col1"
FROM db.table1
);
`
assertStatement(t, stmt, expectedSql)
assertStatement(t, stmt, expectedSQL)
}
func TestInsertDefaultValue(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColFloat).
VALUES(DEFAULT, "two")
var expectedSql = `
var expectedSQL = `
INSERT INTO db.table1 (col1, col_float) VALUES
(DEFAULT, $1);
`
assertStatement(t, stmt, expectedSql, "two")
assertStatement(t, stmt, expectedSQL, "two")
}

View file

@ -1,5 +1,6 @@
package jet
// IntegerExpression interface
type IntegerExpression interface {
Expression
numericExpression
@ -180,6 +181,9 @@ func newIntExpressionWrap(expression Expression) IntegerExpression {
return &intExpressionWrap
}
// IntExp is int expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as int expression.
// Does not add sql cast to generated sql builder output.
func IntExp(expression Expression) IntegerExpression {
return newIntExpressionWrap(expression)
}

View file

@ -8,42 +8,9 @@ import (
"unicode"
)
// CamelToSnake converts a given string to snake case
func CamelToSnake(s string) string {
var result string
var words []string
var lastPos int
rs := []rune(s)
for i := 0; i < len(rs); i++ {
if i > 0 && unicode.IsUpper(rs[i]) {
if initialism := startsWithInitialism(s[lastPos:]); initialism != "" {
words = append(words, initialism)
i += len(initialism) - 1
lastPos = i
continue
}
words = append(words, s[lastPos:i])
lastPos = i
}
}
// append the last word
if s[lastPos:] != "" {
words = append(words, s[lastPos:])
}
for k, word := range words {
if k > 0 {
result += "_"
}
result += strings.ToLower(word)
}
return result
// SnakeToCamel returns a string converted from snake case to uppercase
func SnakeToCamel(s string) string {
return snakeToCamel(s, true)
}
func snakeToCamel(s string, upperCase bool) string {
@ -54,28 +21,6 @@ func snakeToCamel(s string, upperCase bool) string {
words := strings.Split(s, "_")
//// if there is no underscore, first try commons and then just return
//if len(words) == 1 {
// if exception := snakeToCamelExceptions[words[0]]; len(exception) > 0 {
// return exception
// }
//
// if upperCase {
// if upper := strings.ToUpper(words[0]); commonInitialisms[upper] {
// return upper
// }
// }
//
// w := []rune(s)
// if upperCase {
// w[0] = unicode.ToUpper(w[0])
// } else {
// w[0] = unicode.ToLower(w[0])
// }
//
// return string(w)
//}
for i, word := range words {
if exception := snakeToCamelExceptions[word]; len(exception) > 0 {
result += exception
@ -117,28 +62,6 @@ func camelizeWord(word string, force bool) string {
return string(runes)
}
// SnakeToCamel returns a string converted from snake case to uppercase
func SnakeToCamel(s string) string {
return snakeToCamel(s, true)
}
// SnakeToCamelLower returns a string converted from snake case to lowercase
func SnakeToCamelLower(s string) string {
return snakeToCamel(s, false)
}
// startsWithInitialism returns the initialism if the given string begins with it
func startsWithInitialism(s string) string {
var initialism string
// the longest initialism is 5 char, the shortest 2
for i := 1; i <= 5; i++ {
if len(s) > i-1 && commonInitialisms[s[:i]] {
initialism = s[:i]
}
}
return initialism
}
// commonInitialisms, taken from
// https://github.com/golang/lint/blob/206c0f020eba0f7fbcfbc467a5eb808037df2ed6/lint.go#L731
var commonInitialisms = map[string]bool{

View file

@ -6,56 +6,6 @@ import (
)
var _ = Describe("Snaker", func() {
Describe("CamelToSnake test", func() {
It("should return an empty string on an empty input", func() {
Expect(CamelToSnake("")).To(Equal(""))
})
It("should work with one word", func() {
Expect(CamelToSnake("One")).To(Equal("one"))
})
It("should return an uppercase string as seperate words", func() {
Expect(CamelToSnake("ONE")).To(Equal("o_n_e"))
})
It("should return ID as lowercase", func() {
Expect(CamelToSnake("ID")).To(Equal("id"))
})
It("should work with a single lowercase character", func() {
Expect(CamelToSnake("i")).To(Equal("i"))
})
It("should work with a single uppcase character", func() {
Expect(CamelToSnake("I")).To(Equal("i"))
})
It("should return a long text as expected", func() {
Expect(CamelToSnake("ThisHasToBeConvertedCorrectlyID")).To(
Equal("this_has_to_be_converted_correctly_id"))
})
It("should return the text as expected if the initialism is in the middle", func() {
Expect(CamelToSnake("ThisIDIsFine")).To(Equal("this_id_is_fine"))
})
It("should work with long initialism", func() {
Expect(CamelToSnake("ThisHTTPSConnection")).To(Equal("this_https_connection"))
})
It("should work with multi initialisms", func() {
Expect(CamelToSnake("HelloHTTPSConnectionID")).To(Equal("hello_https_connection_id"))
})
It("sould work with concat initialisms", func() {
Expect(CamelToSnake("HTTPSID")).To(Equal("https_id"))
})
It("sould work with initialism where only certain characters are uppercase", func() {
Expect(CamelToSnake("OAuthClient")).To(Equal("oauth_client"))
})
})
Describe("SnakeToCamel test", func() {
It("should return an empty string on an empty input", func() {
@ -83,39 +33,8 @@ var _ = Describe("Snaker", func() {
Expect(SnakeToCamel("id")).To(Equal("ID"))
})
It("sould work with initialism where only certain characters are uppercase", func() {
It("should work with initialism where only certain characters are uppercase", func() {
Expect(SnakeToCamel("oauth_client")).To(Equal("OAuthClient"))
})
})
Describe("SnakeToCamelLower test", func() {
It("should return an empty string on an empty input", func() {
Ω(SnakeToCamelLower("")).To(Equal(""))
})
It("should not blow up on trailing _", func() {
Ω(SnakeToCamelLower("potato_")).To(Equal("potato"))
})
It("should return a snaked text as camel case", func() {
Ω(SnakeToCamelLower("this_has_to_be_uppercased")).To(
Equal("thisHasToBeUppercased"))
})
It("should return a snaked text as camel case, except the word ID", func() {
Ω(SnakeToCamelLower("this_is_an_id")).To(Equal("thisIsAnID"))
})
It("should return 'id' not as uppercase", func() {
Ω(SnakeToCamelLower("this_is_an_identifier")).To(Equal("thisIsAnIdentifier"))
})
It("should simply work with id", func() {
Ω(SnakeToCamelLower("id")).To(Equal("id"))
})
It("should simply work with leading id", func() {
Ω(SnakeToCamelLower("id_me_please")).To(Equal("idMePlease"))
})
})
})

View file

@ -6,24 +6,24 @@ import (
"go/format"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"text/template"
"time"
)
// ToGoIdentifier converts database to Go identifier.
func ToGoIdentifier(databaseIdentifier string) string {
if len(databaseIdentifier) == 0 {
return databaseIdentifier
}
return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier))
}
// ToGoFileName converts database identifier to Go file name.
func ToGoFileName(databaseIdentifier string) string {
return strings.ToLower(replaceInvalidChars(databaseIdentifier))
}
// SaveGoFile saves go file at folder dir, with name fileName and contents text.
func SaveGoFile(dirPath, fileName string, text []byte) error {
newGoFilePath := filepath.Join(dirPath, fileName) + ".go"
@ -49,6 +49,7 @@ func SaveGoFile(dirPath, fileName string, text []byte) error {
return nil
}
// EnsureDirPath ensures dir path exists. If path does not exist, creates new path.
func EnsureDirPath(dirPath string) error {
if _, err := os.Stat(dirPath); os.IsNotExist(err) {
err := os.MkdirAll(dirPath, os.ModePerm)
@ -61,6 +62,7 @@ func EnsureDirPath(dirPath string) error {
return nil
}
// GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
@ -82,6 +84,7 @@ func GenerateTemplate(templateText string, templateData interface{}) ([]byte, er
return buf.Bytes(), nil
}
// CleanUpGeneratedFiles deletes everything at folder dir.
func CleanUpGeneratedFiles(dir string) error {
exist, err := DirExists(dir)
@ -100,6 +103,7 @@ func CleanUpGeneratedFiles(dir string) error {
return nil
}
// DirExists checks if folder at path exist.
func DirExists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
@ -118,8 +122,7 @@ func replaceInvalidChars(str string) string {
return str
}
// github.com/lib/pq
// FormatTimestamp formats t into Postgres' text format for timestamps.
// FormatTimestamp formats t into Postgres' text format for timestamps. From: github.com/lib/pq
func FormatTimestamp(t time.Time) []byte {
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
// minus sign preferred by Go.
@ -152,3 +155,7 @@ func FormatTimestamp(t time.Time) []byte {
}
return b
}
func IsNil(v interface{}) bool {
return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil())
}

View file

@ -6,6 +6,7 @@ import (
)
func TestToGoIdentifier(t *testing.T) {
assert.Equal(t, ToGoIdentifier(""), "")
assert.Equal(t, ToGoIdentifier("uuid"), "UUID")
assert.Equal(t, ToGoIdentifier("col1"), "Col1")
assert.Equal(t, ToGoIdentifier("PG-13"), "Pg13")

View file

@ -1,11 +1,14 @@
package jet
const (
// DEFAULT is jet equivalent of SQL DEFAULT
DEFAULT keywordClause = "DEFAULT"
)
var (
// NULL is jet equivalent of SQL NULL
NULL = newNullLiteral()
// STAR is jet equivalent of SQL *
STAR = newStarLiteral()
)

View file

@ -27,7 +27,7 @@ func (l literalExpression) serialize(statement statementType, out *sqlBuilder, o
if l.constant {
out.insertConstantArgument(l.value)
} else {
out.insertPreparedArgument(l.value)
out.insertParametrizedArgument(l.value)
}
return nil
@ -38,6 +38,7 @@ type integerLiteralExpression struct {
integerInterfaceImpl
}
// Int is constructor for integer expressions literals.
func Int(value int64) IntegerExpression {
numLiteral := &integerLiteralExpression{}
@ -55,6 +56,7 @@ type boolLiteralExpression struct {
literalExpression
}
// Bool creates new bool literal expression
func Bool(value bool) BoolExpression {
boolLiteralExpression := boolLiteralExpression{}
@ -70,6 +72,7 @@ type floatLiteral struct {
literalExpression
}
// Float creates new float literal expression
func Float(value float64) FloatExpression {
floatLiteral := floatLiteral{}
floatLiteral.literalExpression = *literal(value)
@ -85,6 +88,7 @@ type stringLiteral struct {
literalExpression
}
// String creates new string literal expression
func String(value string) StringExpression {
stringLiteral := stringLiteral{}
stringLiteral.literalExpression = *literal(value)
@ -100,6 +104,7 @@ type timeLiteral struct {
literalExpression
}
// Time creates new time literal expression
func Time(hour, minute, second, milliseconds int) TimeExpression {
timeLiteral := &timeLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d.%03d", hour, minute, second, milliseconds)
@ -116,6 +121,7 @@ type timezLiteral struct {
literalExpression
}
// Timez creates new time with time zone literal expression
func Timez(hour, minute, second, milliseconds, timezone int) TimezExpression {
timezLiteral := &timezLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d.%03d %+03d", hour, minute, second, milliseconds, timezone)
@ -132,6 +138,7 @@ type timestampLiteral struct {
literalExpression
}
// Timestamp creates new timestamp literal expression
func Timestamp(year, month, day, hour, minute, second, milliseconds int) TimestampExpression {
timestampLiteral := &timestampLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%03d", year, month, day, hour, minute, second, milliseconds)
@ -148,6 +155,7 @@ type timestampzLiteral struct {
literalExpression
}
// Timestampz creates new timestamp with time zone literal expression
func Timestampz(year, month, day, hour, minute, second, milliseconds, timezone int) TimestampzExpression {
timestampzLiteral := &timestampzLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%03d %+04d",
@ -166,6 +174,7 @@ type dateLiteral struct {
literalExpression
}
//Date creates new date expression
func Date(year, month, day int) DateExpression {
dateLiteral := &dateLiteral{}
@ -226,6 +235,7 @@ func (n *wrap) serialize(statement statementType, out *sqlBuilder, options ...se
return err
}
// WRAP wraps list of expressions with brackets '(' and ')'
func WRAP(expression ...Expression) Expression {
wrap := &wrap{expressions: expression}
wrap.expressionInterfaceImpl.parent = wrap
@ -245,6 +255,8 @@ func (n *rawExpression) serialize(statement statementType, out *sqlBuilder, opti
return nil
}
// RAW can be used for any unsupported functions, operators or expressions.
// For example: RAW("current_database()")
func RAW(raw string) Expression {
rawExp := &rawExpression{raw: raw}
rawExp.expressionInterfaceImpl.parent = rawExp

View file

@ -7,8 +7,10 @@ import (
"github.com/go-jet/jet/execution"
)
// TableLockMode is a type of possible SQL table lock
type TableLockMode string
// Lock types for LockStatement.
const (
LOCK_ACCESS_SHARE = "ACCESS SHARE"
LOCK_ROW_SHARE = "ROW SHARE"
@ -20,6 +22,7 @@ const (
LOCK_ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE"
)
// LockStatement interface for SQL LOCK statement
type LockStatement interface {
Statement
@ -33,6 +36,7 @@ type lockStatementImpl struct {
nowait bool
}
// LOCK creates lock statement for list of tables.
func LOCK(tables ...WritableTable) LockStatement {
return &lockStatementImpl{
tables: tables,
@ -55,11 +59,11 @@ func (l *lockStatementImpl) DebugSql() (query string, err error) {
func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) {
if l == nil {
return "", nil, errors.New("jet: nil Statement.")
return "", nil, errors.New("jet: nil Statement")
}
if len(l.tables) == 0 {
return "", nil, errors.New("jet: There is no table selected to be locked. ")
return "", nil, errors.New("jet: There is no table selected to be locked")
}
out := &sqlBuilder{}
@ -72,7 +76,7 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
out.writeString(", ")
}
err := table.serialize(lock_statement, out)
err := table.serialize(lockStatement, out)
if err != nil {
return "", nil, err
@ -97,14 +101,14 @@ func (l *lockStatementImpl) Query(db execution.DB, destination interface{}) erro
return query(l, db, destination)
}
func (l *lockStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error {
return queryContext(l, db, context, destination)
func (l *lockStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, l, db, destination)
}
func (l *lockStatementImpl) Exec(db execution.DB) (sql.Result, error) {
return exec(l, db)
}
func (l *lockStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) {
return execContext(l, db, context)
func (l *lockStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, l, db)
}

View file

@ -1,5 +1,6 @@
package jet
// NumericExpression is common interface for all integer and float expressions
type NumericExpression interface {
Expression
numericExpression

View file

@ -4,17 +4,19 @@ import "errors"
//----------- Logical operators ---------------//
// Returns negation of bool expression expr
// NOT returns negation of bool expression result
func NOT(exp BoolExpression) BoolExpression {
return newPrefixBoolOperator(exp, "NOT")
}
// BIT_NOT inverts every bit in integer expression result
func BIT_NOT(expr IntegerExpression) IntegerExpression {
return newPrefixIntegerOperator(expr, "~")
}
//----------- Comparison operators ---------------//
// EXISTS checks for existence of the rows in subQuery
func EXISTS(subQuery SelectStatement) BoolExpression {
return newPrefixBoolOperator(subQuery, "EXISTS")
}
@ -59,12 +61,13 @@ func gtEq(lhs, rhs Expression) BoolExpression {
// --------------- CASE operator -------------------//
type CaseOperatorExpression interface {
// CaseOperator is interface for SQL case operator
type CaseOperator interface {
Expression
WHEN(condition Expression) CaseOperatorExpression
THEN(then Expression) CaseOperatorExpression
ELSE(els Expression) CaseOperatorExpression
WHEN(condition Expression) CaseOperator
THEN(then Expression) CaseOperator
ELSE(els Expression) CaseOperator
}
type caseOperatorImpl struct {
@ -76,7 +79,8 @@ type caseOperatorImpl struct {
els Expression
}
func CASE(expression ...Expression) CaseOperatorExpression {
// CASE create CASE operator with optional list of expressions
func CASE(expression ...Expression) CaseOperator {
caseExp := &caseOperatorImpl{}
if len(expression) > 0 {
@ -88,17 +92,17 @@ func CASE(expression ...Expression) CaseOperatorExpression {
return caseExp
}
func (c *caseOperatorImpl) WHEN(when Expression) CaseOperatorExpression {
func (c *caseOperatorImpl) WHEN(when Expression) CaseOperator {
c.when = append(c.when, when)
return c
}
func (c *caseOperatorImpl) THEN(then Expression) CaseOperatorExpression {
func (c *caseOperatorImpl) THEN(then Expression) CaseOperator {
c.then = append(c.then, then)
return c
}
func (c *caseOperatorImpl) ELSE(els Expression) CaseOperatorExpression {
func (c *caseOperatorImpl) ELSE(els Expression) CaseOperator {
c.els = els
return c

View file

@ -5,6 +5,7 @@ import "testing"
func TestOperatorNOT(t *testing.T) {
notExpression := NOT(Int(2).EQ(Int(1)))
assertClauseSerialize(t, NOT(table1ColBool), "NOT table1.col_bool")
assertClauseSerialize(t, notExpression, "NOT ($1 = $2)", int64(2), int64(1))
assertProjectionSerialize(t, notExpression.AS("alias_not_expression"), `NOT ($1 = $2) AS "alias_not_expression"`, int64(2), int64(1))
assertClauseSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `(NOT ($1 = $2) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5))

View file

@ -2,7 +2,8 @@ package jet
import "errors"
type OrderByClause interface {
// OrderByClause
type orderByClause interface {
serializeForOrderBy(statement statementType, out *sqlBuilder) error
}
@ -13,7 +14,7 @@ type orderByClauseImpl struct {
func (o *orderByClauseImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
if o.expression == nil {
return errors.New("jet: nil orderBy by clause.")
return errors.New("jet: nil orderBy by clause")
}
if err := o.expression.serializeForOrderBy(statement, out); err != nil {
@ -29,6 +30,6 @@ func (o *orderByClauseImpl) serializeForOrderBy(statement statementType, out *sq
return nil
}
func newOrderByClause(expression Expression, ascent bool) OrderByClause {
func newOrderByClause(expression Expression, ascent bool) orderByClause {
return &orderByClauseImpl{expression: expression, ascent: ascent}
}

View file

@ -2,12 +2,13 @@ package jet
type projection interface {
serializeForProjection(statement statementType, out *sqlBuilder) error
from(subQuery ExpressionTable) projection
from(subQuery SelectTable) projection
}
// ProjectionList is a redefined type, so that ProjectionList can be used as a projection.
type ProjectionList []projection
func (cl ProjectionList) from(subQuery ExpressionTable) projection {
func (cl ProjectionList) from(subQuery SelectTable) projection {
newProjectionList := ProjectionList{}
for _, projection := range cl {

View file

@ -7,6 +7,7 @@ import (
"github.com/go-jet/jet/execution"
)
// Select statements lock types
var (
UPDATE = newLock("UPDATE")
NO_KEY_UPDATE = newLock("NO KEY UPDATE")
@ -14,6 +15,7 @@ var (
KEY_SHARE = newLock("KEY SHARE")
)
// SelectStatement is interface for SQL SELECT statements
type SelectStatement interface {
Statement
Expression
@ -23,7 +25,7 @@ type SelectStatement interface {
WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...groupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement
ORDER_BY(orderByClauses ...OrderByClause) SelectStatement
ORDER_BY(orderByClauses ...orderByClause) SelectStatement
LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement
FOR(lock SelectLock) SelectStatement
@ -35,11 +37,12 @@ type SelectStatement interface {
EXCEPT(rhs SelectStatement) SelectStatement
EXCEPT_ALL(rhs SelectStatement) SelectStatement
AsTable(alias string) ExpressionTable
AsTable(alias string) SelectTable
projections() []projection
}
//SELECT creates new SelectStatement with list of projections
func SELECT(projection1 projection, projections ...projection) SelectStatement {
return newSelectStatement(nil, append([]projection{projection1}, projections...))
}
@ -55,7 +58,7 @@ type selectStatementImpl struct {
groupBy []groupByClause
having BoolExpression
orderBy []OrderByClause
orderBy []orderByClause
limit, offset int64
lockFor SelectLock
@ -81,8 +84,8 @@ func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement {
return s.parent
}
func (s *selectStatementImpl) AsTable(alias string) ExpressionTable {
return newExpressionTable(s.parent, alias, s.parent.projections())
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s.parent, alias)
}
func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement {
@ -100,7 +103,7 @@ func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement
return s.parent
}
func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement {
func (s *selectStatementImpl) ORDER_BY(clauses ...orderByClause) SelectStatement {
s.orderBy = clauses
return s.parent
}
@ -189,20 +192,20 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
return errors.New("jet: no column selected for projection")
}
err := out.writeProjections(select_statement, s.projectionList)
err := out.writeProjections(selectStatement, s.projectionList)
if err != nil {
return err
}
if s.table != nil {
if err := out.writeFrom(select_statement, s.table); err != nil {
if err := out.writeFrom(selectStatement, s.table); err != nil {
return err
}
}
if s.where != nil {
err := out.writeWhere(select_statement, s.where)
err := out.writeWhere(selectStatement, s.where)
if err != nil {
return nil
@ -210,7 +213,7 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
}
if s.groupBy != nil && len(s.groupBy) > 0 {
err := out.writeGroupBy(select_statement, s.groupBy)
err := out.writeGroupBy(selectStatement, s.groupBy)
if err != nil {
return err
@ -218,7 +221,7 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
}
if s.having != nil {
err := out.writeHaving(select_statement, s.having)
err := out.writeHaving(selectStatement, s.having)
if err != nil {
return err
@ -226,7 +229,7 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
}
if s.orderBy != nil {
err := out.writeOrderBy(select_statement, s.orderBy)
err := out.writeOrderBy(selectStatement, s.orderBy)
if err != nil {
return err
@ -236,19 +239,19 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
if s.limit >= 0 {
out.newLine()
out.writeString("LIMIT")
out.insertPreparedArgument(s.limit)
out.insertParametrizedArgument(s.limit)
}
if s.offset >= 0 {
out.newLine()
out.writeString("OFFSET")
out.insertPreparedArgument(s.offset)
out.insertParametrizedArgument(s.offset)
}
if s.lockFor != nil {
out.newLine()
out.writeString("FOR")
err := s.lockFor.serialize(select_statement, out)
err := s.lockFor.serialize(selectStatement, out)
if err != nil {
return err
@ -280,20 +283,19 @@ func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) er
return query(s.parent, db, destination)
}
func (s *selectStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error {
return queryContext(s.parent, db, context, destination)
func (s *selectStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, s.parent, db, destination)
}
func (s *selectStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(s.parent, db)
}
func (s *selectStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) {
return execContext(s.parent, db, context)
func (s *selectStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, s.parent, db)
}
// SelectLock
// SelectLock is interface for SELECT statement locks
type SelectLock interface {
clause

View file

@ -122,3 +122,91 @@ FROM db.table1
FOR NO KEY UPDATE SKIP LOCKED;
`)
}
func TestSelectSets(t *testing.T) {
select1 := SELECT(table1ColBool).FROM(table1)
select2 := SELECT(table2ColBool).FROM(table2)
assertStatement(t, select1.UNION(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
UNION
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.UNION_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
UNION ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.INTERSECT(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
INTERSECT
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.INTERSECT_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
INTERSECT ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.EXCEPT(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
EXCEPT
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
assertStatement(t, select1.EXCEPT_ALL(select2), `
(
(
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
)
EXCEPT ALL
(
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2
)
);
`)
}

63
select_table.go Normal file
View file

@ -0,0 +1,63 @@
package jet
import "errors"
// SelectTable is interface for SELECT sub-queries
type SelectTable interface {
ReadableTable
Alias() string
AllColumns() ProjectionList
}
type selectTableImpl struct {
readableTableInterfaceImpl
selectStmt SelectStatement
alias string
projections []projection
}
func newSelectTable(selectStmt SelectStatement, alias string) SelectTable {
expTable := &selectTableImpl{selectStmt: selectStmt, alias: alias}
expTable.readableTableInterfaceImpl.parent = expTable
for _, projection := range selectStmt.projections() {
newProjection := projection.from(expTable)
expTable.projections = append(expTable.projections, newProjection)
}
return expTable
}
func (s *selectTableImpl) Alias() string {
return s.alias
}
func (s *selectTableImpl) columns() []column {
return nil
}
func (s *selectTableImpl) AllColumns() ProjectionList {
return s.projections
}
func (s *selectTableImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if s == nil {
return errors.New("jet: Expression table is nil. ")
}
err := s.selectStmt.serialize(statement, out)
if err != nil {
return err
}
out.writeString("AS")
out.writeIdentifier(s.alias)
return nil
}

View file

@ -4,28 +4,40 @@ import (
"errors"
)
// UNION effectively appends the result of sub-queries(select statements) into single query.
// It eliminates duplicate rows from its result.
func UNION(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...))
}
// UNION_ALL effectively appends the result of sub-queries(select statements) into single query.
// It does not eliminates duplicate rows from its result.
func UNION_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...))
}
// INTERSECT returns all rows that are in query results.
// It eliminates duplicate rows from its result.
func INTERSECT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(intersect, false, toSelectList(lhs, rhs, selects...))
}
// INTERSECT_ALL returns all rows that are in query results.
// It does not eliminates duplicate rows from its result.
func INTERSECT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(intersect, true, toSelectList(lhs, rhs, selects...))
}
func EXCEPT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(except, false, toSelectList(lhs, rhs, selects...))
// EXCEPT returns all rows that are in the result of query lhs but not in the result of query rhs.
// It eliminates duplicate rows from its result.
func EXCEPT(lhs, rhs SelectStatement) SelectStatement {
return newSetStatementImpl(except, false, toSelectList(lhs, rhs))
}
func EXCEPT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(except, true, toSelectList(lhs, rhs, selects...))
// EXCEPT_ALL returns all rows that are in the result of query lhs but not in the result of query rhs.
// It does not eliminates duplicate rows from its result.
func EXCEPT_ALL(lhs, rhs SelectStatement) SelectStatement {
return newSetStatementImpl(except, true, toSelectList(lhs, rhs))
}
func toSelectList(lhs, rhs SelectStatement, selects ...SelectStatement) []SelectStatement {
@ -102,7 +114,7 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error {
}
if len(s.selects) < 2 {
return errors.New("jet: UNION Statement must have at least two SELECT statements.")
return errors.New("jet: UNION Statement must have at least two SELECT statements")
}
out.newLine()
@ -124,7 +136,7 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error {
return errors.New("jet: select statement is nil")
}
err := selectStmt.serialize(set_statement, out)
err := selectStmt.serialize(setStatement, out)
if err != nil {
return err
@ -136,7 +148,7 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error {
out.writeString(")")
if s.orderBy != nil {
err := out.writeOrderBy(set_statement, s.orderBy)
err := out.writeOrderBy(setStatement, s.orderBy)
if err != nil {
return err
}
@ -145,13 +157,13 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error {
if s.limit >= 0 {
out.newLine()
out.writeString("LIMIT")
out.insertPreparedArgument(s.limit)
out.insertParametrizedArgument(s.limit)
}
if s.offset >= 0 {
out.newLine()
out.writeString("OFFSET")
out.insertPreparedArgument(s.offset)
out.insertParametrizedArgument(s.offset)
}
return nil

View file

@ -6,7 +6,7 @@ import (
)
func TestUnionTwoSelect(t *testing.T) {
var expectedSql = `
var expectedSQL = `
(
(
SELECT table1.col1 AS "table1.col1"
@ -27,8 +27,8 @@ func TestUnionTwoSelect(t *testing.T) {
unionStmt2 := UNION(table1.SELECT(table1Col1), table2.SELECT(table2Col3))
assertStatement(t, unionStmt1, expectedSql)
assertStatement(t, unionStmt2, expectedSql)
assertStatement(t, unionStmt1, expectedSQL)
assertStatement(t, unionStmt2, expectedSQL)
}
func TestUnionNilSelect(t *testing.T) {
@ -49,7 +49,7 @@ func TestUnionThreeSelect1(t *testing.T) {
table3.SELECT(table3Col1),
)
var expectedSql = `
var expectedSQL = `
(
(
@ -71,7 +71,7 @@ func TestUnionThreeSelect1(t *testing.T) {
);
`
assertStatement(t, unionStmt1, expectedSql)
assertStatement(t, unionStmt1, expectedSQL)
}
func TestUnionThreeSelect2(t *testing.T) {
@ -82,7 +82,7 @@ func TestUnionThreeSelect2(t *testing.T) {
table3.SELECT(table3Col1),
)
var expectedSql = `
var expectedSQL = `
(
(
SELECT table1.col1 AS "table1.col1"
@ -101,7 +101,7 @@ func TestUnionThreeSelect2(t *testing.T) {
);
`
assertStatement(t, unionStmt2, expectedSql)
assertStatement(t, unionStmt2, expectedSQL)
}
func TestUnionWithOrderBy(t *testing.T) {
@ -155,7 +155,7 @@ OFFSET $2;
}
func TestUnionInUnion(t *testing.T) {
expectedSql := `
expectedSQL := `
(
(
SELECT table2.col3 AS "table2.col3",
@ -182,7 +182,7 @@ func TestUnionInUnion(t *testing.T) {
UNION_ALL(table1.SELECT(table1Col1), table2.SELECT(table2Col3)),
)
assertStatement(t, query, expectedSql)
assertStatement(t, query, expectedSQL)
}
func TestUnionALL(t *testing.T) {

View file

@ -8,6 +8,7 @@ import (
"strings"
)
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
type Statement interface {
// Sql returns parametrized sql query with list of arguments.
// err is returned if statement is not composed correctly
@ -22,12 +23,12 @@ type Statement interface {
Query(db execution.DB, destination interface{}) error
// QueryContext executes statement with a context over database connection db and stores row result in destination.
// Destination can be of arbitrary structure
QueryContext(db execution.DB, context context.Context, destination interface{}) error
QueryContext(context context.Context, db execution.DB, destination interface{}) error
//Exec executes statement over db connection without returning any rows.
Exec(db execution.DB) (sql.Result, error)
//Exec executes statement with context over db connection without returning any rows.
ExecContext(db execution.DB, context context.Context) (sql.Result, error)
ExecContext(context context.Context, db execution.DB) (sql.Result, error)
}
func debugSql(statement Statement) (string, error) {
@ -37,14 +38,14 @@ func debugSql(statement Statement) (string, error) {
return "", err
}
debugSqlQuery := sqlQuery
debugSQLQuery := sqlQuery
for i, arg := range args {
argPlaceholder := "$" + strconv.Itoa(i+1)
debugSqlQuery = strings.Replace(debugSqlQuery, argPlaceholder, ArgToString(arg), 1)
debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1)
}
return debugSqlQuery, nil
return debugSQLQuery, nil
}
func query(statement Statement, db execution.DB, destination interface{}) error {
@ -54,17 +55,17 @@ func query(statement Statement, db execution.DB, destination interface{}) error
return err
}
return execution.Query(db, context.Background(), query, args, destination)
return execution.Query(context.Background(), db, query, args, destination)
}
func queryContext(statement Statement, db execution.DB, context context.Context, destination interface{}) error {
func queryContext(context context.Context, statement Statement, db execution.DB, destination interface{}) error {
query, args, err := statement.Sql()
if err != nil {
return err
}
return execution.Query(db, context, query, args, destination)
return execution.Query(context, db, query, args, destination)
}
func exec(statement Statement, db execution.DB) (res sql.Result, err error) {
@ -77,7 +78,7 @@ func exec(statement Statement, db execution.DB) (res sql.Result, err error) {
return db.Exec(query, args...)
}
func execContext(statement Statement, db execution.DB, context context.Context) (res sql.Result, err error) {
func execContext(context context.Context, statement Statement, db execution.DB) (res sql.Result, err error) {
query, args, err := statement.Sql()
if err != nil {

View file

@ -1,5 +1,6 @@
package jet
// StringExpression interface
type StringExpression interface {
Expression
@ -108,6 +109,9 @@ func newStringExpressionWrap(expression Expression) StringExpression {
return &stringExpressionWrap
}
// StringExp is string expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as string expression.
// Does not add sql cast to generated sql builder output.
func StringExp(expression Expression) StringExpression {
return newStringExpressionWrap(expression)
}

View file

@ -17,6 +17,16 @@ func TestStringNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_EQ(String("JOHN")), "(table3.col2 != $1)", "JOHN")
}
func TestStringExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table3StrCol.IS_DISTINCT_FROM(table2ColStr), "(table3.col2 IS DISTINCT FROM table2.col_str)")
assertClauseSerialize(t, table3StrCol.IS_DISTINCT_FROM(String("JOHN")), "(table3.col2 IS DISTINCT FROM $1)", "JOHN")
}
func TestStringExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table3StrCol.IS_NOT_DISTINCT_FROM(table2ColStr), "(table3.col2 IS NOT DISTINCT FROM table2.col_str)")
assertClauseSerialize(t, table3StrCol.IS_NOT_DISTINCT_FROM(String("JOHN")), "(table3.col2 IS NOT DISTINCT FROM $1)", "JOHN")
}
func TestStringGT(t *testing.T) {
exp := table3StrCol.GT(table2ColStr)
assertClauseSerialize(t, exp, "(table3.col2 > table2.col_str)")

View file

@ -2,6 +2,7 @@ package jet
import (
"errors"
"github.com/go-jet/jet/internal/utils"
)
type table interface {
@ -36,18 +37,21 @@ type writableTable interface {
LOCK() LockStatement
}
// ReadableTable interface
type ReadableTable interface {
table
readableTable
clause
}
// WritableTable interface
type WritableTable interface {
table
writableTable
clause
}
// Table interface
type Table interface {
table
readableTable
@ -110,6 +114,7 @@ func (w *writableTableInterfaceImpl) LOCK() LockStatement {
return LOCK(w.parent)
}
// NewTable creates new table with schema name, table name and list of columns
func NewTable(schemaName, name string, columns ...Column) Table {
t := &tableImpl{
@ -196,20 +201,20 @@ type joinTable struct {
lhs ReadableTable
rhs ReadableTable
join_type joinType
joinType joinType
onCondition BoolExpression
}
func newJoinTable(
lhs ReadableTable,
rhs ReadableTable,
join_type joinType,
joinType joinType,
onCondition BoolExpression) ReadableTable {
joinTable := &joinTable{
lhs: lhs,
rhs: rhs,
join_type: join_type,
joinType: joinType,
onCondition: onCondition,
}
@ -235,7 +240,7 @@ func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options
return errors.New("jet: Join table is nil. ")
}
if isNil(t.lhs) {
if utils.IsNil(t.lhs) {
return errors.New("jet: left hand side of join operation is nil table")
}
@ -245,7 +250,7 @@ func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options
out.newLine()
switch t.join_type {
switch t.joinType {
case innerJoin:
out.writeString("INNER JOIN")
case leftJoin:
@ -258,7 +263,7 @@ func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options
out.writeString("CROSS JOIN")
}
if isNil(t.rhs) {
if utils.IsNil(t.rhs) {
return errors.New("jet: right hand side of join operation is nil table")
}
@ -266,7 +271,7 @@ func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options
return
}
if t.onCondition == nil && t.join_type != crossJoin {
if t.onCondition == nil && t.joinType != crossJoin {
return errors.New("jet: join condition is nil")
}

View file

@ -127,7 +127,10 @@ func TestStringOperators(t *testing.T) {
LOWER(AllTypes.CharacterVaryingPtr),
UPPER(AllTypes.Character),
BTRIM(AllTypes.CharacterVarying),
BTRIM(AllTypes.CharacterVarying, String("AA")),
LTRIM(AllTypes.CharacterVarying),
LTRIM(AllTypes.CharacterVarying, String("A")),
RTRIM(AllTypes.CharacterVarying),
RTRIM(AllTypes.CharacterVarying, String("B")),
CHR(Int(65)),
//CONCAT(String("string1"), Int(1), Float(11.12)),
@ -143,13 +146,16 @@ func TestStringOperators(t *testing.T) {
RIGHT(String("abcde"), Int(2)),
LENGTH(String("jose")),
LENGTH(String("jose"), String("UTF8")),
LPAD(String("Hi"), Int(5)),
LPAD(String("Hi"), Int(5), String("xy")),
RPAD(String("Hi"), Int(5)),
RPAD(String("Hi"), Int(5), String("xy")),
MD5(AllTypes.CharacterVarying),
REPEAT(AllTypes.Text, Int(33)),
REPLACE(AllTypes.Character, String("BA"), String("AB")),
REVERSE(AllTypes.CharacterVarying),
STRPOS(AllTypes.Text, String("A")),
SUBSTR(AllTypes.CharacterPtr, Int(3)),
SUBSTR(AllTypes.CharacterPtr, Int(3), Int(2)),
TO_HEX(AllTypes.IntegerPtr),
)
@ -229,7 +235,7 @@ func TestFloatOperators(t *testing.T) {
CEIL(AllTypes.Real),
FLOOR(AllTypes.Real),
ROUND(AllTypes.Decimal),
ROUND(AllTypes.Decimal, Int(3)).AS("round"),
ROUND(AllTypes.Decimal, AllTypes.Integer).AS("round"),
SIGN(AllTypes.Real),
TRUNC(AllTypes.Decimal),
TRUNC(AllTypes.Decimal, Int(1)),
@ -361,7 +367,7 @@ func TestSubQueryColumnReference(t *testing.T) {
args []interface{}
}
subQueries := map[ExpressionTable]expected{}
subQueries := map[SelectTable]expected{}
selectSubQuery := AllTypes.SELECT(
AllTypes.Boolean,
@ -378,7 +384,7 @@ func TestSubQueryColumnReference(t *testing.T) {
LIMIT(2).
AsTable("subQuery")
var selectExpectedSql = ` (
var selectexpectedSQL = ` (
SELECT all_types.boolean AS "all_types.boolean",
all_types.integer AS "all_types.integer",
all_types.real AS "all_types.real",
@ -424,7 +430,7 @@ func TestSubQueryColumnReference(t *testing.T) {
).
AsTable("subQuery")
unionExpectedSql := `
unionexpectedSQL := `
(
(
SELECT all_types.boolean AS "all_types.boolean",
@ -458,8 +464,8 @@ func TestSubQueryColumnReference(t *testing.T) {
)
) AS "subQuery"`
subQueries[selectSubQuery] = expected{sql: selectExpectedSql, args: []interface{}{int64(2)}}
subQueries[unionSubQuery] = expected{sql: unionExpectedSql, args: []interface{}{int64(1), int64(1), int64(1)}}
subQueries[selectSubQuery] = expected{sql: selectexpectedSQL, args: []interface{}{int64(2)}}
subQueries[unionSubQuery] = expected{sql: unionexpectedSQL, args: []interface{}{int64(1), int64(1), int64(1)}}
for subQuery, expected := range subQueries {
boolColumn := AllTypes.Boolean.From(subQuery)
@ -487,7 +493,7 @@ func TestSubQueryColumnReference(t *testing.T) {
).
FROM(subQuery)
var expectedSql = `
var expectedSQL = `
SELECT "subQuery"."all_types.boolean" AS "all_types.boolean",
"subQuery"."all_types.integer" AS "all_types.integer",
"subQuery"."all_types.real" AS "all_types.real",
@ -500,7 +506,7 @@ SELECT "subQuery"."all_types.boolean" AS "all_types.boolean",
"subQuery"."aliasedColumn" AS "aliasedColumn"
FROM`
assertStatementSql(t, stmt1, expectedSql+expected.sql+";\n", expected.args...)
assertStatementSql(t, stmt1, expectedSQL+expected.sql+";\n", expected.args...)
dest1 := []model.AllTypes{}
err := stmt1.Query(db, &dest1)
@ -523,7 +529,7 @@ FROM`
//fmt.Println(stmt2.DebugSql())
assertStatementSql(t, stmt2, expectedSql+expected.sql+";\n", expected.args...)
assertStatementSql(t, stmt2, expectedSQL+expected.sql+";\n", expected.args...)
dest2 := []model.AllTypes{}
err = stmt2.Query(db, &dest2)

View file

@ -1,6 +1,7 @@
package tests
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -9,6 +10,7 @@ import (
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/table"
"gotest.tools/assert"
"io/ioutil"
"runtime"
"testing"
"time"
)
@ -104,7 +106,7 @@ func TestJoinEverything(t *testing.T) {
assert.NilError(t, err)
assert.Equal(t, len(dest), 275)
assertJsonFile(t, "./testdata/joined_everything.json", dest)
assertJSONFile(t, "./testdata/joined_everything.json", dest)
}
func TestSelfJoin(t *testing.T) {
@ -144,7 +146,7 @@ ORDER BY "Employee"."EmployeeId";
assert.NilError(t, err)
assert.Equal(t, len(dest), 8)
assertJson(t, dest[0:2], `
assertJSON(t, dest[0:2], `
[
{
"EmployeeId": 1,
@ -244,23 +246,21 @@ ORDER BY "Album.AlbumId";
assert.DeepEqual(t, dest[1], album2)
}
//func TestQueryWithContext(t *testing.T) {
//
// ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
// defer cancel()
//
// dest := []model.Album{}
//
// err := Album.
// CROSS_JOIN(Track).
// CROSS_JOIN(InvoiceLine).
// SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
// QueryContext(db, ctx, &dest)
//
// spew.Dump(dest)
//
// assert.Error(t, err, "context deadline exceeded")
//}
func TestQueryWithContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
dest := []model.Album{}
err := Album.
CROSS_JOIN(Track).
CROSS_JOIN(InvoiceLine).
SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded")
}
func TestExecWithContext(t *testing.T) {
@ -271,7 +271,7 @@ func TestExecWithContext(t *testing.T) {
CROSS_JOIN(Track).
CROSS_JOIN(InvoiceLine).
SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
ExecContext(db, ctx)
ExecContext(ctx, db)
assert.Error(t, err, "pq: canceling statement due to user request")
}
@ -283,7 +283,7 @@ func TestSubQueriesForQuotedNames(t *testing.T) {
LIMIT(10).
AsTable("first10Artist")
artistId := Artist.ArtistId.From(first10Artist)
artistID := Artist.ArtistId.From(first10Artist)
first10Albums := Album.
SELECT(Album.AllColumns).
@ -291,12 +291,12 @@ func TestSubQueriesForQuotedNames(t *testing.T) {
LIMIT(10).
AsTable("first10Albums")
albumArtistId := Album.ArtistId.From(first10Albums)
albumArtistID := Album.ArtistId.From(first10Albums)
stmt := first10Artist.
INNER_JOIN(first10Albums, artistId.EQ(albumArtistId)).
INNER_JOIN(first10Albums, artistID.EQ(albumArtistID)).
SELECT(first10Artist.AllColumns(), first10Albums.AllColumns()).
ORDER_BY(artistId)
ORDER_BY(artistID)
assertStatementSql(t, stmt, `
SELECT "first10Artist"."Artist.ArtistId" AS "Artist.ArtistId",
@ -335,22 +335,26 @@ ORDER BY "first10Artist"."Artist.ArtistId";
//spew.Dump(dest)
}
func assertJson(t *testing.T, data interface{}, expectedJson string) {
func assertJSON(t *testing.T, data interface{}, expectedJSON string) {
jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err)
assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJson)
assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON)
}
func assertJsonFile(t *testing.T, jsonFilePath string, data interface{}) {
fileJsonData, err := ioutil.ReadFile(jsonFilePath)
func assertJSONFile(t *testing.T, jsonFilePath string, data interface{}) {
fileJSONData, err := ioutil.ReadFile(jsonFilePath)
assert.NilError(t, err)
if runtime.GOOS == "windows" {
fileJSONData = bytes.Replace(fileJSONData, []byte("\r\n"), []byte("\n"), -1)
}
jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err)
assert.Assert(t, string(fileJsonData) == string(jsonData))
//assert.Equal(t, string(fileJsonData), string(jsonData))
assert.Assert(t, string(fileJSONData) == string(jsonData))
//assert.Equal(t, string(fileJSONData), string(jsonData))
}
func jsonPrint(v interface{}) {

View file

@ -2,6 +2,7 @@ package dbconfig
import "fmt"
// test database connection parameters
const (
Host = "localhost"
Port = 5432
@ -10,4 +11,5 @@ const (
DBName = "jetdb"
)
// ConnectString is PostgreSQL connection string
var ConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName)

View file

@ -1,17 +1,19 @@
package tests
import (
"context"
. "github.com/go-jet/jet"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"gotest.tools/assert"
"testing"
"time"
)
func TestDeleteWithWhere(t *testing.T) {
initForDeleteTest(t)
var expectedSql = `
var expectedSQL = `
DELETE FROM test_sample.link
WHERE link.name IN ('Gmail', 'Outlook');
`
@ -19,14 +21,14 @@ WHERE link.name IN ('Gmail', 'Outlook');
DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
assertStatementSql(t, deleteStmt, expectedSql, "Gmail", "Outlook")
assertStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook")
assertExec(t, deleteStmt, 2)
}
func TestDeleteWithWhereAndReturning(t *testing.T) {
initForDeleteTest(t)
var expectedSql = `
var expectedSQL = `
DELETE FROM test_sample.link
WHERE link.name IN ('Gmail', 'Outlook')
RETURNING link.id AS "link.id",
@ -39,7 +41,7 @@ RETURNING link.id AS "link.id",
WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))).
RETURNING(Link.AllColumns)
assertStatementSql(t, deleteStmt, expectedSql, "Gmail", "Outlook")
assertStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook")
dest := []model.Link{}
@ -60,3 +62,38 @@ func initForDeleteTest(t *testing.T) {
assertExec(t, stmt, 2)
}
func TestDeleteQueryContext(t *testing.T) {
initForDeleteTest(t)
deleteStmt := Link.
DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
dest := []model.Link{}
err := deleteStmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded")
}
func TestDeleteExecContext(t *testing.T) {
initForDeleteTest(t)
deleteStmt := Link.
DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
_, err := deleteStmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded")
}

View file

@ -5,6 +5,7 @@ import (
"fmt"
"github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/tests/dbconfig"
_ "github.com/lib/pq"
"io/ioutil"
)
@ -38,7 +39,7 @@ func main() {
err = postgres.Generate("./.gentestdata", postgres.DBConnection{
Host: dbconfig.Host,
Port: "5432",
Port: 5432,
User: dbconfig.User,
Password: dbconfig.Password,
DBName: dbconfig.DBName,

View file

@ -1,17 +1,19 @@
package tests
import (
"context"
. "github.com/go-jet/jet"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"gotest.tools/assert"
"testing"
"time"
)
func TestInsertValues(t *testing.T) {
cleanUpLinkTable(t)
var expectedSql = `
var expectedSQL = `
INSERT INTO test_sample.link (id, url, name, description) VALUES
(100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
(101, 'http://www.google.com', 'Google', DEFAULT),
@ -28,7 +30,7 @@ RETURNING link.id AS "link.id",
VALUES(102, "http://www.yahoo.com", "Yahoo", nil).
RETURNING(Link.AllColumns)
assertStatementSql(t, insertQuery, expectedSql,
assertStatementSql(t, insertQuery, expectedSQL,
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
101, "http://www.google.com", "Google",
102, "http://www.yahoo.com", "Yahoo", nil)
@ -74,7 +76,7 @@ RETURNING link.id AS "link.id",
func TestInsertEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t)
expectedSql := `
expectedSQL := `
INSERT INTO test_sample.link VALUES
(100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT);
`
@ -82,7 +84,7 @@ INSERT INTO test_sample.link VALUES
stmt := Link.INSERT().
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT)
assertStatementSql(t, stmt, expectedSql,
assertStatementSql(t, stmt, expectedSQL,
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial")
assertExec(t, stmt, 1)
@ -90,7 +92,7 @@ INSERT INTO test_sample.link VALUES
func TestInsertModelObject(t *testing.T) {
cleanUpLinkTable(t)
var expectedSql = `
var expectedSQL = `
INSERT INTO test_sample.link (url, name) VALUES
('http://www.duckduckgo.com', 'Duck Duck go');
`
@ -104,19 +106,35 @@ INSERT INTO test_sample.link (url, name) VALUES
INSERT(Link.URL, Link.Name).
MODEL(linkData)
assertStatementSql(t, query, expectedSql, "http://www.duckduckgo.com", "Duck Duck go")
assertStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go")
result, err := query.Exec(db)
assertExec(t, query, 1)
}
assert.NilError(t, err)
func TestInsertModelObjectEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t)
var expectedSQL = `
INSERT INTO test_sample.link VALUES
(1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
`
rowsAffected, err := result.RowsAffected()
linkData := model.Link{
ID: 1000,
URL: "http://www.duckduckgo.com",
Name: "Duck Duck go",
}
assert.Equal(t, rowsAffected, int64(1))
query := Link.
INSERT().
MODEL(linkData)
assertStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil)
assertExec(t, query, 1)
}
func TestInsertModelsObject(t *testing.T) {
expectedSql := `
expectedSQL := `
INSERT INTO test_sample.link (url, name) VALUES
('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'),
@ -142,7 +160,7 @@ INSERT INTO test_sample.link (url, name) VALUES
INSERT(Link.URL, Link.Name).
MODELS([]model.Link{tutorial, google, yahoo})
assertStatementSql(t, stmt, expectedSql,
assertStatementSql(t, stmt, expectedSQL,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
"http://www.google.com", "Google",
"http://www.yahoo.com", "Yahoo")
@ -151,7 +169,7 @@ INSERT INTO test_sample.link (url, name) VALUES
}
func TestInsertUsingMutableColumns(t *testing.T) {
var expectedSql = `
var expectedSQL = `
INSERT INTO test_sample.link (url, name, description) VALUES
('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
('http://www.google.com', 'Google', NULL),
@ -175,7 +193,7 @@ INSERT INTO test_sample.link (url, name, description) VALUES
MODEL(google).
MODELS([]model.Link{google, yahoo})
assertStatementSql(t, stmt, expectedSql,
assertStatementSql(t, stmt, expectedSQL,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
"http://www.google.com", "Google", nil,
"http://www.google.com", "Google", nil,
@ -190,7 +208,7 @@ func TestInsertQuery(t *testing.T) {
Exec(db)
assert.NilError(t, err)
var expectedSql = `
var expectedSQL = `
INSERT INTO test_sample.link (url, name) (
SELECT link.url AS "link.url",
link.name AS "link.name"
@ -212,7 +230,7 @@ RETURNING link.id AS "link.id",
).
RETURNING(Link.AllColumns)
assertStatementSql(t, query, expectedSql, int64(0))
assertStatementSql(t, query, expectedSQL, int64(0))
dest := []model.Link{}
@ -229,3 +247,37 @@ RETURNING link.id AS "link.id",
assert.NilError(t, err)
assert.Equal(t, len(youtubeLinks), 2)
}
func TestInsertWithQueryContext(t *testing.T) {
cleanUpLinkTable(t)
stmt := Link.INSERT().
VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
RETURNING(Link.AllColumns)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
dest := []model.Link{}
err := stmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded")
}
func TestInsertWithExecContext(t *testing.T) {
cleanUpLinkTable(t)
stmt := Link.INSERT().
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
_, err := stmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded")
}

73
tests/lock_test.go Normal file
View file

@ -0,0 +1,73 @@
package tests
import (
"context"
"gotest.tools/assert"
"testing"
"time"
. "github.com/go-jet/jet"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table"
)
func TestLockTable(t *testing.T) {
expectedSQL := `
LOCK TABLE dvds.address IN`
var testData = []TableLockMode{
LOCK_ACCESS_SHARE,
LOCK_ROW_SHARE,
LOCK_ROW_EXCLUSIVE,
LOCK_SHARE_UPDATE_EXCLUSIVE,
LOCK_SHARE,
LOCK_SHARE_ROW_EXCLUSIVE,
LOCK_EXCLUSIVE,
LOCK_ACCESS_EXCLUSIVE,
}
for _, lockMode := range testData {
query := Address.LOCK().IN(lockMode)
assertStatementSql(t, query, expectedSQL+" "+string(lockMode)+" MODE;\n")
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
for _, lockMode := range testData {
query := Address.LOCK().IN(lockMode).NOWAIT()
assertStatementSql(t, query, expectedSQL+" "+string(lockMode)+" MODE NOWAIT;\n")
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
}
func TestLockExecContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
tx, _ := db.Begin()
defer tx.Rollback()
_, err := Address.LOCK().IN(LOCK_ACCESS_SHARE).ExecContext(ctx, tx)
assert.Error(t, err, "context deadline exceeded")
}

View file

@ -1,13 +1,17 @@
package tests
import (
"bytes"
"database/sql"
"github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model"
"github.com/go-jet/jet/tests/dbconfig"
_ "github.com/lib/pq"
"github.com/pkg/profile"
"gotest.tools/assert"
"io/ioutil"
"os"
"os/exec"
"reflect"
"testing"
)
@ -29,7 +33,7 @@ func TestMain(m *testing.M) {
os.Exit(ret)
}
func TestGenerateModel(t *testing.T) {
func TestGeneratedModel(t *testing.T) {
actor := model.Actor{}
@ -58,3 +62,189 @@ func TestGenerateModel(t *testing.T) {
assert.Equal(t, reflect.TypeOf(staff.Email).String(), "*string")
assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8")
}
const genTestDir2 = "./.gentestdata2"
func TestCmdGenerator(t *testing.T) {
err := os.RemoveAll(genTestDir2)
assert.NilError(t, err)
cmd := exec.Command("jet", "-dbname=jetdb", "-host=localhost", "-port=5432",
"-user=jet", "-password=jet", "-schema=dvds", "-path="+genTestDir2)
err = cmd.Run()
assert.NilError(t, err)
assertGeneratedFiles(t)
err = os.RemoveAll(genTestDir2)
assert.NilError(t, err)
}
func TestGenerator(t *testing.T) {
err := os.RemoveAll(genTestDir2)
assert.NilError(t, err)
err = postgres.Generate(genTestDir2, postgres.DBConnection{
Host: dbconfig.Host,
Port: dbconfig.Port,
User: dbconfig.User,
Password: dbconfig.Password,
SslMode: "disable",
Params: "",
DBName: dbconfig.DBName,
SchemaName: "dvds",
})
assert.NilError(t, err)
assertGeneratedFiles(t)
err = os.RemoveAll(genTestDir2)
assert.NilError(t, err)
}
func assertGeneratedFiles(t *testing.T) {
// Table SQL Builder files
tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table")
assert.NilError(t, err)
assertFileNameEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile)
// Enums SQL Builder files
enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum")
assert.NilError(t, err)
assertFileNameEqual(t, enumFiles, "mpaa_rating.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile)
// Model files
modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model")
assert.NilError(t, err)
assertFileNameEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go")
assertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", "\npackage model", actorModelFile)
}
func assertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err)
beginIndex := bytes.Index(enumFileData, []byte(contentBegin))
//fmt.Println("-"+string(enumFileData[beginIndex:])+"-")
assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent)
}
func assertFileNameEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) {
fileNamesMap := map[string]bool{}
for _, fileInfo := range fileInfos {
fileNamesMap[fileInfo.Name()] = true
}
for _, fileName := range fileNames {
assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.")
}
}
var mpaaRatingEnumFile = `
package enum
import "github.com/go-jet/jet"
var MpaaRating = &struct {
G jet.StringExpression
Pg jet.StringExpression
Pg13 jet.StringExpression
R jet.StringExpression
Nc17 jet.StringExpression
}{
G: jet.NewEnumValue("G"),
Pg: jet.NewEnumValue("PG"),
Pg13: jet.NewEnumValue("PG-13"),
R: jet.NewEnumValue("R"),
Nc17: jet.NewEnumValue("NC-17"),
}
`
var actorSQLBuilderFile = `
package table
import (
"github.com/go-jet/jet"
)
var Actor = newActorTable()
type ActorTable struct {
jet.Table
//Columns
ActorID jet.ColumnInteger
FirstName jet.ColumnString
LastName jet.ColumnString
LastUpdate jet.ColumnTimestamp
AllColumns jet.ColumnList
MutableColumns jet.ColumnList
}
// creates new ActorTable with assigned alias
func (a *ActorTable) AS(alias string) *ActorTable {
aliasTable := newActorTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newActorTable() *ActorTable {
var (
ActorIDColumn = jet.IntegerColumn("actor_id")
FirstNameColumn = jet.StringColumn("first_name")
LastNameColumn = jet.StringColumn("last_name")
LastUpdateColumn = jet.TimestampColumn("last_update")
)
return &ActorTable{
Table: jet.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn),
//Columns
ActorID: ActorIDColumn,
FirstName: FirstNameColumn,
LastName: LastNameColumn,
LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn},
MutableColumns: jet.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn},
}
}
`
var actorModelFile = `
package model
import (
"time"
)
type Actor struct {
ActorID int32 ` + "`sql:\"primary_key\"`" + `
FirstName string
LastName string
LastUpdate time.Time
}
`

View file

@ -61,5 +61,5 @@ func TestNorthwindJoinEverything(t *testing.T) {
assert.NilError(t, err)
//jsonSave("./testdata/northwind-all.json", dest)
assertJsonFile(t, "./testdata/northwind-all.json", dest)
assertJSONFile(t, "./testdata/northwind-all.json", dest)
}

View file

@ -46,7 +46,7 @@ FROM test_sample.person;
err := query.Query(db, &result)
assert.NilError(t, err)
assertJson(t, result, `
assertJSON(t, result, `
[
{
"PersonID": "b68dbff4-a87d-11e9-a7f2-98ded00c39c6",
@ -72,7 +72,7 @@ FROM test_sample.person;
func TestSelecSelfJoin1(t *testing.T) {
var expectedSql = `
var expectedSQL = `
SELECT employee.employee_id AS "employee.employee_id",
employee.first_name AS "employee.first_name",
employee.last_name AS "employee.last_name",
@ -97,7 +97,7 @@ ORDER BY employee.employee_id;
).
ORDER_BY(Employee.EmployeeID)
assertStatementSql(t, query, expectedSql)
assertStatementSql(t, query, expectedSQL)
type Manager model.Employee

View file

@ -19,7 +19,7 @@ func TestScanToInvalidDestination(t *testing.T) {
t.Run("nil dest", func(t *testing.T) {
err := query.Query(db, nil)
assert.Error(t, err, "jet: Destination is nil.")
assert.Error(t, err, "jet: Destination is nil")
})
t.Run("struct dest", func(t *testing.T) {

View file

@ -10,7 +10,7 @@ import (
)
func TestSelect_ScanToStruct(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT DISTINCT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
@ -24,7 +24,7 @@ WHERE actor.actor_id = 1;
DISTINCT().
WHERE(Actor.ActorID.EQ(Int(1)))
assertStatementSql(t, query, expectedSql, int64(1))
assertStatementSql(t, query, expectedSQL, int64(1))
actor := model.Actor{}
err := query.Query(db, &actor)
@ -42,7 +42,7 @@ WHERE actor.actor_id = 1;
}
func TestClassicSelect(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id",
@ -74,7 +74,7 @@ LIMIT 30;
ORDER_BY(Payment.PaymentID.ASC()).
LIMIT(30)
assertStatementSql(t, query, expectedSql, int64(30))
assertStatementSql(t, query, expectedSQL, int64(30))
dest := []model.Payment{}
@ -85,7 +85,7 @@ LIMIT 30;
}
func TestSelect_ScanToSlice(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name",
@ -103,7 +103,7 @@ ORDER BY customer.customer_id ASC;
query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.ASC())
assertStatementSql(t, query, expectedSql)
assertStatementSql(t, query, expectedSQL)
err := query.Query(db, &customers)
assert.NilError(t, err)
@ -116,7 +116,7 @@ ORDER BY customer.customer_id ASC;
}
func TestSelectAndUnionInProjection(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT payment.payment_id AS "payment.payment_id",
(
SELECT customer.customer_id AS "customer.customer_id"
@ -156,12 +156,12 @@ LIMIT 12;
).
LIMIT(12)
assertStatementSql(t, query, expectedSql, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12))
assertStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12))
}
func TestJoinQueryStruct(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT film_actor.actor_id AS "film_actor.actor_id",
film_actor.film_id AS "film_actor.film_id",
film_actor.last_update AS "film_actor.last_update",
@ -224,7 +224,7 @@ LIMIT 1000;
ORDER_BY(Film.FilmID.ASC()).
LIMIT(1000)
assertStatementSql(t, query, expectedSql, int64(1000))
assertStatementSql(t, query, expectedSQL, int64(1000))
var languageActorFilm []struct {
model.Language
@ -253,7 +253,7 @@ LIMIT 1000;
}
func TestJoinQuerySlice(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT language.language_id AS "language.language_id",
language.name AS "language.name",
language.last_update AS "language.last_update",
@ -290,7 +290,7 @@ LIMIT 15;
WHERE(Film.Rating.EQ(enum.MpaaRating.Nc17)).
LIMIT(15)
assertStatementSql(t, query, expectedSql, int64(15))
assertStatementSql(t, query, expectedSQL, int64(15))
err := query.Query(db, &filmsPerLanguage)
@ -532,7 +532,7 @@ ORDER BY city.city_id, address.address_id, customer.customer_id;
assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
assertJson(t, dest, `
assertJSON(t, dest, `
[
{
"CityID": 312,
@ -657,7 +657,7 @@ func TestSelectOrderByAscDesc(t *testing.T) {
}
func TestSelectFullJoin(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name",
@ -685,7 +685,7 @@ ORDER BY customer.customer_id ASC;
SELECT(Customer.AllColumns, Address.AllColumns).
ORDER_BY(Customer.CustomerID.ASC())
assertStatementSql(t, query, expectedSql)
assertStatementSql(t, query, expectedSQL)
allCustomersAndAddress := []struct {
Address *model.Address
@ -708,7 +708,7 @@ ORDER BY customer.customer_id ASC;
}
func TestSelectFullCrossJoin(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name",
@ -738,7 +738,7 @@ LIMIT 1000;
ORDER_BY(Customer.CustomerID.ASC()).
LIMIT(1000)
assertStatementSql(t, query, expectedSql, int64(1000))
assertStatementSql(t, query, expectedSQL, int64(1000))
var customerAddresCrosJoined []struct {
model.Customer
@ -753,7 +753,7 @@ LIMIT 1000;
}
func TestSelectSelfJoin(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT f1.film_id AS "f1.film_id",
f1.title AS "f1.title",
f1.description AS "f1.description",
@ -793,7 +793,7 @@ ORDER BY f1.film_id ASC;
SELECT(f1.AllColumns, f2.AllColumns).
ORDER_BY(f1.FilmID.ASC())
assertStatementSql(t, query, expectedSql)
assertStatementSql(t, query, expectedSQL)
type F1 model.Film
type F2 model.Film
@ -813,7 +813,7 @@ ORDER BY f1.film_id ASC;
}
func TestSelectAliasColumn(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT f1.title AS "thesame_length_films.title1",
f2.title AS "thesame_length_films.title2",
f1.length AS "thesame_length_films.length"
@ -835,7 +835,7 @@ LIMIT 1000;
ORDER_BY(f1.Length.ASC(), f1.Title.ASC(), f2.Title.ASC()).
LIMIT(1000)
assertStatementSql(t, query, expectedSql, int64(1000))
assertStatementSql(t, query, expectedSQL, int64(1000))
type thesameLengthFilms struct {
Title1 string
@ -886,11 +886,11 @@ FROM dvds.actor
WHERE(Film.Rating.EQ(enum.MpaaRating.R)).
AsTable("rFilms")
rFilmId := Film.FilmID.From(rRatingFilms)
rFilmID := Film.FilmID.From(rRatingFilms)
query := Actor.
INNER_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.FilmID)).
INNER_JOIN(rRatingFilms, FilmActor.FilmID.EQ(rFilmId)).
INNER_JOIN(rRatingFilms, FilmActor.FilmID.EQ(rFilmID)).
SELECT(
Actor.AllColumns,
FilmActor.AllColumns,
@ -928,7 +928,7 @@ FROM dvds.film;
}
func TestSelectQueryScalar(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT film.film_id AS "film.film_id",
film.title AS "film.title",
film.description AS "film.description",
@ -960,7 +960,7 @@ ORDER BY film.film_id ASC;
WHERE(Film.RentalRate.EQ(maxFilmRentalRate)).
ORDER_BY(Film.FilmID.ASC())
assertStatementSql(t, query, expectedSql)
assertStatementSql(t, query, expectedSQL)
maxRentalRateFilms := []model.Film{}
err := query.Query(db, &maxRentalRateFilms)
@ -989,7 +989,7 @@ ORDER BY film.film_id ASC;
}
func TestSelectGroupByHaving(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT payment.customer_id AS "customer_payment_sum.customer_id",
SUM(payment.amount) AS "customer_payment_sum.amount_sum",
AVG(payment.amount) AS "customer_payment_sum.amount_avg",
@ -1018,7 +1018,7 @@ ORDER BY SUM(payment.amount) ASC;
SUMf(Payment.Amount).GT(Float(100)),
)
assertStatementSql(t, customersPaymentQuery, expectedSql, float64(100))
assertStatementSql(t, customersPaymentQuery, expectedSQL, float64(100))
type CustomerPaymentSum struct {
CustomerID int16
@ -1047,7 +1047,7 @@ ORDER BY SUM(payment.amount) ASC;
}
func TestSelectGroupBy2(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name",
@ -1077,18 +1077,18 @@ ORDER BY customer_payment_sum."amount_sum" ASC;
GROUP_BY(Payment.CustomerID).
AsTable("customer_payment_sum")
customerId := Payment.CustomerID.From(customersPayments)
customerID := Payment.CustomerID.From(customersPayments)
amountSum := FloatColumn("amount_sum").From(customersPayments)
query := Customer.
INNER_JOIN(customersPayments, Customer.CustomerID.EQ(customerId)).
INNER_JOIN(customersPayments, Customer.CustomerID.EQ(customerID)).
SELECT(
Customer.AllColumns,
amountSum.AS("CustomerWithAmounts.AmountSum"),
).
ORDER_BY(amountSum.ASC())
assertStatementSql(t, query, expectedSql)
assertStatementSql(t, query, expectedSQL)
type CustomerWithAmounts struct {
Customer *model.Customer
@ -1123,7 +1123,7 @@ func TestSelectStaff(t *testing.T) {
assert.NilError(t, err)
assertJson(t, staffs, `
assertJSON(t, staffs, `
[
{
"StaffID": 1,
@ -1157,7 +1157,7 @@ func TestSelectStaff(t *testing.T) {
func TestSelectTimeColumns(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id",
@ -1173,7 +1173,7 @@ ORDER BY payment.payment_date ASC;
WHERE(Payment.PaymentDate.LT(Timestamp(2007, 02, 14, 22, 16, 01, 0))).
ORDER_BY(Payment.PaymentDate.ASC())
assertStatementSql(t, query, expectedSql, "2007-02-14 22:16:01.000")
assertStatementSql(t, query, expectedSQL, "2007-02-14 22:16:01.000")
payments := []model.Payment{}
@ -1260,8 +1260,8 @@ func TestAllSetOperators(t *testing.T) {
UNION_ALL,
INTERSECT,
INTERSECT_ALL,
EXCEPT,
EXCEPT_ALL,
//EXCEPT,
//EXCEPT_ALL,
}
expectedDestLen := []int{
@ -1304,63 +1304,15 @@ LIMIT 20;
assertStatementSql(t, query, expectedQuery, int64(1), "ONE", int64(2), "TWO", int64(3), "THREE", "OTHER", int64(20))
dest := []struct {
StaffIdNum string
StaffIDNum string
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 20)
assert.Equal(t, dest[0].StaffIdNum, "TWO")
assert.Equal(t, dest[1].StaffIdNum, "ONE")
}
func TestLockTable(t *testing.T) {
expectedSql := `
LOCK TABLE dvds.address IN`
var testData = []TableLockMode{
LOCK_ACCESS_SHARE,
LOCK_ROW_SHARE,
LOCK_ROW_EXCLUSIVE,
LOCK_SHARE_UPDATE_EXCLUSIVE,
LOCK_SHARE,
LOCK_SHARE_ROW_EXCLUSIVE,
LOCK_EXCLUSIVE,
LOCK_ACCESS_EXCLUSIVE,
}
for _, lockMode := range testData {
query := Address.LOCK().IN(lockMode)
assertStatementSql(t, query, expectedSql+" "+string(lockMode)+" MODE;\n")
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
for _, lockMode := range testData {
query := Address.LOCK().IN(lockMode).NOWAIT()
assertStatementSql(t, query, expectedSql+" "+string(lockMode)+" MODE NOWAIT;\n")
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
assert.Equal(t, dest[0].StaffIDNum, "TWO")
assert.Equal(t, dest[1].StaffIDNum, "ONE")
}
func getRowLockTestData() map[SelectLock]string {
@ -1373,7 +1325,7 @@ func getRowLockTestData() map[SelectLock]string {
}
func TestRowLock(t *testing.T) {
expectedSql := `
expectedSQL := `
SELECT *
FROM dvds.address
LIMIT 3
@ -1385,7 +1337,7 @@ FOR`
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType)
assertStatementSql(t, query, expectedSql+" "+lockTypeStr+";\n", int64(3))
assertStatementSql(t, query, expectedSQL+" "+lockTypeStr+";\n", int64(3))
tx, _ := db.Begin()
@ -1401,7 +1353,7 @@ FOR`
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType.NOWAIT())
assertStatementSql(t, query, expectedSql+" "+lockTypeStr+" NOWAIT;\n", int64(3))
assertStatementSql(t, query, expectedSQL+" "+lockTypeStr+" NOWAIT;\n", int64(3))
tx, _ := db.Begin()
@ -1417,7 +1369,7 @@ FOR`
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType.SKIP_LOCKED())
assertStatementSql(t, query, expectedSql+" "+lockTypeStr+" SKIP LOCKED;\n", int64(3))
assertStatementSql(t, query, expectedSQL+" "+lockTypeStr+" SKIP LOCKED;\n", int64(3))
tx, _ := db.Begin()
@ -1433,7 +1385,7 @@ FOR`
func TestQuickStart(t *testing.T) {
var expectedSql = `
var expectedSQL = `
SELECT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
@ -1488,7 +1440,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
Film.FilmID.ASC(),
)
assertStatementSql(t, stmt, expectedSql, "English", "Action", int64(180))
assertStatementSql(t, stmt, expectedSQL, "English", "Action", int64(180))
var dest []struct {
model.Actor
@ -1506,7 +1458,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
assert.NilError(t, err)
//jsonSave("./testdata/quick-start-dest.json", dest)
assertJsonFile(t, "./testdata/quick-start-dest.json", dest)
assertJSONFile(t, "./testdata/quick-start-dest.json", dest)
var dest2 []struct {
model.Category
@ -1519,7 +1471,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
assert.NilError(t, err)
//jsonSave("./testdata/quick-start-dest2.json", dest2)
assertJsonFile(t, "./testdata/quick-start-dest2.json", dest2)
assertJSONFile(t, "./testdata/quick-start-dest2.json", dest2)
}
func TestQuickStartWithSubQueries(t *testing.T) {
@ -1529,22 +1481,22 @@ func TestQuickStartWithSubQueries(t *testing.T) {
WHERE(Film.Length.GT(Int(180))).
AsTable("films")
filmId := Film.FilmID.From(filmLogerThan180)
filmLanguageId := Film.LanguageID.From(filmLogerThan180)
filmID := Film.FilmID.From(filmLogerThan180)
filmLanguageID := Film.LanguageID.From(filmLogerThan180)
categoriesNotAction := Category.
SELECT(Category.AllColumns).
WHERE(Category.Name.NOT_EQ(String("Action"))).
AsTable("categories")
categoryId := Category.CategoryID.From(categoriesNotAction)
categoryID := Category.CategoryID.From(categoriesNotAction)
stmt := Actor.
INNER_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.ActorID)).
INNER_JOIN(filmLogerThan180, filmId.EQ(FilmActor.FilmID)).
INNER_JOIN(Language, Language.LanguageID.EQ(filmLanguageId)).
INNER_JOIN(FilmCategory, FilmCategory.FilmID.EQ(filmId)).
INNER_JOIN(categoriesNotAction, categoryId.EQ(FilmCategory.CategoryID)).
INNER_JOIN(filmLogerThan180, filmID.EQ(FilmActor.FilmID)).
INNER_JOIN(Language, Language.LanguageID.EQ(filmLanguageID)).
INNER_JOIN(FilmCategory, FilmCategory.FilmID.EQ(filmID)).
INNER_JOIN(categoriesNotAction, categoryID.EQ(FilmCategory.CategoryID)).
SELECT(
Actor.AllColumns,
filmLogerThan180.AllColumns(),
@ -1552,7 +1504,7 @@ func TestQuickStartWithSubQueries(t *testing.T) {
categoriesNotAction.AllColumns(),
).ORDER_BY(
Actor.ActorID.ASC(),
filmId.ASC(),
filmID.ASC(),
)
var dest []struct {
@ -1571,7 +1523,7 @@ func TestQuickStartWithSubQueries(t *testing.T) {
assert.NilError(t, err)
//jsonSave("./testdata/quick-start-dest.json", dest)
assertJsonFile(t, "./testdata/quick-start-dest.json", dest)
assertJSONFile(t, "./testdata/quick-start-dest.json", dest)
var dest2 []struct {
model.Category
@ -1584,5 +1536,5 @@ func TestQuickStartWithSubQueries(t *testing.T) {
assert.NilError(t, err)
//jsonSave("./testdata/quick-start-dest2.json", dest2)
assertJsonFile(t, "./testdata/quick-start-dest2.json", dest2)
assertJSONFile(t, "./testdata/quick-start-dest2.json", dest2)
}

View file

@ -1,11 +1,13 @@
package tests
import (
"context"
. "github.com/go-jet/jet"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"gotest.tools/assert"
"testing"
"time"
)
func TestUpdateValues(t *testing.T) {
@ -16,12 +18,12 @@ func TestUpdateValues(t *testing.T) {
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
var expectedSql = `
var expectedSQL = `
UPDATE test_sample.link
SET (name, url) = ('Bong', 'http://bong.com')
WHERE link.name = 'Bing';
`
assertStatementSql(t, query, expectedSql, "Bong", "http://bong.com", "Bing")
assertStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing")
assertExec(t, query, 1)
@ -54,7 +56,7 @@ func TestUpdateWithSubQueries(t *testing.T) {
).
WHERE(Link.Name.EQ(String("Bing")))
expectedSql := `
expectedSQL := `
UPDATE test_sample.link
SET (name, url) = ((
SELECT 'Bong'
@ -66,7 +68,7 @@ SET (name, url) = ((
WHERE link.name = 'Bing';
`
assertStatementSql(t, query, expectedSql, "Bong", "Bing", "Bing")
assertStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing")
assertExec(t, query, 1)
}
@ -74,7 +76,7 @@ WHERE link.name = 'Bing';
func TestUpdateAndReturning(t *testing.T) {
setupLinkTableForUpdateTest(t)
expectedSql := `
expectedSQL := `
UPDATE test_sample.link
SET (name, url) = ('DuckDuckGo', 'http://www.duckduckgo.com')
WHERE link.name = 'Ask'
@ -90,7 +92,7 @@ RETURNING link.id AS "link.id",
WHERE(Link.Name.EQ(String("Ask"))).
RETURNING(Link.AllColumns)
assertStatementSql(t, stmt, expectedSql, "DuckDuckGo", "http://www.duckduckgo.com", "Ask")
assertStatementSql(t, stmt, expectedSQL, "DuckDuckGo", "http://www.duckduckgo.com", "Ask")
links := []model.Link{}
@ -112,7 +114,7 @@ func TestUpdateWithSelect(t *testing.T) {
).
WHERE(Link.ID.EQ(Int(0)))
expectedSql := `
expectedSQL := `
UPDATE test_sample.link
SET (id, url, name, description) = (
SELECT link.id AS "link.id",
@ -124,7 +126,7 @@ SET (id, url, name, description) = (
)
WHERE link.id = 0;
`
assertStatementSql(t, stmt, expectedSql, int64(0), int64(0))
assertStatementSql(t, stmt, expectedSQL, int64(0), int64(0))
assertExec(t, stmt, 1)
}
@ -139,7 +141,7 @@ func TestUpdateWithInvalidSelect(t *testing.T) {
).
WHERE(Link.ID.EQ(Int(0)))
var expectedSql = `
var expectedSQL = `
UPDATE test_sample.link
SET (id, url, name, description) = (
SELECT link.id AS "link.id",
@ -149,7 +151,7 @@ SET (id, url, name, description) = (
)
WHERE link.id = 0;
`
assertStatementSql(t, stmt, expectedSql, int64(0), int64(0))
assertStatementSql(t, stmt, expectedSQL, int64(0), int64(0))
assertExecErr(t, stmt, "pq: number of columns does not match number of values")
}
@ -168,12 +170,12 @@ func TestUpdateWithModelData(t *testing.T) {
MODEL(link).
WHERE(Link.ID.EQ(Int(int64(link.ID))))
expectedSql := `
expectedSQL := `
UPDATE test_sample.link
SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL)
WHERE link.id = 201;
`
assertStatementSql(t, stmt, expectedSql, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201))
assertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201))
assertExec(t, stmt, 1)
}
@ -195,12 +197,12 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
MODEL(link).
WHERE(Link.ID.EQ(Int(int64(link.ID))))
var expectedSql = `
var expectedSQL = `
UPDATE test_sample.link
SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com')
WHERE link.id = 201;
`
assertStatementSql(t, stmt, expectedSql, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201))
assertStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201))
assertExec(t, stmt, 1)
}
@ -231,16 +233,53 @@ func TestUpdateWithInvalidModelData(t *testing.T) {
MODEL(link).
WHERE(Link.ID.EQ(Int(int64(link.Ident))))
var expectedSql = `
var expectedSQL = `
UPDATE test_sample.link
SET (id, url, name, description, rel) = ('http://www.duckduckgo.com', 'DuckDuckGo', NULL, NULL)
WHERE link.id = 201;
`
assertStatementSql(t, stmt, expectedSql, "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201))
assertStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201))
assertExecErr(t, stmt, "pq: number of columns does not match number of values")
}
func TestUpdateQueryContext(t *testing.T) {
setupLinkTableForUpdateTest(t)
updateStmt := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
dest := []model.Link{}
err := updateStmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded")
}
func TestUpdateExecContext(t *testing.T) {
setupLinkTableForUpdateTest(t)
updateStmt := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
_, err := updateStmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded")
}
func setupLinkTableForUpdateTest(t *testing.T) {
cleanUpLinkTable(t)

View file

@ -1,5 +1,6 @@
package jet
// TimeExpression interface
type TimeExpression interface {
Expression
@ -58,19 +59,15 @@ type prefixTimeExpression struct {
prefixOpExpression
}
func newPrefixTimeExpression(operator string, expression Expression) TimeExpression {
timeExpr := prefixTimeExpression{}
timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
timeExpr.expressionInterfaceImpl.parent = &timeExpr
timeExpr.timeInterfaceImpl.parent = &timeExpr
return &timeExpr
}
func INTERVAL(interval string) Expression {
return newPrefixTimeExpression("INTERVAL", literal(interval))
}
//func newPrefixTimeExpression(operator string, expression Expression) TimeExpression {
// timeExpr := prefixTimeExpression{}
// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
//
// timeExpr.expressionInterfaceImpl.parent = &timeExpr
// timeExpr.timeInterfaceImpl.parent = &timeExpr
//
// return &timeExpr
//}
//---------------------------------------------------//
@ -85,6 +82,9 @@ func newTimeExpressionWrap(expression Expression) TimeExpression {
return &timeExpressionWrap
}
// TimeExp is time expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as time expression.
// Does not add sql cast to generated sql builder output.
func TimeExp(expression Expression) TimeExpression {
return newTimeExpressionWrap(expression)
}

View file

@ -4,34 +4,46 @@ import (
"testing"
)
var timeVar = Time(10, 20, 0, 0)
func TestTimeExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.EQ(table2ColTime), "(table1.col_time = table2.col_time)")
assertClauseSerialize(t, table1ColTime.EQ(Time(10, 20, 0, 0)), "(table1.col_time = $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.EQ(timeVar), "(table1.col_time = $1::time without time zone)", "10:20:00.000")
}
func TestTimeExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.NOT_EQ(table2ColTime), "(table1.col_time != table2.col_time)")
assertClauseSerialize(t, table1ColTime.NOT_EQ(Time(10, 20, 0, 0)), "(table1.col_time != $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.NOT_EQ(timeVar), "(table1.col_time != $1::time without time zone)", "10:20:00.000")
}
func TestTimeExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(table2ColTime), "(table1.col_time IS DISTINCT FROM table2.col_time)")
assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(timeVar), "(table1.col_time IS DISTINCT FROM $1::time without time zone)", "10:20:00.000")
}
func TestTimeExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(table2ColTime), "(table1.col_time IS NOT DISTINCT FROM table2.col_time)")
assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(timeVar), "(table1.col_time IS NOT DISTINCT FROM $1::time without time zone)", "10:20:00.000")
}
func TestTimeExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTime.LT(table2ColTime), "(table1.col_time < table2.col_time)")
assertClauseSerialize(t, table1ColTime.LT(Time(10, 20, 0, 0)), "(table1.col_time < $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.LT(timeVar), "(table1.col_time < $1::time without time zone)", "10:20:00.000")
}
func TestTimeExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.LT_EQ(table2ColTime), "(table1.col_time <= table2.col_time)")
assertClauseSerialize(t, table1ColTime.LT_EQ(Time(10, 20, 0, 0)), "(table1.col_time <= $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.LT_EQ(timeVar), "(table1.col_time <= $1::time without time zone)", "10:20:00.000")
}
func TestTimeExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTime.GT(table2ColTime), "(table1.col_time > table2.col_time)")
assertClauseSerialize(t, table1ColTime.GT(Time(10, 20, 0, 0)), "(table1.col_time > $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.GT(timeVar), "(table1.col_time > $1::time without time zone)", "10:20:00.000")
}
func TestTimeExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.GT_EQ(table2ColTime), "(table1.col_time >= table2.col_time)")
assertClauseSerialize(t, table1ColTime.GT_EQ(Time(10, 20, 0, 0)), "(table1.col_time >= $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.GT_EQ(timeVar), "(table1.col_time >= $1::time without time zone)", "10:20:00.000")
}
func TestTimeExp(t *testing.T) {

View file

@ -1,5 +1,6 @@
package jet
// TimestampExpression interface
type TimestampExpression interface {
Expression
@ -63,6 +64,9 @@ func newTimestampExpressionWrap(expression Expression) TimestampExpression {
return &timestampExpressionWrap
}
// TimestampExp is timestamp expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as timestamp expression.
// Does not add sql cast to generated sql builder output.
func TimestampExp(expression Expression) TimestampExpression {
return newTimestampExpressionWrap(expression)
}

View file

@ -0,0 +1,52 @@
package jet
import "testing"
var timestamp = Timestamp(2000, 1, 31, 10, 20, 0, 0)
func TestTimestampExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.EQ(table2ColTimestamp), "(table1.col_timestamp = table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.EQ(timestamp),
"(table1.col_timestamp = $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(table2ColTimestamp), "(table1.col_timestamp != table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(timestamp), "(table1.col_timestamp != $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS DISTINCT FROM table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS DISTINCT FROM $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS NOT DISTINCT FROM table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS NOT DISTINCT FROM $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.LT(table2ColTimestamp), "(table1.col_timestamp < table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.LT(timestamp), "(table1.col_timestamp < $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.LT_EQ(table2ColTimestamp), "(table1.col_timestamp <= table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.LT_EQ(timestamp), "(table1.col_timestamp <= $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.GT(table2ColTimestamp), "(table1.col_timestamp > table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.GT(timestamp), "(table1.col_timestamp > $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.GT_EQ(table2ColTimestamp), "(table1.col_timestamp >= table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.GT_EQ(timestamp), "(table1.col_timestamp >= $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}
func TestTimestampExp(t *testing.T) {
assertClauseSerialize(t, TimestampExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp),
"(table1.col_float < $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
}

View file

@ -1,5 +1,6 @@
package jet
// TimestampzExpression interface
type TimestampzExpression interface {
Expression
@ -63,6 +64,9 @@ func newTimestampzExpressionWrap(expression Expression) TimestampzExpression {
return &timestampzExpressionWrap
}
// TimestampzExp is timestamp with time zone expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as timestamp with time zone expression.
// Does not add sql cast to generated sql builder output.
func TimestampzExp(expression Expression) TimestampzExpression {
return newTimestampzExpressionWrap(expression)
}

View file

@ -0,0 +1,52 @@
package jet
import "testing"
var timestampz = Timestampz(2000, 1, 31, 10, 20, 0, 0, 2)
func TestTimestampzExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.EQ(table2ColTimestampz), "(table1.col_timestampz = table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.EQ(timestampz),
"(table1.col_timestampz = $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(table2ColTimestampz), "(table1.col_timestampz != table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS NOT DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT(table2ColTimestampz), "(table1.col_timestampz < table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(table2ColTimestampz), "(table1.col_timestampz <= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT(table2ColTimestampz), "(table1.col_timestampz > table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(table2ColTimestampz), "(table1.col_timestampz >= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}
func TestTimestampzExp(t *testing.T) {
assertClauseSerialize(t, TimestampzExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz),
"(table1.col_float < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
}

View file

@ -1,16 +1,25 @@
package jet
// TimezExpression interface 'time with time zone'
type TimezExpression interface {
Expression
//EQ
EQ(rhs TimezExpression) BoolExpression
//NOT_EQ
NOT_EQ(rhs TimezExpression) BoolExpression
//IS_DISTINCT_FROM
IS_DISTINCT_FROM(rhs TimezExpression) BoolExpression
//IS_NOT_DISTINCT_FROM
IS_NOT_DISTINCT_FROM(rhs TimezExpression) BoolExpression
//LT
LT(rhs TimezExpression) BoolExpression
//LT_EQ
LT_EQ(rhs TimezExpression) BoolExpression
//GT
GT(rhs TimezExpression) BoolExpression
//GT_EQ
GT_EQ(rhs TimezExpression) BoolExpression
}
@ -58,15 +67,15 @@ type prefixTimezExpression struct {
prefixOpExpression
}
func newPrefixTimezExpression(operator string, expression Expression) TimezExpression {
timeExpr := prefixTimezExpression{}
timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
timeExpr.expressionInterfaceImpl.parent = &timeExpr
timeExpr.timezInterfaceImpl.parent = &timeExpr
return &timeExpr
}
//func newPrefixTimezExpression(operator string, expression Expression) TimezExpression {
// timeExpr := prefixTimezExpression{}
// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
//
// timeExpr.expressionInterfaceImpl.parent = &timeExpr
// timeExpr.timezInterfaceImpl.parent = &timeExpr
//
// return &timeExpr
//}
//---------------------------------------------------//
@ -81,6 +90,9 @@ func newTimezExpressionWrap(expression Expression) TimezExpression {
return &timezExpressionWrap
}
// TimezExp is time with time zone expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as time with time zone expression.
// Does not add sql cast to generated sql builder output.
func TimezExp(expression Expression) TimezExpression {
return newTimezExpressionWrap(expression)
}

51
timez_expression_test.go Normal file
View file

@ -0,0 +1,51 @@
package jet
import "testing"
var timezVar = Timez(10, 20, 0, 0, 4)
func TestTimezExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.EQ(table2ColTimez), "(table1.col_timez = table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.NOT_EQ(table2ColTimez), "(table1.col_timez != table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS NOT DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT(table2ColTimez), "(table1.col_timez < table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT_EQ(table2ColTimez), "(table1.col_timez <= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT(table2ColTimez), "(table1.col_timez > table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT_EQ(table2ColTimez), "(table1.col_timez >= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1::time with time zone)", "10:20:00.000 +04")
}
func TestTimezExp(t *testing.T) {
assertClauseSerialize(t, TimezExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, 4)),
"(table1.col_float < $1::time with time zone)", string("01:01:01.001 +04"))
}

View file

@ -5,8 +5,10 @@ import (
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/utils"
)
// UpdateStatement is interface of SQL UPDATE statement
type UpdateStatement interface {
Statement
@ -61,11 +63,11 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
out.newLine()
out.writeString("UPDATE")
if isNil(u.table) {
if utils.IsNil(u.table) {
return "", nil, errors.New("jet: table to update is nil")
}
if err = u.table.serialize(update_statement, out); err != nil {
if err = u.table.serialize(updateStatement, out); err != nil {
return
}
@ -100,7 +102,7 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
out.writeString("(")
}
err = serializeClauseList(update_statement, u.row, out)
err = serializeClauseList(updateStatement, u.row, out)
if err != nil {
return
@ -114,11 +116,11 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
return "", nil, errors.New("jet: WHERE clause not set")
}
if err = out.writeWhere(update_statement, u.where); err != nil {
if err = out.writeWhere(updateStatement, u.where); err != nil {
return
}
if err = out.writeReturning(update_statement, u.returning); err != nil {
if err = out.writeReturning(updateStatement, u.returning); err != nil {
return
}
@ -134,14 +136,14 @@ func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) er
return query(u, db, destination)
}
func (u *updateStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error {
return queryContext(u, db, context, destination)
func (u *updateStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, u, db, destination)
}
func (u *updateStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(u, db)
}
func (u *updateStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) {
return execContext(u, db, context)
func (u *updateStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, u, db)
}

View file

@ -5,7 +5,7 @@ import (
)
func TestUpdateWithOneValue(t *testing.T) {
expectedSql := `
expectedSQL := `
UPDATE db.table1
SET col_int = $1
WHERE table1.col_int >= $2;
@ -14,11 +14,11 @@ WHERE table1.col_int >= $2;
SET(1).
WHERE(table1ColInt.GT_EQ(Int(33)))
assertStatement(t, stmt, expectedSql, 1, int64(33))
assertStatement(t, stmt, expectedSQL, 1, int64(33))
}
func TestUpdateWithValues(t *testing.T) {
expectedSql := `
expectedSQL := `
UPDATE db.table1
SET (col_int, col_float) = ($1, $2)
WHERE table1.col_int >= $3;
@ -27,11 +27,11 @@ WHERE table1.col_int >= $3;
SET(1, 22.2).
WHERE(table1ColInt.GT_EQ(Int(33)))
assertStatement(t, stmt, expectedSql, 1, 22.2, int64(33))
assertStatement(t, stmt, expectedSQL, 1, 22.2, int64(33))
}
func TestUpdateOneColumnWithSelect(t *testing.T) {
expectedSql := `
expectedSQL := `
UPDATE db.table1
SET col_float = (
SELECT table1.col_float AS "table1.col_float"
@ -48,11 +48,11 @@ RETURNING table1.col1 AS "table1.col1";
WHERE(table1Col1.EQ(Int(2))).
RETURNING(table1Col1)
assertStatement(t, stmt, expectedSql, int64(2))
assertStatement(t, stmt, expectedSQL, int64(2))
}
func TestUpdateColumnsWithSelect(t *testing.T) {
expectedSql := `
expectedSQL := `
UPDATE db.table1
SET (col1, col_float) = (
SELECT table1.col_float AS "table1.col_float",
@ -67,7 +67,7 @@ RETURNING table1.col1 AS "table1.col1";
WHERE(table1Col1.EQ(Int(2))).
RETURNING(table1Col1)
assertStatement(t, stmt, expectedSql, int64(2))
assertStatement(t, stmt, expectedSQL, int64(2))
}
func TestInvalidInputs(t *testing.T) {

View file

@ -7,7 +7,7 @@ import (
"strings"
)
func serializeOrderByClauseList(statement statementType, orderByClauses []OrderByClause, out *sqlBuilder) error {
func serializeOrderByClauseList(statement statementType, orderByClauses []orderByClause, out *sqlBuilder) error {
for i, value := range orderByClauses {
if i > 0 {
@ -32,7 +32,7 @@ func serializeGroupByClauseList(statement statementType, clauses []groupByClause
}
if c == nil {
return errors.New("jet: nil clause.")
return errors.New("jet: nil clause")
}
if err = c.serializeForGroupBy(statement, out); err != nil {
@ -51,7 +51,7 @@ func serializeClauseList(statement statementType, clauses []clause, out *sqlBuil
}
if c == nil {
return errors.New("jet: nil clause.")
return errors.New("jet: nil clause")
}
if err = c.serialize(statement, out); err != nil {
@ -124,16 +124,12 @@ func columnListToProjectionList(columns []Column) []projection {
return ret
}
func isNil(v interface{}) bool {
return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil())
}
func valueToClause(value interface{}) clause {
if clause, ok := value.(clause); ok {
return clause
} else {
return literal(value)
}
return literal(value)
}
func unwindRowFromModel(columns []column, data interface{}) []clause {

View file

@ -10,7 +10,11 @@ var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float")
var table1Col3 = IntegerColumn("col3")
var table1ColTime = TimeColumn("col_time")
var table1ColTimez = TimezColumn("col_timez")
var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColTimestampz = TimestampzColumn("col_timestampz")
var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date")
var table1 = NewTable(
"db",
@ -20,7 +24,12 @@ var table1 = NewTable(
table1ColFloat,
table1Col3,
table1ColTime,
table1ColBool)
table1ColTimez,
table1ColBool,
table1ColDate,
table1ColTimestamp,
table1ColTimestampz,
)
var table2Col3 = IntegerColumn("col3")
var table2Col4 = IntegerColumn("col4")
@ -29,6 +38,10 @@ var table2ColFloat = FloatColumn("col_float")
var table2ColStr = StringColumn("col_str")
var table2ColBool = BoolColumn("col_bool")
var table2ColTime = TimeColumn("col_time")
var table2ColTimez = TimezColumn("col_timez")
var table2ColTimestamp = TimestampColumn("col_timestamp")
var table2ColTimestampz = TimestampzColumn("col_timestampz")
var table2ColDate = DateColumn("col_date")
var table2 = NewTable(
"db",
@ -39,7 +52,12 @@ var table2 = NewTable(
table2ColFloat,
table2ColStr,
table2ColBool,
table2ColTime)
table2ColTime,
table2ColTimez,
table2ColDate,
table2ColTimestamp,
table2ColTimestampz,
)
var table3Col1 = IntegerColumn("col1")
var table3ColInt = IntegerColumn("col_int")
@ -53,7 +71,7 @@ var table3 = NewTable(
func assertClauseSerialize(t *testing.T, clause clause, query string, args ...interface{}) {
out := sqlBuilder{}
err := clause.serialize(select_statement, &out)
err := clause.serialize(selectStatement, &out)
assert.NilError(t, err)
@ -63,7 +81,7 @@ func assertClauseSerialize(t *testing.T, clause clause, query string, args ...in
func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
out := sqlBuilder{}
err := clause.serialize(select_statement, &out)
err := clause.serialize(selectStatement, &out)
//fmt.Println(out.buff.String())
assert.Assert(t, err != nil)
@ -72,7 +90,7 @@ func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
func assertProjectionSerialize(t *testing.T, projection projection, query string, args ...interface{}) {
out := sqlBuilder{}
err := projection.serializeForProjection(select_statement, &out)
err := projection.serializeForProjection(selectStatement, &out)
assert.NilError(t, err)