Merge pull request #9 from go-jet/mysql

MySQL and MariaDB support
This commit is contained in:
go-jet 2019-08-16 13:10:38 +02:00 committed by GitHub
commit de57a52acc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
205 changed files with 11704 additions and 467715 deletions

View file

@ -3,17 +3,25 @@
# Check https://circleci.com/docs/2.0/language-go/ for more details # Check https://circleci.com/docs/2.0/language-go/ for more details
version: 2 version: 2
jobs: jobs:
build: build-postgres-and-mysql:
docker: docker:
# specify the version # specify the version
- image: circleci/golang:1.11 - image: circleci/golang:1.11
- image: circleci/postgres:10.6-alpine - image: circleci/postgres:10.8-alpine
environment: # environment variables for primary container environment: # environment variables for primary container
POSTGRES_USER: jet POSTGRES_USER: jet
POSTGRES_PASSWORD: jet POSTGRES_PASSWORD: jet
POSTGRES_DB: jetdb POSTGRES_DB: jetdb
- image: circleci/mysql:8.0
command: [--default-authentication-plugin=mysql_native_password]
environment:
MYSQL_ROOT_PASSWORD: jet
MYSQL_DATABASE: dvds
MYSQL_USER: jet
MYSQL_PASSWORD: jet
working_directory: /go/src/github.com/go-jet/jet working_directory: /go/src/github.com/go-jet/jet
environment: # environment variables for the build itself environment: # environment variables for the build itself
@ -22,12 +30,20 @@ jobs:
steps: steps:
- checkout - checkout
# specify any bash command here prefixed with `run: ` - run:
name: Submodule init
command: |
git submodule init
git submodule update
cd ./tests/testdata && git fetch && git checkout master
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
go get github.com/google/uuid go get github.com/google/uuid
go get github.com/lib/pq go get github.com/lib/pq
go get github.com/go-sql-driver/mysql
go get github.com/pkg/profile go get github.com/pkg/profile
go get gotest.tools/assert go get gotest.tools/assert
@ -48,14 +64,37 @@ jobs:
echo Failed waiting for Postgres && exit 1 echo Failed waiting for Postgres && exit 1
- run: - run:
name: Init Postgres database name: Waiting for MySQL to be ready
command: | command: |
cd tests for i in `seq 1 10`;
go run ./init/init.go do
cd .. nc -z 127.0.0.1 3306 && echo Success && exit 0
echo -n .
sleep 1
done
echo Failed waiting for MySQL && exit 1
- run:
name: Install MySQL CLI;
command: |
sudo apt-get install default-mysql-client
- run:
name: Create MySQL user and databases
command: |
mysql -h 127.0.0.1 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -u jet -pjet -e "create database test_sample"
- run:
name: Init Postgres database
command: |
cd tests
go run ./init/init.go -testsuite all
cd ..
- run: mkdir -p $TEST_RESULTS - run: mkdir -p $TEST_RESULTS
- run: go test -v . ./tests -coverpkg=github.com/go-jet/jet,github.com/go-jet/jet/execution/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml - run: go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/execution/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
- run: - run:
name: Upload code coverage name: Upload code coverage
@ -67,4 +106,75 @@ jobs:
- store_test_results: # Upload test results for display in Test Summary: https://circleci.com/docs/2.0/collect-test-data/ - store_test_results: # Upload test results for display in Test Summary: https://circleci.com/docs/2.0/collect-test-data/
path: /tmp/test-results path: /tmp/test-results
build-mariadb:
docker:
# specify the version
- image: circleci/golang:1.11
- image: circleci/mariadb:10.3
command: [--default-authentication-plugin=mysql_native_password]
environment:
MYSQL_ROOT_PASSWORD: jet
MYSQL_DATABASE: dvds
MYSQL_USER: jet
MYSQL_PASSWORD: jet
working_directory: /go/src/github.com/go-jet/jet
environment: # environment variables for the build itself
TEST_RESULTS: /tmp/test-results # path to where test results will be saved
steps:
- checkout
- run:
name: Submodule init
command: |
git submodule init
git submodule update
cd ./tests/testdata && git fetch && git checkout master
- run:
name: Install dependencies
command: |
go get github.com/google/uuid
go get github.com/lib/pq
go get github.com/go-sql-driver/mysql
go get github.com/pkg/profile
go get gotest.tools/assert
go get github.com/davecgh/go-spew/spew
go get github.com/jstemmer/go-junit-report
go install github.com/go-jet/jet/cmd/jet
- run:
name: Install MySQL CLI;
command: |
sudo apt-get install default-mysql-client
- run:
name: Init MariaDB database
command: |
mysql -h 127.0.0.1 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -u jet -pjet -e "create database test_sample"
- run:
name: Init MariaDB database
command: |
cd tests
go run ./init/init.go -testsuite MariaDB
cd ..
- run:
name: Run MariaDB tests
command: |
go test -v ./tests/mysql/ -source=MariaDB
workflows:
version: 2
build_and_test:
jobs:
- build-postgres-and-mysql
- build-mariadb

1
.gitignore vendored
View file

@ -18,3 +18,4 @@
# Test files # Test files
gen gen
.gentestdata .gentestdata
.tests/testdata/

3
.gitmodules vendored Normal file
View file

@ -0,0 +1,3 @@
[submodule "tests/testdata"]
path = tests/testdata
url = https://github.com/go-jet/jet-test-data

View file

@ -1,9 +1,11 @@
# Jet # Jet
[![CircleCI](https://circleci.com/gh/go-jet/jet/tree/develop.svg?style=svg&circle-token=97f255c6a4a3ab6590ea2e9195eb3ebf9f97b4a7)](https://circleci.com/gh/go-jet/jet/tree/develop)
[![codecov](https://codecov.io/gh/go-jet/jet/branch/develop/graph/badge.svg)](https://codecov.io/gh/go-jet/jet)
[![Go Report Card](https://goreportcard.com/badge/github.com/go-jet/jet)](https://goreportcard.com/report/github.com/go-jet/jet) [![Go Report Card](https://goreportcard.com/badge/github.com/go-jet/jet)](https://goreportcard.com/report/github.com/go-jet/jet)
[![Documentation](https://godoc.org/github.com/go-jet/jet?status.svg)](http://godoc.org/github.com/go-jet/jet) [![Documentation](https://godoc.org/github.com/go-jet/jet?status.svg)](http://godoc.org/github.com/go-jet/jet)
[![codecov](https://codecov.io/gh/go-jet/jet/branch/develop/graph/badge.svg)](https://codecov.io/gh/go-jet/jet) [![GitHub release](https://img.shields.io/github/release/go-jet/jet.svg)](https://github.com/go-jet/jet/releases)
[![CircleCI](https://circleci.com/gh/go-jet/jet/tree/develop.svg?style=svg&circle-token=97f255c6a4a3ab6590ea2e9195eb3ebf9f97b4a7)](https://circleci.com/gh/go-jet/jet/tree/develop)
Jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability to easily Jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability to easily
convert database query result to desired arbitrary structure. convert database query result to desired arbitrary structure.
@ -258,7 +260,7 @@ above statement. Usually this is the most complex and tedious work, but with Jet
First we have to create desired structure to store query result set. First we have to create desired structure to store query result set.
This is done be combining autogenerated model types or it can be done manually(see [wiki](https://github.com/go-jet/jet/wiki/Scan-to-arbitrary-destination) for more information). This is done be combining autogenerated model types or it can be done manually(see [wiki](https://github.com/go-jet/jet/wiki/Scan-to-arbitrary-destination) for more information).
Let's say this is our desired structure, created by combining auto-generated model types: Let's say this is our desired structure:
```go ```go
var dest []struct { var dest []struct {
model.Actor model.Actor
@ -506,7 +508,7 @@ return result in one database call. Handler execution will be only proportional
ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate). ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate).
With Jet you can even join the whole database and store the whole structured result in in one query call. With Jet you can even join the whole database and store the whole structured result in in one query call.
This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/chinook_db_test.go#L40). This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/postgres/chinook_db_test.go#L40).
The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s. The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s.
##### How quickly bugs are found ##### How quickly bugs are found

1
_config.yml Normal file
View file

@ -0,0 +1 @@
theme: jekyll-theme-tactile

View file

@ -1,34 +0,0 @@
package jet
type alias struct {
expression Expression
alias string
}
func newAlias(expression Expression, aliasName string) projection {
return &alias{
expression: expression,
alias: aliasName,
}
}
func (a *alias) from(subQuery SelectTable) projection {
column := newColumn(a.alias, "", nil)
column.parent = &column
column.subQuery = subQuery
return &column
}
func (a *alias) serializeForProjection(statement statementType, out *sqlBuilder) error {
err := a.expression.serialize(statement, out)
if err != nil {
return err
}
out.writeString("AS")
out.writeQuotedString(a.alias)
return nil
}

255
clause.go
View file

@ -1,255 +0,0 @@
package jet
import (
"bytes"
"github.com/go-jet/jet/internal/utils"
"github.com/google/uuid"
"strconv"
"strings"
"time"
)
type serializeOption int
const (
noWrap serializeOption = iota
)
type clause interface {
serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error
}
func contains(options []serializeOption, option serializeOption) bool {
for _, opt := range options {
if opt == option {
return true
}
}
return false
}
type sqlBuilder struct {
buff bytes.Buffer
args []interface{}
lastChar byte
ident int
}
type statementType string
const (
selectStatement statementType = "SELECT"
insertStatement statementType = "INSERT"
updateStatement statementType = "UPDATE"
deleteStatement statementType = "DELETE"
setStatement statementType = "SET"
lockStatement statementType = "LOCK"
)
const defaultIdent = 5
func (q *sqlBuilder) increaseIdent() {
q.ident += defaultIdent
}
func (q *sqlBuilder) decreaseIdent() {
if q.ident < defaultIdent {
q.ident = 0
}
q.ident -= defaultIdent
}
func (q *sqlBuilder) writeProjections(statement statementType, projections []projection) error {
q.increaseIdent()
err := serializeProjectionList(statement, projections, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeFrom(statement statementType, table ReadableTable) error {
q.newLine()
q.writeString("FROM")
q.increaseIdent()
err := table.serialize(statement, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeWhere(statement statementType, where Expression) error {
q.newLine()
q.writeString("WHERE")
q.increaseIdent()
err := where.serialize(statement, q, noWrap)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeGroupBy(statement statementType, groupBy []groupByClause) error {
q.newLine()
q.writeString("GROUP BY")
q.increaseIdent()
err := serializeGroupByClauseList(statement, groupBy, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeOrderBy(statement statementType, orderBy []orderByClause) error {
q.newLine()
q.writeString("ORDER BY")
q.increaseIdent()
err := serializeOrderByClauseList(statement, orderBy, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeHaving(statement statementType, having Expression) error {
q.newLine()
q.writeString("HAVING")
q.increaseIdent()
err := having.serialize(statement, q, noWrap)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeReturning(statement statementType, returning []projection) error {
if len(returning) == 0 {
return nil
}
q.newLine()
q.writeString("RETURNING")
q.increaseIdent()
return q.writeProjections(statement, returning)
}
func (q *sqlBuilder) newLine() {
q.write([]byte{'\n'})
q.write(bytes.Repeat([]byte{' '}, q.ident))
}
func (q *sqlBuilder) write(data []byte) {
if len(data) == 0 {
return
}
if !isPreSeparator(q.lastChar) && !isPostSeparator(data[0]) && q.buff.Len() > 0 {
q.buff.WriteByte(' ')
}
q.buff.Write(data)
q.lastChar = data[len(data)-1]
}
func isPreSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':'
}
func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
}
func (q *sqlBuilder) writeQuotedString(str string) {
q.writeString(`"` + str + `"`)
}
func (q *sqlBuilder) writeString(str string) {
q.write([]byte(str))
}
func (q *sqlBuilder) writeIdentifier(name string) {
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
if quoteWrap {
q.writeString(`"` + name + `"`)
} else {
q.writeString(name)
}
}
func (q *sqlBuilder) writeByte(b byte) {
q.write([]byte{b})
}
func (q *sqlBuilder) finalize() (string, []interface{}) {
return q.buff.String() + ";\n", q.args
}
func (q *sqlBuilder) insertConstantArgument(arg interface{}) {
q.writeString(argToString(arg))
}
func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) {
q.args = append(q.args, arg)
argPlaceholder := "$" + strconv.Itoa(len(q.args))
q.writeString(argPlaceholder)
}
func argToString(value interface{}) string {
if utils.IsNil(value) {
return "NULL"
}
switch bindVal := value.(type) {
case bool:
if bindVal {
return "TRUE"
}
return "FALSE"
case int8:
return strconv.FormatInt(int64(bindVal), 10)
case int:
return strconv.FormatInt(int64(bindVal), 10)
case int16:
return strconv.FormatInt(int64(bindVal), 10)
case int32:
return strconv.FormatInt(int64(bindVal), 10)
case int64:
return strconv.FormatInt(int64(bindVal), 10)
case uint8:
return strconv.FormatUint(uint64(bindVal), 10)
case uint:
return strconv.FormatUint(uint64(bindVal), 10)
case uint16:
return strconv.FormatUint(uint64(bindVal), 10)
case uint32:
return strconv.FormatUint(uint64(bindVal), 10)
case uint64:
return strconv.FormatUint(uint64(bindVal), 10)
case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case float64:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case string:
return stringQuote(bindVal)
case []byte:
return stringQuote(string(bindVal))
case uuid.UUID:
return stringQuote(bindVal.String())
case time.Time:
return stringQuote(string(utils.FormatTimestamp(bindVal)))
default:
return "[Unsupported type]"
}
}
func stringQuote(value string) string {
return `'` + strings.Replace(value, "'", "''", -1) + `'`
}

View file

@ -3,12 +3,19 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"github.com/go-jet/jet/generator/postgres" mysqlgen "github.com/go-jet/jet/generator/mysql"
postgresgen "github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/postgres"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"os" "os"
"strings"
) )
var ( var (
source string
host string host string
port int port int
user string user string
@ -22,14 +29,16 @@ var (
) )
func init() { func init() {
flag.StringVar(&source, "source", "", "Database system name (PostgreSQL or MySQL)")
flag.StringVar(&host, "host", "", "Database host path (Example: localhost)") flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
flag.IntVar(&port, "port", 0, "Database port") flag.IntVar(&port, "port", 0, "Database port")
flag.StringVar(&user, "user", "", "Database user") flag.StringVar(&user, "user", "", "Database user")
flag.StringVar(&password, "password", "", "The users password") flag.StringVar(&password, "password", "", "The users password")
flag.StringVar(&sslmode, "sslmode", "disable", "Whether or not to use SSL(optional)")
flag.StringVar(&params, "params", "", "Additional connection string parameters(optional)") flag.StringVar(&params, "params", "", "Additional connection string parameters(optional)")
flag.StringVar(&dbName, "dbname", "", "name of the database") flag.StringVar(&dbName, "dbname", "", "Database name")
flag.StringVar(&schemaName, "schema", "public", "Database schema name.") flag.StringVar(&schemaName, "schema", "public", `Database schema name. (default "public") (ignored for MySQL)`)
flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL(optional)(default "disable") (ignored for MySQL)`)
flag.StringVar(&destDir, "path", "", "Destination dir for files generated.") flag.StringVar(&destDir, "path", "", "Destination dir for files generated.")
} }
@ -38,7 +47,11 @@ func main() {
flag.Usage = func() { flag.Usage = func() {
_, _ = fmt.Fprint(os.Stdout, ` _, _ = fmt.Fprint(os.Stdout, `
Usage of jet: Jet generator 2.0.0
Usage:
-source string
Database system name (PostgreSQL or MySQL)
-host string -host string
Database host path (Example: localhost) Database host path (Example: localhost)
-port int -port int
@ -48,13 +61,13 @@ Usage of jet:
-password string -password string
The users password The users password
-dbname string -dbname string
name of the database Database name
-params string -params string
Additional connection string parameters(optional) Additional connection string parameters(optional)
-schema string -schema string
Database schema name. (default "public") Database schema name. (default "public") (ignored for MySQL)
-sslmode string -sslmode string
Whether or not to use SSL(optional) (default "disable") Whether or not to use SSL(optional) (default "disable") (ignored for MySQL)
-path string -path string
Destination dir for files generated. Destination dir for files generated.
`) `)
@ -62,28 +75,54 @@ Usage of jet:
flag.Parse() flag.Parse()
if host == "" || port == 0 || user == "" || dbName == "" || schemaName == "" { if source == "" || host == "" || port == 0 || user == "" || dbName == "" {
fmt.Println("\njet: required flag missing") printErrorAndExit("\nERROR: required flag(s) missing")
flag.Usage()
os.Exit(-2)
} }
genData := postgres.DBConnection{ var err error
Host: host,
Port: port,
User: user,
Password: password,
SslMode: sslmode,
Params: params,
DBName: dbName, switch strings.ToLower(strings.TrimSpace(source)) {
SchemaName: schemaName, case strings.ToLower(postgres.Dialect.Name()):
genData := postgresgen.DBConnection{
Host: host,
Port: port,
User: user,
Password: password,
SslMode: sslmode,
Params: params,
DBName: dbName,
SchemaName: schemaName,
}
err = postgresgen.Generate(destDir, genData)
case strings.ToLower(mysql.Dialect.Name()):
dbConn := mysqlgen.DBConnection{
Host: host,
Port: port,
User: user,
Password: password,
SslMode: sslmode,
Params: params,
DBName: dbName,
}
err = mysqlgen.Generate(destDir, dbConn)
default:
fmt.Println("ERROR: unsupported source " + source + ". " + postgres.Dialect.Name() + " and " + mysql.Dialect.Name() + " are currently supported.")
os.Exit(-4)
} }
err := postgres.Generate(destDir, genData)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
os.Exit(-1) os.Exit(-5)
} }
} }
func printErrorAndExit(error string) {
fmt.Println(error)
flag.Usage()
os.Exit(-2)
}

145
column.go
View file

@ -1,145 +0,0 @@
// Modeling of columns
package jet
type column interface {
Name() string
TableName() string
setTableName(table string)
setSubQuery(subQuery SelectTable)
defaultAlias() string
}
// Column is common column interface for all types of columns.
type Column interface {
Expression
column
}
// The base type for real materialized columns.
type columnImpl struct {
expressionInterfaceImpl
name string
tableName string
subQuery SelectTable
}
func newColumn(name string, tableName string, parent Column) columnImpl {
bc := columnImpl{
name: name,
tableName: tableName,
}
bc.expressionInterfaceImpl.parent = parent
return bc
}
func (c *columnImpl) Name() string {
return c.name
}
func (c *columnImpl) TableName() string {
return c.tableName
}
func (c *columnImpl) setTableName(table string) {
c.tableName = table
}
func (c *columnImpl) setSubQuery(subQuery SelectTable) {
c.subQuery = subQuery
}
func (c *columnImpl) defaultAlias() string {
if c.tableName != "" {
return c.tableName + "." + c.name
}
return c.name
}
func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
if statement == setStatement {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.writeString(`"` + c.defaultAlias() + `"`) //always quote
return nil
}
return c.serialize(statement, out)
}
func (c columnImpl) serializeForProjection(statement statementType, out *sqlBuilder) error {
err := c.serialize(statement, out)
if err != nil {
return err
}
out.writeString(`AS "` + c.defaultAlias() + `"`)
return nil
}
func (c columnImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if c.subQuery != nil {
out.writeIdentifier(c.subQuery.Alias())
out.writeByte('.')
out.writeQuotedString(c.defaultAlias())
} else {
if c.tableName != "" {
out.writeIdentifier(c.tableName)
out.writeByte('.')
}
out.writeIdentifier(c.name)
}
return nil
}
//------------------------------------------------------//
// ColumnList is redefined type to support list of columns as single projection
type ColumnList []Column
// projection interface implementation
func (cl ColumnList) isProjectionType() {}
func (cl ColumnList) from(subQuery SelectTable) projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.from(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serializeForProjection(statement statementType, out *sqlBuilder) error {
projections := columnListToProjectionList(cl)
err := serializeProjectionList(statement, projections, out)
if err != nil {
return err
}
return nil
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }

View file

@ -1,45 +0,0 @@
package jet
import "testing"
var dateVar = Date(2000, 12, 30)
func TestDateExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.EQ(table2ColDate), "(table1.col_date = table2.col_date)")
assertClauseSerialize(t, table1ColDate.EQ(dateVar), "(table1.col_date = $1::date)", "2000-12-30")
}
func TestDateExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.NOT_EQ(table2ColDate), "(table1.col_date != table2.col_date)")
assertClauseSerialize(t, table1ColDate.NOT_EQ(dateVar), "(table1.col_date != $1::date)", "2000-12-30")
}
func TestDateExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(table2ColDate), "(table1.col_date IS DISTINCT FROM table2.col_date)")
assertClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(dateVar), "(table1.col_date IS DISTINCT FROM $1::date)", "2000-12-30")
}
func TestDateExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(table2ColDate), "(table1.col_date IS NOT DISTINCT FROM table2.col_date)")
assertClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(dateVar), "(table1.col_date IS NOT DISTINCT FROM $1::date)", "2000-12-30")
}
func TestDateExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColDate.GT(table2ColDate), "(table1.col_date > table2.col_date)")
assertClauseSerialize(t, table1ColDate.GT(dateVar), "(table1.col_date > $1::date)", "2000-12-30")
}
func TestDateExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.GT_EQ(table2ColDate), "(table1.col_date >= table2.col_date)")
assertClauseSerialize(t, table1ColDate.GT_EQ(dateVar), "(table1.col_date >= $1::date)", "2000-12-30")
}
func TestDateExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColDate.LT(table2ColDate), "(table1.col_date < table2.col_date)")
assertClauseSerialize(t, table1ColDate.LT(dateVar), "(table1.col_date < $1::date)", "2000-12-30")
}
func TestDateExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.LT_EQ(table2ColDate), "(table1.col_date <= table2.col_date)")
assertClauseSerialize(t, table1ColDate.LT_EQ(dateVar), "(table1.col_date <= $1::date)", "2000-12-30")
}

View file

@ -1,102 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
)
// DeleteStatement is interface for SQL DELETE statement
type DeleteStatement interface {
Statement
WHERE(expression BoolExpression) DeleteStatement
RETURNING(projections ...projection) DeleteStatement
}
func newDeleteStatement(table WritableTable) DeleteStatement {
return &deleteStatementImpl{
table: table,
}
}
type deleteStatementImpl struct {
table WritableTable
where BoolExpression
returning []projection
}
func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
d.where = expression
return d
}
func (d *deleteStatementImpl) RETURNING(projections ...projection) DeleteStatement {
d.returning = projections
return d
}
func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
if d == nil {
return errors.New("jet: delete statement is nil")
}
out.newLine()
out.writeString("DELETE FROM")
if d.table == nil {
return errors.New("jet: nil tableName")
}
if err := d.table.serialize(deleteStatement, out); err != nil {
return err
}
if d.where == nil {
return errors.New("jet: deleting without a WHERE clause")
}
if err := out.writeWhere(deleteStatement, d.where); err != nil {
return err
}
if err := out.writeReturning(deleteStatement, d.returning); err != nil {
return err
}
return nil
}
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := &sqlBuilder{}
err = d.serializeImpl(queryData)
if err != nil {
return
}
query, args = queryData.finalize()
return
}
func (d *deleteStatementImpl) DebugSql() (query string, err error) {
return debugSql(d)
}
func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(d, db, destination)
}
func (d *deleteStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, d, db, destination)
}
func (d *deleteStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(d, db)
}
func (d *deleteStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, d, db)
}

View file

@ -1,25 +0,0 @@
package jet
import (
"testing"
)
func TestDeleteUnconditionally(t *testing.T) {
assertStatementErr(t, table1.DELETE(), `jet: deleting without a WHERE clause`)
assertStatementErr(t, table1.DELETE().WHERE(nil), `jet: deleting without a WHERE clause`)
}
func TestDeleteWithWhere(t *testing.T) {
assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), `
DELETE FROM db.table1
WHERE table1.col1 = $1;
`, int64(1))
}
func TestDeleteWithWhereAndReturning(t *testing.T) {
assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).RETURNING(table1Col1), `
DELETE FROM db.table1
WHERE table1.col1 = $1
RETURNING table1.col1 AS "table1.col1";
`, int64(1))
}

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -8,18 +8,18 @@
package enum package enum
import "github.com/go-jet/jet" import "github.com/go-jet/jet/postgres"
var MpaaRating = &struct { var MpaaRating = &struct {
G jet.StringExpression G postgres.StringExpression
Pg jet.StringExpression Pg postgres.StringExpression
Pg13 jet.StringExpression Pg13 postgres.StringExpression
R jet.StringExpression R postgres.StringExpression
Nc17 jet.StringExpression Nc17 postgres.StringExpression
}{ }{
G: jet.NewEnumValue("G"), G: postgres.NewEnumValue("G"),
Pg: jet.NewEnumValue("PG"), Pg: postgres.NewEnumValue("PG"),
Pg13: jet.NewEnumValue("PG-13"), Pg13: postgres.NewEnumValue("PG-13"),
R: jet.NewEnumValue("R"), R: postgres.NewEnumValue("R"),
Nc17: jet.NewEnumValue("NC-17"), Nc17: postgres.NewEnumValue("NC-17"),
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,22 +9,22 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var Actor = newActorTable() var Actor = newActorTable()
type ActorTable struct { type ActorTable struct {
jet.Table postgres.Table
//Columns //Columns
ActorID jet.ColumnInteger ActorID postgres.ColumnInteger
FirstName jet.ColumnString FirstName postgres.ColumnString
LastName jet.ColumnString LastName postgres.ColumnString
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new ActorTable with assigned alias // creates new ActorTable with assigned alias
@ -38,14 +38,14 @@ func (a *ActorTable) AS(alias string) *ActorTable {
func newActorTable() *ActorTable { func newActorTable() *ActorTable {
var ( var (
ActorIDColumn = jet.IntegerColumn("actor_id") ActorIDColumn = postgres.IntegerColumn("actor_id")
FirstNameColumn = jet.StringColumn("first_name") FirstNameColumn = postgres.StringColumn("first_name")
LastNameColumn = jet.StringColumn("last_name") LastNameColumn = postgres.StringColumn("last_name")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
) )
return &ActorTable{ return &ActorTable{
Table: jet.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn), Table: postgres.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn),
//Columns //Columns
ActorID: ActorIDColumn, ActorID: ActorIDColumn,
@ -53,7 +53,7 @@ func newActorTable() *ActorTable {
LastName: LastNameColumn, LastName: LastNameColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}, AllColumns: postgres.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn),
MutableColumns: jet.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}, MutableColumns: postgres.ColumnList(FirstNameColumn, LastNameColumn, LastUpdateColumn),
} }
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var Category = newCategoryTable() var Category = newCategoryTable()
type CategoryTable struct { type CategoryTable struct {
jet.Table postgres.Table
//Columns //Columns
CategoryID jet.ColumnInteger CategoryID postgres.ColumnInteger
Name jet.ColumnString Name postgres.ColumnString
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new CategoryTable with assigned alias // creates new CategoryTable with assigned alias
@ -37,20 +37,20 @@ func (a *CategoryTable) AS(alias string) *CategoryTable {
func newCategoryTable() *CategoryTable { func newCategoryTable() *CategoryTable {
var ( var (
CategoryIDColumn = jet.IntegerColumn("category_id") CategoryIDColumn = postgres.IntegerColumn("category_id")
NameColumn = jet.StringColumn("name") NameColumn = postgres.StringColumn("name")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
) )
return &CategoryTable{ return &CategoryTable{
Table: jet.NewTable("dvds", "category", CategoryIDColumn, NameColumn, LastUpdateColumn), Table: postgres.NewTable("dvds", "category", CategoryIDColumn, NameColumn, LastUpdateColumn),
//Columns //Columns
CategoryID: CategoryIDColumn, CategoryID: CategoryIDColumn,
Name: NameColumn, Name: NameColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{CategoryIDColumn, NameColumn, LastUpdateColumn}, AllColumns: postgres.ColumnList(CategoryIDColumn, NameColumn, LastUpdateColumn),
MutableColumns: jet.ColumnList{NameColumn, LastUpdateColumn}, MutableColumns: postgres.ColumnList(NameColumn, LastUpdateColumn),
} }
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,31 +9,31 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var Film = newFilmTable() var Film = newFilmTable()
type FilmTable struct { type FilmTable struct {
jet.Table postgres.Table
//Columns //Columns
FilmID jet.ColumnInteger FilmID postgres.ColumnInteger
Title jet.ColumnString Title postgres.ColumnString
Description jet.ColumnString Description postgres.ColumnString
ReleaseYear jet.ColumnInteger ReleaseYear postgres.ColumnInteger
LanguageID jet.ColumnInteger LanguageID postgres.ColumnInteger
RentalDuration jet.ColumnInteger RentalDuration postgres.ColumnInteger
RentalRate jet.ColumnFloat RentalRate postgres.ColumnFloat
Length jet.ColumnInteger Length postgres.ColumnInteger
ReplacementCost jet.ColumnFloat ReplacementCost postgres.ColumnFloat
Rating jet.ColumnString Rating postgres.ColumnString
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
SpecialFeatures jet.ColumnString SpecialFeatures postgres.ColumnString
Fulltext jet.ColumnString Fulltext postgres.ColumnString
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new FilmTable with assigned alias // creates new FilmTable with assigned alias
@ -47,23 +47,23 @@ func (a *FilmTable) AS(alias string) *FilmTable {
func newFilmTable() *FilmTable { func newFilmTable() *FilmTable {
var ( var (
FilmIDColumn = jet.IntegerColumn("film_id") FilmIDColumn = postgres.IntegerColumn("film_id")
TitleColumn = jet.StringColumn("title") TitleColumn = postgres.StringColumn("title")
DescriptionColumn = jet.StringColumn("description") DescriptionColumn = postgres.StringColumn("description")
ReleaseYearColumn = jet.IntegerColumn("release_year") ReleaseYearColumn = postgres.IntegerColumn("release_year")
LanguageIDColumn = jet.IntegerColumn("language_id") LanguageIDColumn = postgres.IntegerColumn("language_id")
RentalDurationColumn = jet.IntegerColumn("rental_duration") RentalDurationColumn = postgres.IntegerColumn("rental_duration")
RentalRateColumn = jet.FloatColumn("rental_rate") RentalRateColumn = postgres.FloatColumn("rental_rate")
LengthColumn = jet.IntegerColumn("length") LengthColumn = postgres.IntegerColumn("length")
ReplacementCostColumn = jet.FloatColumn("replacement_cost") ReplacementCostColumn = postgres.FloatColumn("replacement_cost")
RatingColumn = jet.StringColumn("rating") RatingColumn = postgres.StringColumn("rating")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
SpecialFeaturesColumn = jet.StringColumn("special_features") SpecialFeaturesColumn = postgres.StringColumn("special_features")
FulltextColumn = jet.StringColumn("fulltext") FulltextColumn = postgres.StringColumn("fulltext")
) )
return &FilmTable{ return &FilmTable{
Table: jet.NewTable("dvds", "film", FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn), Table: postgres.NewTable("dvds", "film", FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn),
//Columns //Columns
FilmID: FilmIDColumn, FilmID: FilmIDColumn,
@ -80,7 +80,7 @@ func newFilmTable() *FilmTable {
SpecialFeatures: SpecialFeaturesColumn, SpecialFeatures: SpecialFeaturesColumn,
Fulltext: FulltextColumn, Fulltext: FulltextColumn,
AllColumns: jet.ColumnList{FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}, AllColumns: postgres.ColumnList(FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn),
MutableColumns: jet.ColumnList{TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}, MutableColumns: postgres.ColumnList(TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn),
} }
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var FilmActor = newFilmActorTable() var FilmActor = newFilmActorTable()
type FilmActorTable struct { type FilmActorTable struct {
jet.Table postgres.Table
//Columns //Columns
ActorID jet.ColumnInteger ActorID postgres.ColumnInteger
FilmID jet.ColumnInteger FilmID postgres.ColumnInteger
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new FilmActorTable with assigned alias // creates new FilmActorTable with assigned alias
@ -37,20 +37,20 @@ func (a *FilmActorTable) AS(alias string) *FilmActorTable {
func newFilmActorTable() *FilmActorTable { func newFilmActorTable() *FilmActorTable {
var ( var (
ActorIDColumn = jet.IntegerColumn("actor_id") ActorIDColumn = postgres.IntegerColumn("actor_id")
FilmIDColumn = jet.IntegerColumn("film_id") FilmIDColumn = postgres.IntegerColumn("film_id")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
) )
return &FilmActorTable{ return &FilmActorTable{
Table: jet.NewTable("dvds", "film_actor", ActorIDColumn, FilmIDColumn, LastUpdateColumn), Table: postgres.NewTable("dvds", "film_actor", ActorIDColumn, FilmIDColumn, LastUpdateColumn),
//Columns //Columns
ActorID: ActorIDColumn, ActorID: ActorIDColumn,
FilmID: FilmIDColumn, FilmID: FilmIDColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{ActorIDColumn, FilmIDColumn, LastUpdateColumn}, AllColumns: postgres.ColumnList(ActorIDColumn, FilmIDColumn, LastUpdateColumn),
MutableColumns: jet.ColumnList{LastUpdateColumn}, MutableColumns: postgres.ColumnList(LastUpdateColumn),
} }
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var FilmCategory = newFilmCategoryTable() var FilmCategory = newFilmCategoryTable()
type FilmCategoryTable struct { type FilmCategoryTable struct {
jet.Table postgres.Table
//Columns //Columns
FilmID jet.ColumnInteger FilmID postgres.ColumnInteger
CategoryID jet.ColumnInteger CategoryID postgres.ColumnInteger
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new FilmCategoryTable with assigned alias // creates new FilmCategoryTable with assigned alias
@ -37,20 +37,20 @@ func (a *FilmCategoryTable) AS(alias string) *FilmCategoryTable {
func newFilmCategoryTable() *FilmCategoryTable { func newFilmCategoryTable() *FilmCategoryTable {
var ( var (
FilmIDColumn = jet.IntegerColumn("film_id") FilmIDColumn = postgres.IntegerColumn("film_id")
CategoryIDColumn = jet.IntegerColumn("category_id") CategoryIDColumn = postgres.IntegerColumn("category_id")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
) )
return &FilmCategoryTable{ return &FilmCategoryTable{
Table: jet.NewTable("dvds", "film_category", FilmIDColumn, CategoryIDColumn, LastUpdateColumn), Table: postgres.NewTable("dvds", "film_category", FilmIDColumn, CategoryIDColumn, LastUpdateColumn),
//Columns //Columns
FilmID: FilmIDColumn, FilmID: FilmIDColumn,
CategoryID: CategoryIDColumn, CategoryID: CategoryIDColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{FilmIDColumn, CategoryIDColumn, LastUpdateColumn}, AllColumns: postgres.ColumnList(FilmIDColumn, CategoryIDColumn, LastUpdateColumn),
MutableColumns: jet.ColumnList{LastUpdateColumn}, MutableColumns: postgres.ColumnList(LastUpdateColumn),
} }
} }

View file

@ -1,6 +1,6 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST // Generated at Thursday, 08-Aug-19 16:59:58 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/postgres"
) )
var Language = newLanguageTable() var Language = newLanguageTable()
type LanguageTable struct { type LanguageTable struct {
jet.Table postgres.Table
//Columns //Columns
LanguageID jet.ColumnInteger LanguageID postgres.ColumnInteger
Name jet.ColumnString Name postgres.ColumnString
LastUpdate jet.ColumnTimestamp LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList AllColumns postgres.IColumnList
MutableColumns jet.ColumnList MutableColumns postgres.IColumnList
} }
// creates new LanguageTable with assigned alias // creates new LanguageTable with assigned alias
@ -37,20 +37,20 @@ func (a *LanguageTable) AS(alias string) *LanguageTable {
func newLanguageTable() *LanguageTable { func newLanguageTable() *LanguageTable {
var ( var (
LanguageIDColumn = jet.IntegerColumn("language_id") LanguageIDColumn = postgres.IntegerColumn("language_id")
NameColumn = jet.StringColumn("name") NameColumn = postgres.StringColumn("name")
LastUpdateColumn = jet.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
) )
return &LanguageTable{ return &LanguageTable{
Table: jet.NewTable("dvds", "language", LanguageIDColumn, NameColumn, LastUpdateColumn), Table: postgres.NewTable("dvds", "language", LanguageIDColumn, NameColumn, LastUpdateColumn),
//Columns //Columns
LanguageID: LanguageIDColumn, LanguageID: LanguageIDColumn,
Name: NameColumn, Name: NameColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{LanguageIDColumn, NameColumn, LastUpdateColumn}, AllColumns: postgres.ColumnList(LanguageIDColumn, NameColumn, LastUpdateColumn),
MutableColumns: jet.ColumnList{NameColumn, LastUpdateColumn}, MutableColumns: postgres.ColumnList(NameColumn, LastUpdateColumn),
} }
} }

View file

@ -9,8 +9,8 @@ import (
// dot import so go code would resemble as much as native SQL // dot import so go code would resemble as much as native SQL
// dot import is not mandatory // dot import is not mandatory
. "github.com/go-jet/jet"
. "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table" . "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table"
. "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/model" "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/model"
) )
@ -24,7 +24,6 @@ const (
) )
func main() { func main() {
// Connect to database // Connect to database
var connectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName) var connectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName)
@ -97,17 +96,15 @@ func jsonSave(path string, v interface{}) {
} }
} }
func printStatementInfo(stmt Statement) { func printStatementInfo(stmt SelectStatement) {
query, args, err := stmt.Sql() query, args := stmt.Sql()
panicOnError(err)
fmt.Println("Parameterized query: ") fmt.Println("Parameterized query: ")
fmt.Println(query) fmt.Println(query)
fmt.Println("Arguments: ") fmt.Println("Arguments: ")
fmt.Println(args) fmt.Println(args)
debugSQL, err := stmt.DebugSql() debugSQL := stmt.DebugSql()
panicOnError(err)
fmt.Println("\n\n==============================") fmt.Println("\n\n==============================")

View file

@ -4,10 +4,10 @@ import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors"
"fmt" "fmt"
"github.com/go-jet/jet/execution/internal" "github.com/go-jet/jet/execution/internal"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
"github.com/google/uuid"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
@ -18,14 +18,11 @@ import (
// Destination can be either pointer to struct or pointer to slice of structs. // Destination can be either pointer to struct or pointer to slice of structs.
func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error { func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error {
if utils.IsNil(destinationPtr) { utils.MustBeInitializedPtr(db, "jet: db is nil")
return errors.New("jet: Destination is nil") utils.MustBeInitializedPtr(destinationPtr, "jet: destination is nil")
} utils.MustBe(destinationPtr, reflect.Ptr, "jet: destination has to be a pointer to slice or pointer to struct")
destinationPtrType := reflect.TypeOf(destinationPtr) destinationPtrType := reflect.TypeOf(destinationPtr)
if destinationPtrType.Kind() != reflect.Ptr {
return errors.New("jet: Destination has to be a pointer to slice or pointer to struct")
}
if destinationPtrType.Elem().Kind() == reflect.Slice { if destinationPtrType.Elem().Kind() == reflect.Slice {
return queryToSlice(context, db, query, args, destinationPtr) return queryToSlice(context, db, query, args, destinationPtr)
@ -51,24 +48,11 @@ func Query(context context.Context, db DB, query string, args []interface{}, des
} }
return nil return nil
} else { } else {
return errors.New("jet: unsupported destination type") panic("jet: destination has to be a pointer to slice or pointer to struct")
} }
} }
func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) error { func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) error {
if db == nil {
return errors.New("jet: db is nil")
}
if slicePtr == nil {
return errors.New("jet: Destination is nil. ")
}
destinationType := reflect.TypeOf(slicePtr)
if destinationType.Kind() != reflect.Ptr && destinationType.Elem().Kind() != reflect.Slice {
return errors.New("jet: Destination has to be a pointer to slice. ")
}
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
@ -126,14 +110,12 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
sliceElemType := getSliceElemType(slicePtrValue) sliceElemType := getSliceElemType(slicePtrValue)
if isGoBaseType(sliceElemType) { if isSimpleModelType(sliceElemType) {
updated, err = mapRowToBaseTypeSlice(scanContext, slicePtrValue, field) updated, err = mapRowToBaseTypeSlice(scanContext, slicePtrValue, field)
return return
} }
if sliceElemType.Kind() != reflect.Struct { utils.TypeMustBe(sliceElemType, reflect.Struct, "jet: unsupported slice element type"+fieldToString(field))
return false, errors.New("jet: Unsupported dest type: " + field.Name + " " + field.Type.String())
}
structGroupKey := scanContext.getGroupKey(sliceElemType, field) structGroupKey := scanContext.getGroupKey(sliceElemType, field)
@ -226,7 +208,7 @@ func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
if implementsScannerType(field.Type) { if implementsScannerType(field.Type) {
fieldMap.implementsScanner = true fieldMap.implementsScanner = true
} else if !isGoBaseType(field.Type) { } else if !isSimpleModelType(field.Type) {
fieldMap.complexType = true fieldMap.complexType = true
} }
@ -249,6 +231,10 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
field := structType.Field(i) field := structType.Field(i)
fieldValue := structValue.Field(i) fieldValue := structValue.Field(i)
if !fieldValue.CanSet() { // private field
continue
}
fieldMap := typeInf.fieldMappings[i] fieldMap := typeInf.fieldMappings[i]
if fieldMap.complexType { if fieldMap.complexType {
@ -284,8 +270,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
err = scanner.Scan(cellValue) err = scanner.Scan(cellValue)
if err != nil { if err != nil {
err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String()) panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String())
return
} }
updated = true updated = true
} else { } else {
@ -294,12 +279,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
if cellValue != nil { if cellValue != nil {
updated = true updated = true
initializeValueIfNilPtr(fieldValue) initializeValueIfNilPtr(fieldValue)
err = setReflectValue(reflect.ValueOf(cellValue), fieldValue) setReflectValue(reflect.ValueOf(cellValue), fieldValue)
if err != nil {
err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String())
return
}
} }
} }
} }
@ -310,9 +290,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
if destPtrValue.Kind() != reflect.Ptr { utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.")
return false, errors.New("jet: Internal error. ")
}
destValueKind := destPtrValue.Elem().Kind() destValueKind := destPtrValue.Elem().Kind()
@ -321,7 +299,7 @@ func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrVa
} else if destValueKind == reflect.Slice { } else if destValueKind == reflect.Slice {
return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
} else { } else {
return false, errors.New("jet: Unsupported dest type: " + structField.Name + " " + structField.Type.String()) panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
} }
} }
@ -331,14 +309,12 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
if dest.Kind() != reflect.Ptr { if dest.Kind() != reflect.Ptr {
destPtrValue = dest.Addr() destPtrValue = dest.Addr()
} else if dest.Kind() == reflect.Ptr { } else {
if dest.IsNil() { if dest.IsNil() {
destPtrValue = reflect.New(dest.Type().Elem()) destPtrValue = reflect.New(dest.Type().Elem())
} else { } else {
destPtrValue = dest destPtrValue = dest
} }
} else {
return false, errors.New("jet: Internal error. ")
} }
updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField) updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField)
@ -399,7 +375,7 @@ func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value {
func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error { func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error {
if slicePtrValue.IsNil() { if slicePtrValue.IsNil() {
panic("Slice is nil") panic("jet: internal, slice is nil")
} }
sliceValue := slicePtrValue.Elem() sliceValue := slicePtrValue.Elem()
sliceElemType := sliceValue.Type().Elem() sliceElemType := sliceValue.Type().Elem()
@ -410,8 +386,12 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
newElemValue = objPtrValue.Elem() newElemValue = objPtrValue.Elem()
} }
if newElemValue.Type().ConvertibleTo(sliceElemType) {
newElemValue = newElemValue.Convert(sliceElemType)
}
if !newElemValue.Type().AssignableTo(sliceElemType) { if !newElemValue.Type().AssignableTo(sliceElemType) {
return fmt.Errorf("jet: can't append %s to %s slice ", newElemValue.Type().String(), sliceValue.Type().String()) panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice")
} }
sliceValue.Set(reflect.Append(sliceValue, newElemValue)) sliceValue.Set(reflect.Append(sliceValue, newElemValue))
@ -465,6 +445,7 @@ func toCommonIdentifier(name string) string {
} }
func initializeValueIfNilPtr(value reflect.Value) { func initializeValueIfNilPtr(value reflect.Value) {
if !value.IsValid() || !value.CanSet() { if !value.IsValid() || !value.CanSet() {
return return
} }
@ -490,55 +471,119 @@ func valueToString(value reflect.Value) string {
valueInterface = value.Interface() valueInterface = value.Interface()
} }
if t, ok := valueInterface.(time.Time); ok { if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String() return t.String()
} }
return fmt.Sprintf("%#v", valueInterface) return fmt.Sprintf("%#v", valueInterface)
} }
func isGoBaseType(objType reflect.Type) bool { var timeType = reflect.TypeOf(time.Now())
typeStr := objType.String() var uuidType = reflect.TypeOf(uuid.New())
switch typeStr { func isSimpleModelType(objType reflect.Type) bool {
case "string", "int", "int16", "int32", "int64", "float32", "float64", "time.Time", "bool", "[]byte", "[]uint8", objType = indirectType(objType)
"*string", "*int", "*int16", "*int32", "*int64", "*float32", "*float64", "*time.Time", "*bool", "*[]byte", "*[]uint8":
switch objType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String,
reflect.Bool:
return true
case reflect.Slice:
return objType.Elem().Kind() == reflect.Uint8 //[]byte
case reflect.Struct:
return objType == timeType || objType == uuidType // time.Time || uuid.UUID
}
return false
}
func isIntegerType(value reflect.Type) bool {
switch value {
case int8Type, unit8Type, int16Type, uint16Type,
int32Type, uint32Type, int64Type, uint64Type:
return true return true
} }
return false return false
} }
func setReflectValue(source, destination reflect.Value) error { func tryAssign(source, destination reflect.Value) bool {
var sourceElem reflect.Value if source.Type().ConvertibleTo(destination.Type()) {
source = source.Convert(destination.Type())
}
if isIntegerType(source.Type()) && destination.Type() == boolType {
intValue := source.Int()
if intValue == 1 {
source = reflect.ValueOf(true)
} else if intValue == 0 {
source = reflect.ValueOf(false)
}
}
if source.Type().AssignableTo(destination.Type()) {
destination.Set(source)
return true
}
return false
}
func setReflectValue(source, destination reflect.Value) {
if tryAssign(source, destination) {
return
}
if destination.Kind() == reflect.Ptr { if destination.Kind() == reflect.Ptr {
if source.Kind() == reflect.Ptr { if source.Kind() == reflect.Ptr {
sourceElem = source if !source.IsNil() {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
} else {
return
}
} else { } else {
if source.CanAddr() { if source.CanAddr() {
sourceElem = source.Addr() source = source.Addr()
} else { } else {
sourceCopy := reflect.New(source.Type()) sourceCopy := reflect.New(source.Type())
sourceCopy.Elem().Set(source) sourceCopy.Elem().Set(source)
sourceElem = sourceCopy source = sourceCopy
}
if tryAssign(source, destination) {
return
}
if tryAssign(source.Elem(), destination.Elem()) {
return
} }
} }
} else { } else {
if source.Kind() == reflect.Ptr { if source.Kind() == reflect.Ptr {
sourceElem = source.Elem() if source.IsNil() {
} else { return
sourceElem = source }
source = source.Elem()
}
if tryAssign(source, destination) {
return
} }
} }
if !sourceElem.Type().AssignableTo(destination.Type()) { panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String())
return errors.New("jet: can't set " + sourceElem.Type().String() + " to " + destination.Type().String())
}
destination.Set(sourceElem)
return nil
} }
func createScanValue(columnTypes []*sql.ColumnType) []interface{} { func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
@ -555,35 +600,49 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
return values return values
} }
var nullFloatType = reflect.TypeOf(internal.NullFloat32{}) var boolType = reflect.TypeOf(true)
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{}) var int8Type = reflect.TypeOf(int8(1))
var unit8Type = reflect.TypeOf(uint8(1))
var int16Type = reflect.TypeOf(int16(1))
var uint16Type = reflect.TypeOf(uint16(1))
var int32Type = reflect.TypeOf(int32(1))
var uint32Type = reflect.TypeOf(uint32(1))
var int64Type = reflect.TypeOf(int64(1))
var uint64Type = reflect.TypeOf(uint64(1))
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullInt8Type = reflect.TypeOf(internal.NullInt8{})
var nullInt16Type = reflect.TypeOf(internal.NullInt16{}) var nullInt16Type = reflect.TypeOf(internal.NullInt16{})
var nullInt32Type = reflect.TypeOf(internal.NullInt32{}) var nullInt32Type = reflect.TypeOf(internal.NullInt32{})
var nullInt64Type = reflect.TypeOf(sql.NullInt64{}) var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{})
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
var nullStringType = reflect.TypeOf(sql.NullString{}) var nullStringType = reflect.TypeOf(sql.NullString{})
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullTimeType = reflect.TypeOf(internal.NullTime{}) var nullTimeType = reflect.TypeOf(internal.NullTime{})
var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{}) var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{})
func newScanType(columnType *sql.ColumnType) reflect.Type { func newScanType(columnType *sql.ColumnType) reflect.Type {
switch columnType.DatabaseTypeName() { switch columnType.DatabaseTypeName() {
case "INT2": case "TINYINT":
return nullInt8Type
case "INT2", "SMALLINT", "YEAR":
return nullInt16Type return nullInt16Type
case "INT4": case "INT4", "MEDIUMINT", "INT":
return nullInt32Type return nullInt32Type
case "INT8": case "INT8", "BIGINT":
return nullInt64Type return nullInt64Type
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML": case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
return nullStringType return nullStringType
case "FLOAT4": case "FLOAT4":
return nullFloatType return nullFloat32Type
case "FLOAT8", "NUMERIC", "DECIMAL": case "FLOAT8", "NUMERIC", "DECIMAL", "FLOAT", "DOUBLE":
return nullFloat64Type return nullFloat64Type
case "BOOL": case "BOOL":
return nullBoolType return nullBoolType
case "BYTEA": case "BYTEA", "BINARY", "VARBINARY", "BLOB":
return nullByteArrayType return nullByteArrayType
case "DATE", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ": case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
return nullTimeType return nullTimeType
default: default:
return nullStringType return nullStringType
@ -697,7 +756,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl
field := structType.Field(i) field := structType.Field(i)
newTypeName, fieldName := getTypeAndFieldName(typeName, field) newTypeName, fieldName := getTypeAndFieldName(typeName, field)
if !isGoBaseType(field.Type) { if !isSimpleModelType(field.Type) {
var structType reflect.Type var structType reflect.Type
if field.Type.Kind() == reflect.Struct { if field.Type.Kind() == reflect.Struct {
structType = field.Type structType = field.Type
@ -749,7 +808,7 @@ func (s *scanContext) rowElem(index int) interface{} {
valuer, ok := s.row[index].(driver.Valuer) valuer, ok := s.row[index].(driver.Valuer)
if !ok { if !ok {
panic("Scan value doesn't implement driver.Valuer") panic("jet: internal error, scan value doesn't implement driver.Valuer")
} }
value, err := valuer.Value() value, err := valuer.Value()
@ -791,3 +850,11 @@ func indirectType(reflectType reflect.Type) reflect.Type {
} }
return reflectType.Elem() return reflectType.Elem()
} }
func fieldToString(field *reflect.StructField) string {
if field == nil {
return ""
}
return " at '" + field.Name + " " + field.Type.String() + "'"
}

View file

@ -2,9 +2,12 @@ package internal
import ( import (
"database/sql/driver" "database/sql/driver"
"strconv"
"time" "time"
) )
//===============================================================//
// NullByteArray struct // NullByteArray struct
type NullByteArray struct { type NullByteArray struct {
ByteArray []byte ByteArray []byte
@ -31,6 +34,8 @@ func (nb NullByteArray) Value() (driver.Value, error) {
return nb.ByteArray, nil return nb.ByteArray, nil
} }
//===============================================================//
// NullTime struct // NullTime struct
type NullTime struct { type NullTime struct {
Time time.Time Time time.Time
@ -38,8 +43,20 @@ type NullTime struct {
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) error { func (nt *NullTime) Scan(value interface{}) (err error) {
nt.Time, nt.Valid = value.(time.Time) switch v := value.(type) {
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, nt.Valid = parseTime(string(v))
return
case string:
nt.Time, nt.Valid = parseTime(v)
return
}
nt.Valid = false
return nil return nil
} }
@ -51,24 +68,49 @@ func (nt NullTime) Value() (driver.Value, error) {
return nt.Time, nil return nt.Time, nil
} }
// NullInt32 struct const formatTime = "2006-01-02 15:04:05.999999"
type NullInt32 struct {
Int32 int32 func parseTime(timeStr string) (t time.Time, valid bool) {
Valid bool // Valid is true if Int64 is not NULL
var format string
switch len(timeStr) {
case 8:
format = formatTime[11:19]
case 10, 19, 21, 22, 23, 24, 25, 26:
format = formatTime[:len(timeStr)]
default:
return t, false
}
t, err := time.Parse(format, timeStr)
return t, err == nil
}
//===============================================================//
// NullInt8 struct
type NullInt8 struct {
Int8 int8
Valid bool
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (n *NullInt32) Scan(value interface{}) error { func (n *NullInt8) Scan(value interface{}) error {
switch v := value.(type) { switch v := value.(type) {
case int64: case int64:
n.Int32, n.Valid = int32(v), true n.Int8, n.Valid = int8(v), true
return nil return nil
case int32: case int8:
n.Int32, n.Valid = v, true n.Int8, n.Valid = v, true
return nil
case uint8:
n.Int32, n.Valid = int32(v), true
return nil return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 8)
if err == nil {
n.Int8, n.Valid = int8(intV), true
return nil
}
} }
n.Valid = false n.Valid = false
@ -77,21 +119,24 @@ func (n *NullInt32) Scan(value interface{}) error {
} }
// Value implements the driver Valuer interface. // Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) { func (n NullInt8) Value() (driver.Value, error) {
if !n.Valid { if !n.Valid {
return nil, nil return nil, nil
} }
return n.Int32, nil return n.Int8, nil
} }
//===============================================================//
// NullInt16 struct // NullInt16 struct
type NullInt16 struct { type NullInt16 struct {
Int16 int16 Int16 int16
Valid bool // Valid is true if Int64 is not NULL Valid bool
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (n *NullInt16) Scan(value interface{}) error { func (n *NullInt16) Scan(value interface{}) error {
switch v := value.(type) { switch v := value.(type) {
case int64: case int64:
n.Int16, n.Valid = int16(v), true n.Int16, n.Valid = int16(v), true
@ -99,9 +144,18 @@ func (n *NullInt16) Scan(value interface{}) error {
case int16: case int16:
n.Int16, n.Valid = v, true n.Int16, n.Valid = v, true
return nil return nil
case int8:
n.Int16, n.Valid = int16(v), true
return nil
case uint8: case uint8:
n.Int16, n.Valid = int16(v), true n.Int16, n.Valid = int16(v), true
return nil return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 16)
if err == nil {
n.Int16, n.Valid = int16(intV), true
return nil
}
} }
n.Valid = false n.Valid = false
@ -117,10 +171,63 @@ func (n NullInt16) Value() (driver.Value, error) {
return n.Int16, nil return n.Int16, nil
} }
//===============================================================//
// NullInt32 struct
type NullInt32 struct {
Int32 int32
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt32) Scan(value interface{}) error {
switch v := value.(type) {
case int64:
n.Int32, n.Valid = int32(v), true
return nil
case int32:
n.Int32, n.Valid = v, true
return nil
case int16:
n.Int32, n.Valid = int32(v), true
return nil
case uint16:
n.Int32, n.Valid = int32(v), true
return nil
case int8:
n.Int32, n.Valid = int32(v), true
return nil
case uint8:
n.Int32, n.Valid = int32(v), true
return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 32)
if err == nil {
n.Int32, n.Valid = int32(intV), true
return nil
}
}
n.Valid = false
return nil
}
// Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int32, nil
}
//===============================================================//
// NullFloat32 struct // NullFloat32 struct
type NullFloat32 struct { type NullFloat32 struct {
Float32 float32 Float32 float32
Valid bool // Valid is true if Int64 is not NULL Valid bool
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.

View file

@ -1,194 +0,0 @@
package jet
import (
"errors"
)
// Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface {
clause
projection
groupByClause
orderByClause
// Test expression whether it is a NULL value.
IS_NULL() BoolExpression
// Test expression whether it is a non-NULL value.
IS_NOT_NULL() BoolExpression
// Check if this expressions matches any in expressions list
IN(expressions ...Expression) BoolExpression
// Check if this expressions is different of all expressions in expressions list
NOT_IN(expressions ...Expression) BoolExpression
// The temporary alias name to assign to the expression
AS(alias string) projection
// Expression will be used to sort query result in ascending order
ASC() orderByClause
// Expression will be used to sort query result in ascending order
DESC() orderByClause
}
type expressionInterfaceImpl struct {
parent Expression
}
func (e *expressionInterfaceImpl) from(subQuery SelectTable) projection {
return e.parent
}
func (e *expressionInterfaceImpl) IS_NULL() BoolExpression {
return newPostifxBoolExpression(e.parent, "IS NULL")
}
func (e *expressionInterfaceImpl) IS_NOT_NULL() BoolExpression {
return newPostifxBoolExpression(e.parent, "IS NOT NULL")
}
func (e *expressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.parent, WRAP(expressions...), "IN")
}
func (e *expressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.parent, WRAP(expressions...), "NOT IN")
}
func (e *expressionInterfaceImpl) AS(alias string) projection {
return newAlias(e.parent, alias)
}
func (e *expressionInterfaceImpl) ASC() orderByClause {
return newOrderByClause(e.parent, true)
}
func (e *expressionInterfaceImpl) DESC() orderByClause {
return newOrderByClause(e.parent, false)
}
func (e *expressionInterfaceImpl) serializeForGroupBy(statement statementType, out *sqlBuilder) error {
return e.parent.serialize(statement, out, noWrap)
}
func (e *expressionInterfaceImpl) serializeForProjection(statement statementType, out *sqlBuilder) error {
return e.parent.serialize(statement, out, noWrap)
}
func (e *expressionInterfaceImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
return e.parent.serialize(statement, out, noWrap)
}
// Representation of binary operations (e.g. comparisons, arithmetic)
type binaryOpExpression struct {
lhs, rhs Expression
operator string
}
func newBinaryExpression(lhs, rhs Expression, operator string) binaryOpExpression {
binaryExpression := binaryOpExpression{
lhs: lhs,
rhs: rhs,
operator: operator,
}
return binaryExpression
}
func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if c == nil {
return errors.New("jet: binary Expression is nil")
}
if c.lhs == nil {
return errors.New("jet: nil lhs")
}
if c.rhs == nil {
return errors.New("jet: nil rhs")
}
wrap := !contains(options, noWrap)
if wrap {
out.writeString("(")
}
if err := c.lhs.serialize(statement, out); err != nil {
return err
}
out.writeString(c.operator)
if err := c.rhs.serialize(statement, out); err != nil {
return err
}
if wrap {
out.writeString(")")
}
return nil
}
// A prefix operator Expression
type prefixOpExpression struct {
expression Expression
operator string
}
func newPrefixExpression(expression Expression, operator string) prefixOpExpression {
prefixExpression := prefixOpExpression{
expression: expression,
operator: operator,
}
return prefixExpression
}
func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil {
return errors.New("jet: Prefix Expression is nil")
}
out.writeString(p.operator + " ")
if p.expression == nil {
return errors.New("jet: nil prefix Expression")
}
if err := p.expression.serialize(statement, out); err != nil {
return err
}
return nil
}
// A postifx operator Expression
type postfixOpExpression struct {
expression Expression
operator string
}
func newPostfixOpExpression(expression Expression, operator string) postfixOpExpression {
postfixOpExpression := postfixOpExpression{
expression: expression,
operator: operator,
}
return postfixOpExpression
}
func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil {
return errors.New("jet: Postifx operator Expression is nil")
}
if p.expression == nil {
return errors.New("jet: nil prefix Expression")
}
if err := p.expression.serialize(statement, out); err != nil {
return err
}
out.writeString(p.operator)
return nil
}

View file

@ -0,0 +1,174 @@
package metadata
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/internal/utils"
"strings"
)
// ColumnMetaData struct
type ColumnMetaData struct {
Name string
IsNullable bool
DataType string
EnumName string
IsUnsigned bool
SqlBuilderColumnType string
GoBaseType string
GoModelType string
}
func NewColumnMetaData(name string, isNullable bool, dataType string, enumName string, isUnsigned bool) ColumnMetaData {
columnMetaData := ColumnMetaData{
Name: name,
IsNullable: isNullable,
DataType: dataType,
EnumName: enumName,
IsUnsigned: isUnsigned,
}
columnMetaData.SqlBuilderColumnType = columnMetaData.getSqlBuilderColumnType()
columnMetaData.GoBaseType = columnMetaData.getGoBaseType()
columnMetaData.GoModelType = columnMetaData.getGoModelType()
return columnMetaData
}
// getSqlBuilderColumnType returns type of jet sql builder column
func (c ColumnMetaData) getSqlBuilderColumnType() string {
switch c.DataType {
case "boolean":
return "Bool"
case "smallint", "integer", "bigint",
"tinyint", "mediumint", "int", "year": //MySQL
return "Integer"
case "date":
return "Date"
case "timestamp without time zone",
"timestamp", "datetime": //MySQL:
return "Timestamp"
case "timestamp with time zone":
return "Timestampz"
case "time without time zone",
"time": //MySQL
return "Time"
case "time with time zone":
return "Timez"
case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "interval", "line", "ARRAY",
"char", "varchar", "binary", "varbinary",
"tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL
return "String"
case "real", "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "Float"
default:
fmt.Println("- [SQL Builder] Unsupported sql column '" + c.Name + " " + c.DataType + "', using StringColumn instead.")
return "String"
}
}
// getGoBaseType returns model type for column info.
func (c ColumnMetaData) getGoBaseType() string {
switch c.DataType {
case "USER-DEFINED", "enum":
return utils.ToGoIdentifier(c.EnumName)
case "boolean":
return "bool"
case "tinyint":
return "int8"
case "smallint",
"year":
return "int16"
case "integer",
"mediumint", "int": //MySQL
return "int32"
case "bigint":
return "int64"
case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone",
"timestamp", "datetime", "time": // MySQL
return "time.Time"
case "bytea",
"binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": //MySQL
return "[]byte"
case "text", "character", "character varying", "tsvector", "bit", "bit varying", "money", "json", "jsonb",
"xml", "point", "interval", "line", "ARRAY",
"char", "varchar", "tinytext", "mediumtext", "longtext": // MySQL
return "string"
case "real":
return "float32"
case "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "float64"
case "uuid":
return "uuid.UUID"
default:
fmt.Println("- [Model ] Unsupported sql column '" + c.Name + " " + c.DataType + "', using string instead.")
return "string"
}
}
// GoModelType returns model type for column info with optional pointer if
// column can be NULL.
func (c ColumnMetaData) getGoModelType() string {
typeStr := c.GoBaseType
if strings.Contains(typeStr, "int") && c.IsUnsigned {
typeStr = "u" + typeStr
}
if c.IsNullable {
return "*" + typeStr
}
return typeStr
}
// GoModelTag returns model field tag for column
func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string {
tags := []string{}
if isPrimaryKey {
tags = append(tags, "primary_key")
}
if len(tags) > 0 {
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
}
return ""
}
func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) ([]ColumnMetaData, error) {
rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []ColumnMetaData{}
for rows.Next() {
var name, isNullable, dataType, enumName string
var isUnsigned bool
err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned)
if err != nil {
return nil, err
}
ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned))
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -0,0 +1,140 @@
package metadata
import (
"database/sql"
"strings"
)
type DialectQuerySet interface {
ListOfTablesQuery() string
PrimaryKeysQuery() string
ListOfColumnsQuery() string
ListOfEnumsQuery() string
GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaData, error)
}
type PostgresQuerySet struct{}
func (p *PostgresQuerySet) ListOfTablesQuery() string {
return `
SELECT table_name
FROM information_schema.tables
where table_schema = $1 and table_type = 'BASE TABLE';
`
}
func (p *PostgresQuerySet) PrimaryKeysQuery() string {
return `
SELECT c.column_name
FROM information_schema.key_column_usage AS c
LEFT JOIN information_schema.table_constraints AS t
ON t.constraint_name = c.constraint_name
WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY';
`
}
func (p *PostgresQuerySet) ListOfColumnsQuery() string {
return `
SELECT column_name, is_nullable, data_type, udt_name, FALSE
FROM information_schema.columns
where table_schema = $1 and table_name = $2
order by ordinal_position;`
}
func (p *PostgresQuerySet) ListOfEnumsQuery() string {
return `
SELECT t.typname,
e.enumlabel
FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;`
}
func (p *PostgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaData, error) {
return getEnumInfos(db, p, schemaName)
}
// =======================================================================//
type MySqlQuerySet struct{}
func (m *MySqlQuerySet) ListOfTablesQuery() string {
return `
SELECT table_name
FROM INFORMATION_SCHEMA.tables
WHERE table_schema = ? and table_type = 'BASE TABLE';
`
}
func (m *MySqlQuerySet) PrimaryKeysQuery() string {
return `
SELECT k.column_name
FROM information_schema.table_constraints t
JOIN information_schema.key_column_usage k
USING(constraint_name,table_schema,table_name)
WHERE t.constraint_type='PRIMARY KEY'
AND t.table_schema= ?
AND t.table_name= ?;
`
}
func (m *MySqlQuerySet) ListOfColumnsQuery() string {
return `
SELECT COLUMN_NAME,
IS_NULLABLE, IF(COLUMN_TYPE = 'tinyint(1)', 'boolean', DATA_TYPE),
IF(DATA_TYPE = 'enum', CONCAT(TABLE_NAME, '_', COLUMN_NAME), ''),
COLUMN_TYPE LIKE '%unsigned%'
FROM information_schema.columns
WHERE table_schema = ? and table_name = ?
ORDER BY ordinal_position;
`
}
func (m *MySqlQuerySet) ListOfEnumsQuery() string {
return `
SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5)
FROM information_schema.columns as c
INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name)
WHERE c.table_schema = ? AND DATA_TYPE = 'enum' AND t.TABLE_TYPE = 'BASE TABLE';
`
}
func (m *MySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaData, error) {
rows, err := db.Query(m.ListOfEnumsQuery(), schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []MetaData{}
for rows.Next() {
var enumName string
var enumValues string
err = rows.Scan(&enumName, &enumValues)
if err != nil {
return nil, err
}
enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1)
ret = append(ret, EnumMetaData{
name: enumName,
Values: strings.Split(enumValues, ","),
})
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -1,32 +1,23 @@
package postgresmeta package metadata
import ( import (
"database/sql" "database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
) )
// EnumInfo struct // EnumMetaData struct
type EnumInfo struct { type EnumMetaData struct {
name string name string
Values []string Values []string
} }
// Name returns enum name // Name returns enum name
func (e EnumInfo) Name() string { func (e EnumMetaData) Name() string {
return e.name return e.name
} }
func getEnumInfos(db *sql.DB, schemaName string) ([]metadata.MetaData, error) { func getEnumInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]MetaData, error) {
query := `
SELECT t.typname,
e.enumlabel
FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;`
rows, err := db.Query(query, schemaName) rows, err := db.Query(querySet.ListOfEnumsQuery(), schemaName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -55,10 +46,10 @@ ORDER BY n.nspname, t.typname, e.enumsortorder;`
return nil, err return nil, err
} }
ret := []metadata.MetaData{} ret := []MetaData{}
for enumName, enumValues := range enumsInfosMap { for enumName, enumValues := range enumsInfosMap {
ret = append(ret, EnumInfo{ ret = append(ret, EnumMetaData{
enumName, enumName,
enumValues, enumValues,
}) })

View file

@ -1,142 +0,0 @@
package postgresmeta
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/internal/utils"
"strings"
)
// ColumnInfo metadata struct
type ColumnInfo struct {
Name string
IsNullable bool
DataType string
EnumName string
}
// SqlBuilderColumnType returns type of jet sql builder column
func (c ColumnInfo) SqlBuilderColumnType() string {
switch c.DataType {
case "boolean":
return "Bool"
case "smallint", "integer", "bigint":
return "Integer"
case "date":
return "Date"
case "timestamp without time zone":
return "Timestamp"
case "timestamp with time zone":
return "Timestampz"
case "time without time zone":
return "Time"
case "time with time zone":
return "Timez"
case "USER-DEFINED", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "interval", "line", "ARRAY":
return "String"
case "real", "numeric", "decimal", "double precision":
return "Float"
default:
fmt.Println("Unsupported sql type: " + c.DataType + ", using string column instead for sql builder.")
return "String"
}
}
// GoBaseType returns model type for column info.
func (c ColumnInfo) GoBaseType() string {
switch c.DataType {
case "USER-DEFINED":
return utils.ToGoIdentifier(c.EnumName)
case "boolean":
return "bool"
case "smallint":
return "int16"
case "integer":
return "int32"
case "bigint":
return "int64"
case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone":
return "time.Time"
case "bytea":
return "[]byte"
case "text", "character", "character varying", "tsvector", "bit", "bit varying", "money", "json", "jsonb",
"xml", "point", "interval", "line", "ARRAY":
return "string"
case "real":
return "float32"
case "numeric", "decimal", "double precision":
return "float64"
case "uuid":
return "uuid.UUID"
default:
fmt.Println("Unsupported sql type: " + c.DataType + ", " + c.EnumName + ", using string instead for model type.")
return "string"
}
}
// GoModelType returns model type for column info with optional pointer if
// column can be NULL.
func (c ColumnInfo) GoModelType() string {
typeStr := c.GoBaseType()
if c.IsNullable {
return "*" + typeStr
}
return typeStr
}
// GoModelTag returns model field tag for column
func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string {
tags := []string{}
if isPrimaryKey {
tags = append(tags, "primary_key")
}
if len(tags) > 0 {
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
}
return ""
}
func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) {
query := `
SELECT column_name, is_nullable, data_type, udt_name
FROM information_schema.columns
where table_catalog = $1 and table_schema = $2 and table_name = $3
order by ordinal_position;`
rows, err := db.Query(query, dbName, schemaName, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []ColumnInfo{}
for rows.Next() {
columnInfo := ColumnInfo{}
var isNullable string
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType, &columnInfo.EnumName)
columnInfo.IsNullable = isNullable == "YES"
if err != nil {
return nil, err
}
ret = append(ret, columnInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -1,76 +0,0 @@
package postgresmeta
import (
"database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
)
// SchemaInfo metadata struct
type SchemaInfo struct {
DatabaseName string
Name string
TableInfos []metadata.MetaData
EnumInfos []metadata.MetaData
}
// GetSchemaInfo returns schema information from db connection.
func GetSchemaInfo(db *sql.DB, databaseName, schemaName string) (schemaInfo SchemaInfo, err error) {
schemaInfo.DatabaseName = databaseName
schemaInfo.Name = schemaName
schemaInfo.TableInfos, err = getTableInfos(db, databaseName, schemaName)
if err != nil {
return
}
schemaInfo.EnumInfos, err = getEnumInfos(db, schemaName)
if err != nil {
return
}
return
}
func getTableInfos(db *sql.DB, dbName, schemaName string) ([]metadata.MetaData, error) {
query := `
SELECT table_name
FROM information_schema.tables
where table_catalog = $1 and table_schema = $2 and table_type = 'BASE TABLE';
`
rows, err := db.Query(query, dbName, schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []metadata.MetaData{}
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
return nil, err
}
tableInfo, err := GetTableInfo(db, dbName, schemaName, tableName)
if err != nil {
return nil, err
}
ret = append(ret, tableInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -0,0 +1,68 @@
package metadata
import (
"database/sql"
"fmt"
)
// SchemaMetaData struct
type SchemaMetaData struct {
TableInfos []MetaData
EnumInfos []MetaData
}
// GetSchemaInfo returns schema information from db connection.
func GetSchemaInfo(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData, err error) {
schemaInfo.TableInfos, err = getTableInfos(db, querySet, schemaName)
if err != nil {
return
}
schemaInfo.EnumInfos, err = querySet.GetEnumsMetaData(db, schemaName)
if err != nil {
return
}
fmt.Println(" FOUND", len(schemaInfo.TableInfos), "table(s), ", len(schemaInfo.EnumInfos), "enum(s)")
return
}
func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]MetaData, error) {
rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []MetaData{}
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
return nil, err
}
tableInfo, err := GetTableInfo(db, querySet, schemaName, tableName)
if err != nil {
return nil, err
}
ret = append(ret, tableInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -1,31 +1,31 @@
package postgresmeta package metadata
import ( import (
"database/sql" "database/sql"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
) )
// TableInfo metadata struct // TableMetaData metadata struct
type TableInfo struct { type TableMetaData struct {
SchemaName string SchemaName string
name string name string
PrimaryKeys map[string]bool PrimaryKeys map[string]bool
Columns []ColumnInfo Columns []ColumnMetaData
} }
// Name returns table info name // Name returns table info name
func (t TableInfo) Name() string { func (t TableMetaData) Name() string {
return t.name return t.name
} }
// IsPrimaryKey returns if column is a part of primary key // IsPrimaryKey returns if column is a part of primary key
func (t TableInfo) IsPrimaryKey(column string) bool { func (t TableMetaData) IsPrimaryKey(column string) bool {
return t.PrimaryKeys[column] return t.PrimaryKeys[column]
} }
// MutableColumns returns list of mutable columns for table // MutableColumns returns list of mutable columns for table
func (t TableInfo) MutableColumns() []ColumnInfo { func (t TableMetaData) MutableColumns() []ColumnMetaData {
ret := []ColumnInfo{} ret := []ColumnMetaData{}
for _, column := range t.Columns { for _, column := range t.Columns {
if t.IsPrimaryKey(column.Name) { if t.IsPrimaryKey(column.Name) {
@ -39,11 +39,11 @@ func (t TableInfo) MutableColumns() []ColumnInfo {
} }
// GetImports returns model imports for table. // GetImports returns model imports for table.
func (t TableInfo) GetImports() []string { func (t TableMetaData) GetImports() []string {
imports := map[string]string{} imports := map[string]string{}
for _, column := range t.Columns { for _, column := range t.Columns {
columnType := column.GoBaseType() columnType := column.GoBaseType
switch columnType { switch columnType {
case "time.Time": case "time.Time":
@ -63,22 +63,22 @@ func (t TableInfo) GetImports() []string {
} }
// GoStructName returns go struct name for sql builder // GoStructName returns go struct name for sql builder
func (t TableInfo) GoStructName() string { func (t TableMetaData) GoStructName() string {
return utils.ToGoIdentifier(t.name) + "Table" return utils.ToGoIdentifier(t.name) + "Table"
} }
// GetTableInfo returns table info metadata // GetTableInfo returns table info metadata
func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo TableInfo, err error) { func GetTableInfo(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData, err error) {
tableInfo.SchemaName = schemaName tableInfo.SchemaName = schemaName
tableInfo.name = tableName tableInfo.name = tableName
tableInfo.PrimaryKeys, err = getPrimaryKeys(db, dbName, schemaName, tableName) tableInfo.PrimaryKeys, err = getPrimaryKeys(db, querySet, schemaName, tableName)
if err != nil { if err != nil {
return return
} }
tableInfo.Columns, err = getColumnInfos(db, dbName, schemaName, tableName) tableInfo.Columns, err = getColumnsMetaData(db, querySet, schemaName, tableName)
if err != nil { if err != nil {
return return
@ -87,15 +87,9 @@ func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo T
return return
} }
func getPrimaryKeys(db *sql.DB, dbName, schemaName, tableName string) (map[string]bool, error) { func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (map[string]bool, error) {
query := `
SELECT c.column_name rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName)
FROM information_schema.key_column_usage AS c
LEFT JOIN information_schema.table_constraints AS t
ON t.constraint_name = c.constraint_name
WHERE t.table_catalog = $1 AND t.table_schema = $2 AND t.table_name = $3 AND t.constraint_type = 'PRIMARY KEY';
`
rows, err := db.Query(query, dbName, schemaName, tableName)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -0,0 +1,118 @@
package template
import (
"bytes"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/utils"
"path/filepath"
"text/template"
"time"
)
func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect jet.Dialect) error {
if len(tables) == 0 && len(enums) == 0 {
return nil
}
fmt.Println("Destination directory:", destDir)
fmt.Println("Cleaning up destination directory...")
err := utils.CleanUpGeneratedFiles(destDir)
if err != nil {
return err
}
fmt.Println("Generating table sql builder files...")
err = generate(destDir, "table", tableSQLBuilderTemplate, tables, dialect)
if err != nil {
return err
}
fmt.Println("Generating table model files...")
err = generate(destDir, "model", tableModelTemplate, tables, dialect)
if err != nil {
return err
}
if len(enums) > 0 {
fmt.Println("Generating enum sql builder files...")
err = generate(destDir, "enum", enumSQLBuilderTemplate, enums, dialect)
if err != nil {
return err
}
fmt.Println("Generating enum model files...")
err = generate(destDir, "model", enumModelTemplate, enums, dialect)
if err != nil {
return err
}
}
fmt.Println("Done")
return nil
}
func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error {
modelDirPath := filepath.Join(dirPath, packageName)
err := utils.EnsureDirPath(modelDirPath)
if err != nil {
return err
}
autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect)
if err != nil {
return err
}
for _, metaData := range metaDataList {
text, err := GenerateTemplate(template, metaData, dialect)
if err != nil {
return err
}
err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...))
if err != nil {
return err
}
}
return nil
}
// GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}, dialect1 jet.Dialect) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": utils.ToGoIdentifier,
"now": func() string {
return time.Now().Format(time.RFC850)
},
"dialect": func() jet.Dialect {
return dialect1
},
}).Parse(templateText)
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := t.Execute(&buf, templateData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View file

@ -1,4 +1,4 @@
package postgres package template
var autoGenWarningTemplate = ` var autoGenWarningTemplate = `
// //
@ -11,7 +11,7 @@ var autoGenWarningTemplate = `
` `
var sqlBuilderTableTemplate = ` var tableSQLBuilderTemplate = `
{{define "column-list" -}} {{define "column-list" -}}
{{- range $i, $c := . }} {{- range $i, $c := . }}
{{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column {{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column
@ -21,21 +21,21 @@ var sqlBuilderTableTemplate = `
package table package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet/{{dialect.PackageName}}"
) )
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}() var {{ToGoIdentifier .Name}} = new{{.GoStructName}}()
type {{.GoStructName}} struct { type {{.GoStructName}} struct {
jet.Table {{dialect.PackageName}}.Table
//Columns //Columns
{{- range .Columns}} {{- range .Columns}}
{{ToGoIdentifier .Name}} jet.Column{{.SqlBuilderColumnType}} {{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
{{- end}} {{- end}}
AllColumns jet.ColumnList AllColumns {{dialect.PackageName}}.IColumnList
MutableColumns jet.ColumnList MutableColumns {{dialect.PackageName}}.IColumnList
} }
// creates new {{.GoStructName}} with assigned alias // creates new {{.GoStructName}} with assigned alias
@ -50,26 +50,26 @@ func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
func new{{.GoStructName}}() *{{.GoStructName}} { func new{{.GoStructName}}() *{{.GoStructName}} {
var ( var (
{{- range .Columns}} {{- range .Columns}}
{{ToGoIdentifier .Name}}Column = jet.{{.SqlBuilderColumnType}}Column("{{.Name}}") {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}} {{- end}}
) )
return &{{.GoStructName}}{ return &{{.GoStructName}}{
Table: jet.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}), Table: {{dialect.PackageName}}.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}),
//Columns //Columns
{{- range .Columns}} {{- range .Columns}}
{{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column, {{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column,
{{- end}} {{- end}}
AllColumns: jet.ColumnList{ {{template "column-list" .Columns}} }, AllColumns: {{dialect.PackageName}}.ColumnList( {{template "column-list" .Columns}} ),
MutableColumns: jet.ColumnList{ {{template "column-list" .MutableColumns}} }, MutableColumns: {{dialect.PackageName}}.ColumnList( {{template "column-list" .MutableColumns}} ),
} }
} }
` `
var dataModelTemplate = `package model var tableModelTemplate = `package model
{{ if .GetImports }} {{ if .GetImports }}
import ( import (
@ -85,6 +85,22 @@ type {{ToGoIdentifier .Name}} struct {
{{ToGoIdentifier .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + ` {{ToGoIdentifier .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + `
{{- end}} {{- end}}
} }
`
var enumSQLBuilderTemplate = `package enum
import "github.com/go-jet/jet/{{dialect.PackageName}}"
var {{ToGoIdentifier $.Name}} = &struct {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}} {{dialect.PackageName}}.StringExpression
{{- end}}
} {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}}: {{dialect.PackageName}}.NewEnumValue("{{$element}}"),
{{- end}}
}
` `
var enumModelTemplate = `package model var enumModelTemplate = `package model
@ -121,17 +137,3 @@ func (e {{ToGoIdentifier $.Name}}) String() string {
} }
` `
var enumTypeTemplate = `package enum
import "github.com/go-jet/jet"
var {{ToGoIdentifier $.Name}} = &struct {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}} jet.StringExpression
{{- end}}
} {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}}: jet.NewEnumValue("{{$element}}"),
{{- end}}
}
`

View file

@ -0,0 +1,71 @@
package mysql
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/template"
"github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/mysql"
"path"
)
type DBConnection struct {
Host string
Port int
User string
Password string
SslMode string
Params string
DBName string
}
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
db, err := openConnection(dbConn)
if err != nil {
return err
}
defer utils.DBClose(db)
fmt.Println("Retrieving database information...")
// No schemas in MySQL
dbInfo, err := metadata.GetSchemaInfo(db, dbConn.DBName, &metadata.MySqlQuerySet{})
if err != nil {
return err
}
genPath := path.Join(destDir, dbConn.DBName)
err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, mysql.Dialect)
if err != nil {
return err
}
return nil
}
func openConnection(dbConn DBConnection) (*sql.DB, error) {
var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName)
if dbConn.Params != "" {
connectionString += "?" + dbConn.Params
}
db, err := sql.Open("mysql", connectionString)
fmt.Println("Connecting to MySQL database: " + connectionString)
if err != nil {
return nil, err
}
err = db.Ping()
if err != nil {
return nil, err
}
return db, nil
}

View file

@ -1,134 +0,0 @@
package postgres
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/metadata/postgresmeta"
"github.com/go-jet/jet/internal/utils"
"path"
"path/filepath"
"strconv"
)
// DBConnection contains postgres connection details
type DBConnection struct {
Host string
Port int
User string
Password string
SslMode string
Params string
DBName string
SchemaName string
}
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params)
fmt.Println("Connecting to postgres database: " + connectionString)
db, err := sql.Open("postgres", connectionString)
if err != nil {
return err
}
defer db.Close()
err = db.Ping()
if err != nil {
return err
}
fmt.Println("Retrieving schema information...")
schemaInfo, err := postgresmeta.GetSchemaInfo(db, dbConn.DBName, dbConn.SchemaName)
if err != nil {
return err
}
fmt.Println(" FOUND", len(schemaInfo.TableInfos), "table(s), ", len(schemaInfo.EnumInfos), "enum(s)")
if len(schemaInfo.TableInfos) == 0 && len(schemaInfo.EnumInfos) == 0 {
return nil
}
schemaGenPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
fmt.Println("Destination directory:", schemaGenPath)
fmt.Println("Cleaning up destination directory...")
err = utils.CleanUpGeneratedFiles(schemaGenPath)
if err != nil {
return err
}
fmt.Println("Generating table sql builder files...")
err = generate(schemaInfo, destDir, "table", sqlBuilderTableTemplate, schemaInfo.TableInfos)
if err != nil {
return err
}
fmt.Println("Generating table model files...")
err = generate(schemaInfo, destDir, "model", dataModelTemplate, schemaInfo.TableInfos)
if err != nil {
return err
}
if len(schemaInfo.EnumInfos) > 0 {
fmt.Println("Generating enum sql builder files...")
err = generate(schemaInfo, destDir, "enum", enumTypeTemplate, schemaInfo.EnumInfos)
if err != nil {
return err
}
fmt.Println("Generating enum model files...")
err = generate(schemaInfo, destDir, "model", enumModelTemplate, schemaInfo.EnumInfos)
if err != nil {
return err
}
}
fmt.Println("Done")
return nil
}
func generate(schemaInfo postgresmeta.SchemaInfo, dirPath, packageName string, template string, metaDataList []metadata.MetaData) error {
modelDirPath := filepath.Join(dirPath, schemaInfo.DatabaseName, schemaInfo.Name, packageName)
err := utils.EnsureDirPath(modelDirPath)
if err != nil {
return err
}
autoGenWarning, err := utils.GenerateTemplate(autoGenWarningTemplate, nil)
if err != nil {
return err
}
for _, metaData := range metaDataList {
text, err := utils.GenerateTemplate(template, metaData)
if err != nil {
return err
}
err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...))
if err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,73 @@
package postgres
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/template"
"github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/postgres"
"path"
"strconv"
)
// DBConnection contains postgres connection details
type DBConnection struct {
Host string
Port int
User string
Password string
SslMode string
Params string
DBName string
SchemaName string
}
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
db, err := openConnection(dbConn)
defer utils.DBClose(db)
if err != nil {
return err
}
fmt.Println("Retrieving schema information...")
schemaInfo, err := metadata.GetSchemaInfo(db, dbConn.SchemaName, &metadata.PostgresQuerySet{})
if err != nil {
return err
}
genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, postgres.Dialect)
if err != nil {
return err
}
return nil
}
func openConnection(dbConn DBConnection) (*sql.DB, error) {
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params)
fmt.Println("Connecting to postgres database: " + connectionString)
db, err := sql.Open("postgres", connectionString)
if err != nil {
return nil, err
}
err = db.Ping()
if err != nil {
return nil, err
}
return db, nil
}

View file

@ -1,5 +0,0 @@
package jet
type groupByClause interface {
serializeForGroupBy(statement statementType, out *sqlBuilder) error
}

View file

@ -1,170 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/utils"
)
// InsertStatement is interface for SQL INSERT statements
type InsertStatement interface {
Statement
// Insert row of values
VALUES(value interface{}, values ...interface{}) InsertStatement
// Insert row of values, where value for each column is extracted from filed of structure data.
// If data is not struct or there is no field for every column selected, this method will panic.
MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
RETURNING(projections ...projection) InsertStatement
}
func newInsertStatement(t WritableTable, columns []column) InsertStatement {
return &insertStatementImpl{
table: t,
columns: columns,
}
}
type insertStatementImpl struct {
table WritableTable
columns []column
rows [][]clause
query SelectStatement
returning []projection
}
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
i.rows = append(i.rows, unwindRowFromValues(value, values))
return i
}
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
i.rows = append(i.rows, unwindRowFromModel(i.getColumns(), data))
return i
}
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
i.rows = append(i.rows, unwindRowsFromModels(i.getColumns(), data)...)
return i
}
func (i *insertStatementImpl) RETURNING(projections ...projection) InsertStatement {
i.returning = projections
return i
}
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
i.query = selectStatement
return i
}
func (i *insertStatementImpl) getColumns() []column {
if len(i.columns) > 0 {
return i.columns
}
return i.table.columns()
}
func (i *insertStatementImpl) DebugSql() (query string, err error) {
return debugSql(i)
}
func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) {
queryData := &sqlBuilder{}
queryData.newLine()
queryData.writeString("INSERT INTO")
if utils.IsNil(i.table) {
return "", nil, errors.New("jet: table is nil")
}
err = i.table.serialize(insertStatement, queryData)
if err != nil {
return
}
if len(i.columns) > 0 {
queryData.writeString("(")
err = serializeColumnNames(i.columns, queryData)
if err != nil {
return
}
queryData.writeString(")")
}
if len(i.rows) == 0 && i.query == nil {
return "", nil, errors.New("jet: no row values or query specified")
}
if len(i.rows) > 0 && i.query != nil {
return "", nil, errors.New("jet: only row values or query has to be specified")
}
if len(i.rows) > 0 {
queryData.writeString("VALUES")
for rowIndex, row := range i.rows {
if rowIndex > 0 {
queryData.writeString(",")
}
queryData.increaseIdent()
queryData.newLine()
queryData.writeString("(")
err = serializeClauseList(insertStatement, row, queryData)
if err != nil {
return "", nil, err
}
queryData.writeByte(')')
queryData.decreaseIdent()
}
}
if i.query != nil {
err = i.query.serialize(insertStatement, queryData)
if err != nil {
return
}
}
if err = queryData.writeReturning(insertStatement, i.returning); err != nil {
return
}
sql, args = queryData.finalize()
return
}
func (i *insertStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(i, db, destination)
}
func (i *insertStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, i, db, destination)
}
func (i *insertStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(i, db)
}
func (i *insertStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, i, db)
}

View file

@ -1,13 +0,0 @@
package snaker
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestDb(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Snaker Suite")
}

View file

@ -1,40 +1,16 @@
package snaker package snaker
import ( import (
. "github.com/onsi/ginkgo" "gotest.tools/assert"
. "github.com/onsi/gomega" "testing"
) )
var _ = Describe("Snaker", func() { func TestSnakeToCamel(t *testing.T) {
assert.Equal(t, SnakeToCamel(""), "")
Describe("SnakeToCamel test", func() { assert.Equal(t, SnakeToCamel("potato_"), "Potato")
It("should return an empty string on an empty input", func() { assert.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased")
Expect(SnakeToCamel("")).To(Equal("")) assert.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID")
}) assert.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier")
assert.Equal(t, SnakeToCamel("id"), "ID")
It("should not blow up on trailing _", func() { assert.Equal(t, SnakeToCamel("oauth_client"), "OAuthClient")
Expect(SnakeToCamel("potato_")).To(Equal("Potato")) }
})
It("should return a snaked text as camel case", func() {
Expect(SnakeToCamel("this_has_to_be_uppercased")).To(
Equal("ThisHasToBeUppercased"))
})
It("should return a snaked text as camel case, except the word ID", func() {
Expect(SnakeToCamel("this_is_an_id")).To(Equal("ThisIsAnID"))
})
It("should return 'id' not as uppercase", func() {
Expect(SnakeToCamel("this_is_an_identifier")).To(Equal("ThisIsAnIdentifier"))
})
It("should simply work with id", func() {
Expect(SnakeToCamel("id")).To(Equal("ID"))
})
It("should work with initialism where only certain characters are uppercase", func() {
Expect(SnakeToCamel("oauth_client")).To(Equal("OAuthClient"))
})
})
})

28
internal/jet/alias.go Normal file
View file

@ -0,0 +1,28 @@
package jet
type alias struct {
expression Expression
alias string
}
func newAlias(expression Expression, aliasName string) Projection {
return &alias{
expression: expression,
alias: aliasName,
}
}
func (a *alias) fromImpl(subQuery SelectTable) Projection {
column := newColumn(a.alias, "", nil)
column.Parent = &column
column.subQuery = subQuery
return &column
}
func (a *alias) serializeForProjection(statement StatementType, out *SqlBuilder) {
a.expression.serialize(statement, out)
out.WriteString("AS")
out.WriteAlias(a.alias)
}

View file

@ -86,25 +86,25 @@ func (b *boolInterfaceImpl) IS_NOT_UNKNOWN() BoolExpression {
//---------------------------------------------------// //---------------------------------------------------//
type binaryBoolExpression struct { type binaryBoolExpression struct {
expressionInterfaceImpl ExpressionInterfaceImpl
boolInterfaceImpl boolInterfaceImpl
binaryOpExpression binaryOpExpression
} }
func newBinaryBoolOperator(lhs, rhs Expression, operator string) BoolExpression { func newBinaryBoolOperator(lhs, rhs Expression, operator string, additionalParams ...Expression) BoolExpression {
boolExpression := binaryBoolExpression{} binaryBoolExpression := binaryBoolExpression{}
boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) binaryBoolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator, additionalParams...)
boolExpression.expressionInterfaceImpl.parent = &boolExpression binaryBoolExpression.ExpressionInterfaceImpl.Parent = &binaryBoolExpression
boolExpression.boolInterfaceImpl.parent = &boolExpression binaryBoolExpression.boolInterfaceImpl.parent = &binaryBoolExpression
return &boolExpression return &binaryBoolExpression
} }
//---------------------------------------------------// //---------------------------------------------------//
type prefixBoolExpression struct { type prefixBoolExpression struct {
expressionInterfaceImpl ExpressionInterfaceImpl
boolInterfaceImpl boolInterfaceImpl
prefixOpExpression prefixOpExpression
@ -114,7 +114,7 @@ func newPrefixBoolOperator(expression Expression, operator string) BoolExpressio
exp := prefixBoolExpression{} exp := prefixBoolExpression{}
exp.prefixOpExpression = newPrefixExpression(expression, operator) exp.prefixOpExpression = newPrefixExpression(expression, operator)
exp.expressionInterfaceImpl.parent = &exp exp.ExpressionInterfaceImpl.Parent = &exp
exp.boolInterfaceImpl.parent = &exp exp.boolInterfaceImpl.parent = &exp
return &exp return &exp
@ -122,7 +122,7 @@ func newPrefixBoolOperator(expression Expression, operator string) BoolExpressio
//---------------------------------------------------// //---------------------------------------------------//
type postfixBoolOpExpression struct { type postfixBoolOpExpression struct {
expressionInterfaceImpl ExpressionInterfaceImpl
boolInterfaceImpl boolInterfaceImpl
postfixOpExpression postfixOpExpression
@ -132,7 +132,7 @@ func newPostifxBoolExpression(expression Expression, operator string) BoolExpres
exp := postfixBoolOpExpression{} exp := postfixBoolOpExpression{}
exp.postfixOpExpression = newPostfixOpExpression(expression, operator) exp.postfixOpExpression = newPostfixOpExpression(expression, operator)
exp.expressionInterfaceImpl.parent = &exp exp.ExpressionInterfaceImpl.Parent = &exp
exp.boolInterfaceImpl.parent = &exp exp.boolInterfaceImpl.parent = &exp
return &exp return &exp

View file

@ -5,9 +5,8 @@ import (
) )
func TestBoolExpressionEQ(t *testing.T) { func TestBoolExpressionEQ(t *testing.T) {
assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: nil rhs")
assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)") assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)")
assertClauseSerialize(t, table1ColBool.EQ(Bool(true)), "(table1.col_bool = $1)", true) assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: rhs is nil for '=' operator")
} }
func TestBoolExpressionNOT_EQ(t *testing.T) { func TestBoolExpressionNOT_EQ(t *testing.T) {
@ -57,6 +56,7 @@ func TestBinaryBoolExpression(t *testing.T) {
boolExpression := Int(2).EQ(Int(3)) boolExpression := Int(2).EQ(Int(3))
assertClauseSerialize(t, boolExpression, "($1 = $2)", int64(2), int64(3)) assertClauseSerialize(t, boolExpression, "($1 = $2)", int64(2), int64(3))
assertProjectionSerialize(t, boolExpression, "$1 = $2", int64(2), int64(3)) assertProjectionSerialize(t, boolExpression, "$1 = $2", int64(2), int64(3))
assertProjectionSerialize(t, boolExpression.AS("alias_eq_expression"), assertProjectionSerialize(t, boolExpression.AS("alias_eq_expression"),
`($1 = $2) AS "alias_eq_expression"`, int64(2), int64(3)) `($1 = $2) AS "alias_eq_expression"`, int64(2), int64(3))
@ -71,20 +71,6 @@ func TestBoolLiteral(t *testing.T) {
assertClauseSerialize(t, Bool(false), "$1", false) assertClauseSerialize(t, Bool(false), "$1", false)
} }
func TestExists(t *testing.T) {
assertClauseSerialize(t, EXISTS(
table2.
SELECT(Int(1)).
WHERE(table1Col1.EQ(table2Col3)),
),
`EXISTS (
SELECT $1
FROM db.table2
WHERE table1.col1 = table2.col3
)`, int64(1))
}
func TestBoolExp(t *testing.T) { func TestBoolExp(t *testing.T) {
assertClauseSerialize(t, BoolExp(String("true")), "$1", "true") assertClauseSerialize(t, BoolExp(String("true")), "$1", "true")
assertClauseSerialize(t, BoolExp(String("true")).IS_TRUE(), "$1 IS TRUE", "true") assertClauseSerialize(t, BoolExp(String("true")).IS_TRUE(), "$1 IS TRUE", "true")

51
internal/jet/cast.go Normal file
View file

@ -0,0 +1,51 @@
package jet
type Cast interface {
AS(castType string) Expression
}
type CastImpl struct {
expression Expression
}
func NewCastImpl(expression Expression) Cast {
castImpl := CastImpl{
expression: expression,
}
return &castImpl
}
func (b *CastImpl) AS(castType string) Expression {
castExp := &castExpression{
expression: b.expression,
cast: string(castType),
}
castExp.ExpressionInterfaceImpl.Parent = castExp
return castExp
}
type castExpression struct {
ExpressionInterfaceImpl
expression Expression
cast string
}
func (b *castExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
expression := b.expression
castType := b.cast
if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil {
castOverride(expression, String(castType))(statement, out, options...)
return
}
out.WriteString("CAST(")
expression.serialize(statement, out, options...)
out.WriteString("AS")
out.WriteString(castType + ")")
}

11
internal/jet/cast_test.go Normal file
View file

@ -0,0 +1,11 @@
package jet
import (
"testing"
)
func TestCastAS(t *testing.T) {
assertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST($1 AS boolean)", int64(1))
assertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)")
assertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)")
}

408
internal/jet/clause.go Normal file
View file

@ -0,0 +1,408 @@
package jet
import (
"github.com/go-jet/jet/internal/utils"
)
type Clause interface {
Serialize(statementType StatementType, out *SqlBuilder)
}
type ClauseWithProjections interface {
Clause
projections() ProjectionList
}
type ClauseSelect struct {
Distinct bool
Projections []Projection
}
func (s *ClauseSelect) projections() ProjectionList {
return s.Projections
}
func (s *ClauseSelect) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString("SELECT")
if s.Distinct {
out.WriteString("DISTINCT")
}
if len(s.Projections) == 0 {
panic("jet: SELECT clause has to have at least one projection")
}
out.WriteProjections(statementType, s.Projections)
}
type ClauseFrom struct {
Table Serializer
}
func (f *ClauseFrom) Serialize(statementType StatementType, out *SqlBuilder) {
if f.Table == nil {
return
}
out.NewLine()
out.WriteString("FROM")
out.IncreaseIdent()
f.Table.serialize(statementType, out)
out.DecreaseIdent()
}
type ClauseWhere struct {
Condition BoolExpression
Mandatory bool
}
func (c *ClauseWhere) Serialize(statementType StatementType, out *SqlBuilder) {
if c.Condition == nil {
if c.Mandatory {
panic("jet: WHERE clause not set")
}
return
}
out.NewLine()
out.WriteString("WHERE")
out.IncreaseIdent()
c.Condition.serialize(statementType, out, noWrap)
out.DecreaseIdent()
}
type ClauseGroupBy struct {
List []GroupByClause
}
func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SqlBuilder) {
if len(c.List) == 0 {
return
}
out.NewLine()
out.WriteString("GROUP BY")
out.IncreaseIdent()
serializeGroupByClauseList(statementType, c.List, out)
out.DecreaseIdent()
}
type ClauseHaving struct {
Condition BoolExpression
}
func (c *ClauseHaving) Serialize(statementType StatementType, out *SqlBuilder) {
if c.Condition == nil {
return
}
out.NewLine()
out.WriteString("HAVING")
out.IncreaseIdent()
c.Condition.serialize(statementType, out, noWrap)
out.DecreaseIdent()
}
type ClauseOrderBy struct {
List []OrderByClause
}
func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SqlBuilder) {
if o.List == nil {
return
}
out.NewLine()
out.WriteString("ORDER BY")
out.IncreaseIdent()
serializeOrderByClauseList(statementType, o.List, out)
out.DecreaseIdent()
}
type ClauseLimit struct {
Count int64
}
func (l *ClauseLimit) Serialize(statementType StatementType, out *SqlBuilder) {
if l.Count >= 0 {
out.NewLine()
out.WriteString("LIMIT")
out.insertParametrizedArgument(l.Count)
}
}
type ClauseOffset struct {
Count int64
}
func (o *ClauseOffset) Serialize(statementType StatementType, out *SqlBuilder) {
if o.Count >= 0 {
out.NewLine()
out.WriteString("OFFSET")
out.insertParametrizedArgument(o.Count)
}
}
type ClauseFor struct {
Lock SelectLock
}
func (f *ClauseFor) Serialize(statementType StatementType, out *SqlBuilder) {
if f.Lock == nil {
return
}
out.NewLine()
out.WriteString("FOR")
f.Lock.serialize(statementType, out)
}
type ClauseSetStmtOperator struct {
Operator string
All bool
Selects []StatementWithProjections
OrderBy ClauseOrderBy
Limit ClauseLimit
Offset ClauseOffset
}
func (s *ClauseSetStmtOperator) projections() ProjectionList {
if len(s.Selects) > 0 {
return s.Selects[0].projections()
}
return nil
}
func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlBuilder) {
if len(s.Selects) < 2 {
panic("jet: UNION Statement must contain at least two SELECT statements")
}
for i, selectStmt := range s.Selects {
out.NewLine()
if i > 0 {
out.WriteString(s.Operator)
if s.All {
out.WriteString("ALL")
}
out.NewLine()
}
if selectStmt == nil {
panic("jet: select statement of '" + s.Operator + "' is nil")
}
selectStmt.serialize(statementType, out)
}
s.OrderBy.Serialize(statementType, out)
s.Limit.Serialize(statementType, out)
s.Offset.Serialize(statementType, out)
}
type ClauseUpdate struct {
Table SerializerTable
}
func (u *ClauseUpdate) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString("UPDATE")
if utils.IsNil(u.Table) {
panic("jet: table to update is nil")
}
u.Table.serialize(statementType, out)
}
type ClauseSet struct {
Columns []Column
Values []Serializer
}
func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString("SET")
if len(s.Columns) != len(s.Values) {
panic("jet: mismatch in numbers of columns and values for SET clause")
}
out.IncreaseIdent(4)
for i, column := range s.Columns {
if i > 0 {
out.WriteString(", ")
out.NewLine()
}
if column == nil {
panic("jet: nil column in columns list for SET clause")
}
out.WriteString(column.Name())
out.WriteString(" = ")
s.Values[i].serialize(UpdateStatementType, out)
}
out.DecreaseIdent(4)
}
type ClauseInsert struct {
Table SerializerTable
Columns []Column
}
func (i *ClauseInsert) GetColumns() []Column {
if len(i.Columns) > 0 {
return i.Columns
}
return i.Table.columns()
}
func (i *ClauseInsert) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString("INSERT INTO")
if utils.IsNil(i.Table) {
panic("jet: table is nil for INSERT clause")
}
i.Table.serialize(statementType, out)
if len(i.Columns) > 0 {
out.WriteString("(")
SerializeColumnNames(i.Columns, out)
out.WriteString(")")
}
}
type ClauseValuesQuery struct {
ClauseValues
ClauseQuery
}
func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SqlBuilder) {
if len(v.Rows) == 0 && v.Query == nil {
panic("jet: VALUES or QUERY has to be specified for INSERT statement")
}
if len(v.Rows) > 0 && v.Query != nil {
panic("jet: VALUES or QUERY has to be specified for INSERT statement")
}
v.ClauseValues.Serialize(statementType, out)
v.ClauseQuery.Serialize(statementType, out)
}
type ClauseValues struct {
Rows [][]Serializer
}
func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) {
if len(v.Rows) == 0 {
return
}
out.WriteString("VALUES")
for rowIndex, row := range v.Rows {
if rowIndex > 0 {
out.WriteString(",")
}
out.IncreaseIdent()
out.NewLine()
out.WriteString("(")
SerializeClauseList(statementType, row, out)
out.WriteByte(')')
out.DecreaseIdent()
}
}
type ClauseQuery struct {
Query SerializerStatement
}
func (v *ClauseQuery) Serialize(statementType StatementType, out *SqlBuilder) {
if v.Query == nil {
return
}
v.Query.serialize(statementType, out)
}
type ClauseDelete struct {
Table SerializerTable
}
func (d *ClauseDelete) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString("DELETE FROM")
if d.Table == nil {
panic("jet: nil table in DELETE clause")
}
d.Table.serialize(statementType, out)
}
type ClauseStatementBegin struct {
Name string
Tables []SerializerTable
}
func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString(d.Name)
for i, table := range d.Tables {
if i > 0 {
out.WriteString(", ")
}
table.serialize(statementType, out)
}
}
type ClauseOptional struct {
Name string
Show bool
InNewLine bool
}
func (d *ClauseOptional) Serialize(statementType StatementType, out *SqlBuilder) {
if !d.Show {
return
}
if d.InNewLine {
out.NewLine()
}
out.WriteString(d.Name)
}
type ClauseIn struct {
LockMode string
}
func (i *ClauseIn) Serialize(statementType StatementType, out *SqlBuilder) {
if i.LockMode == "" {
return
}
out.WriteString("IN")
out.WriteString(string(i.LockMode))
out.WriteString("MODE")
}

View file

@ -0,0 +1,16 @@
package jet
import (
"gotest.tools/assert"
"testing"
)
func TestClauseSelect_Serialize(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "jet: SELECT clause has to have at least one projection")
}()
selectClause := &ClauseSelect{}
selectClause.Serialize(SelectStatementType, &SqlBuilder{})
}

144
internal/jet/column.go Normal file
View file

@ -0,0 +1,144 @@
// Modeling of columns
package jet
type Column interface {
Name() string
TableName() string
setTableName(table string)
setSubQuery(subQuery SelectTable)
defaultAlias() string
}
// Column is common column interface for all types of columns.
type ColumnExpression interface {
Column
Expression
}
// The base type for real materialized columns.
type columnImpl struct {
ExpressionInterfaceImpl
name string
tableName string
subQuery SelectTable
}
func newColumn(name string, tableName string, parent ColumnExpression) columnImpl {
bc := columnImpl{
name: name,
tableName: tableName,
}
bc.ExpressionInterfaceImpl.Parent = parent
return bc
}
func (c *columnImpl) Name() string {
return c.name
}
func (c *columnImpl) TableName() string {
return c.tableName
}
func (c *columnImpl) setTableName(table string) {
c.tableName = table
}
func (c *columnImpl) setSubQuery(subQuery SelectTable) {
c.subQuery = subQuery
}
func (c *columnImpl) defaultAlias() string {
if c.tableName != "" {
return c.tableName + "." + c.name
}
return c.name
}
func (c *columnImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) {
if statement == SetStatementType {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.WriteAlias(c.defaultAlias()) //always quote
return
}
c.serialize(statement, out)
}
func (c columnImpl) serializeForProjection(statement StatementType, out *SqlBuilder) {
c.serialize(statement, out)
out.WriteString("AS")
out.WriteAlias(c.defaultAlias())
}
func (c columnImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if c.subQuery != nil {
out.WriteIdentifier(c.subQuery.Alias())
out.WriteByte('.')
out.WriteIdentifier(c.defaultAlias(), true)
} else {
if c.tableName != "" {
out.WriteIdentifier(c.tableName)
out.WriteByte('.')
}
out.WriteIdentifier(c.name)
}
}
//------------------------------------------------------//
type IColumnList interface {
Projection
Column
columns() []ColumnExpression
}
func ColumnList(columns ...ColumnExpression) IColumnList {
return columnListImpl(columns)
}
// ColumnList is redefined type to support list of columns as single Projection
type columnListImpl []ColumnExpression
func (cl columnListImpl) columns() []ColumnExpression {
return cl
}
func (cl columnListImpl) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
}
return newProjectionList
}
func (cl columnListImpl) serializeForProjection(statement StatementType, out *SqlBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionList(statement, projections, out)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl columnListImpl) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl columnListImpl) TableName() string { return "" }
func (cl columnListImpl) setTableName(name string) {}
func (cl columnListImpl) setSubQuery(subQuery SelectTable) {}
func (cl columnListImpl) defaultAlias() string { return "" }

View file

@ -4,7 +4,7 @@ import "testing"
func TestColumn(t *testing.T) { func TestColumn(t *testing.T) {
column := newColumn("col", "", nil) column := newColumn("col", "", nil)
column.expressionInterfaceImpl.parent = &column column.ExpressionInterfaceImpl.Parent = &column
assertClauseSerialize(t, column, "col") assertClauseSerialize(t, column, "col")
column.setTableName("table1") column.setTableName("table1")

View file

@ -3,7 +3,7 @@ package jet
// ColumnBool is interface for SQL boolean columns. // ColumnBool is interface for SQL boolean columns.
type ColumnBool interface { type ColumnBool interface {
BoolExpression BoolExpression
column Column
From(subQuery SelectTable) ColumnBool From(subQuery SelectTable) ColumnBool
} }
@ -14,7 +14,7 @@ type boolColumnImpl struct {
columnImpl columnImpl
} }
func (i *boolColumnImpl) from(subQuery SelectTable) projection { func (i *boolColumnImpl) fromImpl(subQuery SelectTable) Projection {
newBoolColumn := BoolColumn(i.name) newBoolColumn := BoolColumn(i.name)
newBoolColumn.setTableName(i.tableName) newBoolColumn.setTableName(i.tableName)
newBoolColumn.setSubQuery(subQuery) newBoolColumn.setSubQuery(subQuery)
@ -23,7 +23,7 @@ func (i *boolColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool { func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
newBoolColumn := i.from(subQuery).(ColumnBool) newBoolColumn := i.fromImpl(subQuery).(ColumnBool)
return newBoolColumn return newBoolColumn
} }
@ -42,7 +42,7 @@ func BoolColumn(name string) ColumnBool {
// ColumnFloat is interface for SQL real, numeric, decimal or double precision column. // ColumnFloat is interface for SQL real, numeric, decimal or double precision column.
type ColumnFloat interface { type ColumnFloat interface {
FloatExpression FloatExpression
column Column
From(subQuery SelectTable) ColumnFloat From(subQuery SelectTable) ColumnFloat
} }
@ -52,7 +52,7 @@ type floatColumnImpl struct {
columnImpl columnImpl
} }
func (i *floatColumnImpl) from(subQuery SelectTable) projection { func (i *floatColumnImpl) fromImpl(subQuery SelectTable) Projection {
newFloatColumn := FloatColumn(i.name) newFloatColumn := FloatColumn(i.name)
newFloatColumn.setTableName(i.tableName) newFloatColumn.setTableName(i.tableName)
newFloatColumn.setSubQuery(subQuery) newFloatColumn.setSubQuery(subQuery)
@ -61,7 +61,7 @@ func (i *floatColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat { func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
newFloatColumn := i.from(subQuery).(ColumnFloat) newFloatColumn := i.fromImpl(subQuery).(ColumnFloat)
return newFloatColumn return newFloatColumn
} }
@ -80,7 +80,7 @@ func FloatColumn(name string) ColumnFloat {
// ColumnInteger is interface for SQL smallint, integer, bigint columns. // ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger interface { type ColumnInteger interface {
IntegerExpression IntegerExpression
column Column
From(subQuery SelectTable) ColumnInteger From(subQuery SelectTable) ColumnInteger
} }
@ -91,7 +91,7 @@ type integerColumnImpl struct {
columnImpl columnImpl
} }
func (i *integerColumnImpl) from(subQuery SelectTable) projection { func (i *integerColumnImpl) fromImpl(subQuery SelectTable) Projection {
newIntColumn := IntegerColumn(i.name) newIntColumn := IntegerColumn(i.name)
newIntColumn.setTableName(i.tableName) newIntColumn.setTableName(i.tableName)
newIntColumn.setSubQuery(subQuery) newIntColumn.setSubQuery(subQuery)
@ -100,7 +100,7 @@ func (i *integerColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger { func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
return i.from(subQuery).(ColumnInteger) return i.fromImpl(subQuery).(ColumnInteger)
} }
// IntegerColumn creates named integer column. // IntegerColumn creates named integer column.
@ -118,7 +118,7 @@ func IntegerColumn(name string) ColumnInteger {
// bytea, uuid columns and enums types. // bytea, uuid columns and enums types.
type ColumnString interface { type ColumnString interface {
StringExpression StringExpression
column Column
From(subQuery SelectTable) ColumnString From(subQuery SelectTable) ColumnString
} }
@ -129,7 +129,7 @@ type stringColumnImpl struct {
columnImpl columnImpl
} }
func (i *stringColumnImpl) from(subQuery SelectTable) projection { func (i *stringColumnImpl) fromImpl(subQuery SelectTable) Projection {
newStrColumn := StringColumn(i.name) newStrColumn := StringColumn(i.name)
newStrColumn.setTableName(i.tableName) newStrColumn.setTableName(i.tableName)
newStrColumn.setSubQuery(subQuery) newStrColumn.setSubQuery(subQuery)
@ -138,7 +138,7 @@ func (i *stringColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString { func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
return i.from(subQuery).(ColumnString) return i.fromImpl(subQuery).(ColumnString)
} }
// StringColumn creates named string column. // StringColumn creates named string column.
@ -155,7 +155,7 @@ func StringColumn(name string) ColumnString {
// ColumnTime is interface for SQL time column. // ColumnTime is interface for SQL time column.
type ColumnTime interface { type ColumnTime interface {
TimeExpression TimeExpression
column Column
From(subQuery SelectTable) ColumnTime From(subQuery SelectTable) ColumnTime
} }
@ -165,7 +165,7 @@ type timeColumnImpl struct {
columnImpl columnImpl
} }
func (i *timeColumnImpl) from(subQuery SelectTable) projection { func (i *timeColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimeColumn := TimeColumn(i.name) newTimeColumn := TimeColumn(i.name)
newTimeColumn.setTableName(i.tableName) newTimeColumn.setTableName(i.tableName)
newTimeColumn.setSubQuery(subQuery) newTimeColumn.setSubQuery(subQuery)
@ -174,7 +174,7 @@ func (i *timeColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime { func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
return i.from(subQuery).(ColumnTime) return i.fromImpl(subQuery).(ColumnTime)
} }
// TimeColumn creates named time column // TimeColumn creates named time column
@ -190,7 +190,7 @@ func TimeColumn(name string) ColumnTime {
// ColumnTimez is interface of SQL time with time zone columns. // ColumnTimez is interface of SQL time with time zone columns.
type ColumnTimez interface { type ColumnTimez interface {
TimezExpression TimezExpression
column Column
From(subQuery SelectTable) ColumnTimez From(subQuery SelectTable) ColumnTimez
} }
@ -201,7 +201,7 @@ type timezColumnImpl struct {
columnImpl columnImpl
} }
func (i *timezColumnImpl) from(subQuery SelectTable) projection { func (i *timezColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimezColumn := TimezColumn(i.name) newTimezColumn := TimezColumn(i.name)
newTimezColumn.setTableName(i.tableName) newTimezColumn.setTableName(i.tableName)
newTimezColumn.setSubQuery(subQuery) newTimezColumn.setSubQuery(subQuery)
@ -210,7 +210,7 @@ func (i *timezColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez { func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
return i.from(subQuery).(ColumnTimez) return i.fromImpl(subQuery).(ColumnTimez)
} }
// TimezColumn creates named time with time zone column. // TimezColumn creates named time with time zone column.
@ -227,7 +227,7 @@ func TimezColumn(name string) ColumnTimez {
// ColumnTimestamp is interface of SQL timestamp columns. // ColumnTimestamp is interface of SQL timestamp columns.
type ColumnTimestamp interface { type ColumnTimestamp interface {
TimestampExpression TimestampExpression
column Column
From(subQuery SelectTable) ColumnTimestamp From(subQuery SelectTable) ColumnTimestamp
} }
@ -238,7 +238,7 @@ type timestampColumnImpl struct {
columnImpl columnImpl
} }
func (i *timestampColumnImpl) from(subQuery SelectTable) projection { func (i *timestampColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimestampColumn := TimestampColumn(i.name) newTimestampColumn := TimestampColumn(i.name)
newTimestampColumn.setTableName(i.tableName) newTimestampColumn.setTableName(i.tableName)
newTimestampColumn.setSubQuery(subQuery) newTimestampColumn.setSubQuery(subQuery)
@ -247,7 +247,7 @@ func (i *timestampColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp { func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
return i.from(subQuery).(ColumnTimestamp) return i.fromImpl(subQuery).(ColumnTimestamp)
} }
// TimestampColumn creates named timestamp column // TimestampColumn creates named timestamp column
@ -264,7 +264,7 @@ func TimestampColumn(name string) ColumnTimestamp {
// ColumnTimestampz is interface of SQL timestamp with timezone columns. // ColumnTimestampz is interface of SQL timestamp with timezone columns.
type ColumnTimestampz interface { type ColumnTimestampz interface {
TimestampzExpression TimestampzExpression
column Column
From(subQuery SelectTable) ColumnTimestampz From(subQuery SelectTable) ColumnTimestampz
} }
@ -275,7 +275,7 @@ type timestampzColumnImpl struct {
columnImpl columnImpl
} }
func (i *timestampzColumnImpl) from(subQuery SelectTable) projection { func (i *timestampzColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimestampzColumn := TimestampzColumn(i.name) newTimestampzColumn := TimestampzColumn(i.name)
newTimestampzColumn.setTableName(i.tableName) newTimestampzColumn.setTableName(i.tableName)
newTimestampzColumn.setSubQuery(subQuery) newTimestampzColumn.setSubQuery(subQuery)
@ -284,7 +284,7 @@ func (i *timestampzColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz { func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
return i.from(subQuery).(ColumnTimestampz) return i.fromImpl(subQuery).(ColumnTimestampz)
} }
// TimestampzColumn creates named timestamp with time zone column. // TimestampzColumn creates named timestamp with time zone column.
@ -301,7 +301,7 @@ func TimestampzColumn(name string) ColumnTimestampz {
// ColumnDate is interface of SQL date columns. // ColumnDate is interface of SQL date columns.
type ColumnDate interface { type ColumnDate interface {
DateExpression DateExpression
column Column
From(subQuery SelectTable) ColumnDate From(subQuery SelectTable) ColumnDate
} }
@ -312,7 +312,7 @@ type dateColumnImpl struct {
columnImpl columnImpl
} }
func (i *dateColumnImpl) from(subQuery SelectTable) projection { func (i *dateColumnImpl) fromImpl(subQuery SelectTable) Projection {
newDateColumn := DateColumn(i.name) newDateColumn := DateColumn(i.name)
newDateColumn.setTableName(i.tableName) newDateColumn.setTableName(i.tableName)
newDateColumn.setSubQuery(subQuery) newDateColumn.setSubQuery(subQuery)
@ -321,7 +321,7 @@ func (i *dateColumnImpl) from(subQuery SelectTable) projection {
} }
func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate { func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
return i.from(subQuery).(ColumnDate) return i.fromImpl(subQuery).(ColumnDate)
} }
// DateColumn creates named date column. // DateColumn creates named date column.

View file

@ -4,7 +4,9 @@ import (
"testing" "testing"
) )
var subQuery = table1.SELECT(table1ColFloat, table1ColInt).AsTable("sub_query") var subQuery = &SelectTableImpl{
alias: "sub_query",
}
func TestNewBoolColumn(t *testing.T) { func TestNewBoolColumn(t *testing.T) {
boolColumn := BoolColumn("colBool").From(subQuery) boolColumn := BoolColumn("colBool").From(subQuery)

83
internal/jet/dialect.go Normal file
View file

@ -0,0 +1,83 @@
package jet
type Dialect interface {
Name() string
PackageName() string
OperatorSerializeOverride(operator string) SerializeOverride
FunctionSerializeOverride(function string) SerializeOverride
AliasQuoteChar() byte
IdentifierQuoteChar() byte
ArgumentPlaceholder() QueryPlaceholderFunc
}
type SerializeFunc func(statement StatementType, out *SqlBuilder, options ...SerializeOption)
type SerializeOverride func(expressions ...Expression) SerializeFunc
type QueryPlaceholderFunc func(ord int) string
type DialectParams struct {
Name string
PackageName string
OperatorSerializeOverrides map[string]SerializeOverride
FunctionSerializeOverrides map[string]SerializeOverride
AliasQuoteChar byte
IdentifierQuoteChar byte
ArgumentPlaceholder QueryPlaceholderFunc
}
func NewDialect(params DialectParams) Dialect {
return &dialectImpl{
name: params.Name,
packageName: params.PackageName,
operatorSerializeOverrides: params.OperatorSerializeOverrides,
functionSerializeOverrides: params.FunctionSerializeOverrides,
aliasQuoteChar: params.AliasQuoteChar,
identifierQuoteChar: params.IdentifierQuoteChar,
argumentPlaceholder: params.ArgumentPlaceholder,
}
}
type dialectImpl struct {
name string
packageName string
operatorSerializeOverrides map[string]SerializeOverride
functionSerializeOverrides map[string]SerializeOverride
aliasQuoteChar byte
identifierQuoteChar byte
argumentPlaceholder QueryPlaceholderFunc
supportsReturning bool
}
func (d *dialectImpl) Name() string {
return d.name
}
func (d *dialectImpl) PackageName() string {
return d.packageName
}
func (d *dialectImpl) OperatorSerializeOverride(operator string) SerializeOverride {
if d.operatorSerializeOverrides == nil {
return nil
}
return d.operatorSerializeOverrides[operator]
}
func (d *dialectImpl) FunctionSerializeOverride(function string) SerializeOverride {
if d.functionSerializeOverrides == nil {
return nil
}
return d.functionSerializeOverrides[function]
}
func (d *dialectImpl) AliasQuoteChar() byte {
return d.aliasQuoteChar
}
func (d *dialectImpl) IdentifierQuoteChar() byte {
return d.identifierQuoteChar
}
func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc {
return d.argumentPlaceholder
}

View file

@ -1,8 +1,9 @@
package jet package jet
type enumValue struct { type enumValue struct {
expressionInterfaceImpl ExpressionInterfaceImpl
stringInterfaceImpl stringInterfaceImpl
name string name string
} }
@ -10,13 +11,12 @@ type enumValue struct {
func NewEnumValue(name string) StringExpression { func NewEnumValue(name string) StringExpression {
enumValue := &enumValue{name: name} enumValue := &enumValue{name: name}
enumValue.expressionInterfaceImpl.parent = enumValue enumValue.ExpressionInterfaceImpl.Parent = enumValue
enumValue.stringInterfaceImpl.parent = enumValue enumValue.stringInterfaceImpl.parent = enumValue
return enumValue return enumValue
} }
func (e enumValue) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (e enumValue) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.insertConstantArgument(e.name) out.insertConstantArgument(e.name)
return nil
} }

178
internal/jet/expression.go Normal file
View file

@ -0,0 +1,178 @@
package jet
// Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface {
Serializer
Projection
GroupByClause
OrderByClause
// Test expression whether it is a NULL value.
IS_NULL() BoolExpression
// Test expression whether it is a non-NULL value.
IS_NOT_NULL() BoolExpression
// Check if this expressions matches any in expressions list
IN(expressions ...Expression) BoolExpression
// Check if this expressions is different of all expressions in expressions list
NOT_IN(expressions ...Expression) BoolExpression
// The temporary alias name to assign to the expression
AS(alias string) Projection
// Expression will be used to sort query result in ascending order
ASC() OrderByClause
// Expression will be used to sort query result in ascending order
DESC() OrderByClause
}
type ExpressionInterfaceImpl struct {
Parent Expression
}
func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection {
return e.Parent
}
func (e *ExpressionInterfaceImpl) IS_NULL() BoolExpression {
return newPostifxBoolExpression(e.Parent, "IS NULL")
}
func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression {
return newPostifxBoolExpression(e.Parent, "IS NOT NULL")
}
func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "IN")
}
func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "NOT IN")
}
func (e *ExpressionInterfaceImpl) AS(alias string) Projection {
return newAlias(e.Parent, alias)
}
func (e *ExpressionInterfaceImpl) ASC() OrderByClause {
return newOrderByClause(e.Parent, true)
}
func (e *ExpressionInterfaceImpl) DESC() OrderByClause {
return newOrderByClause(e.Parent, false)
}
func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SqlBuilder) {
e.Parent.serialize(statement, out, noWrap)
}
func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SqlBuilder) {
e.Parent.serialize(statement, out, noWrap)
}
func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) {
e.Parent.serialize(statement, out, noWrap)
}
// Representation of binary operations (e.g. comparisons, arithmetic)
type binaryOpExpression struct {
lhs, rhs Expression
additionalParam Expression
operator string
}
func newBinaryExpression(lhs, rhs Expression, operator string, additionalParam ...Expression) binaryOpExpression {
binaryExpression := binaryOpExpression{
lhs: lhs,
rhs: rhs,
operator: operator,
}
if len(additionalParam) > 0 {
binaryExpression.additionalParam = additionalParam[0]
}
return binaryExpression
}
func (c *binaryOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if c.lhs == nil {
panic("jet: lhs is nil for '" + c.operator + "' operator")
}
if c.rhs == nil {
panic("jet: rhs is nil for '" + c.operator + "' operator")
}
wrap := !contains(options, noWrap)
if wrap {
out.WriteString("(")
}
if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam)
serializeOverrideFunc(statement, out, options...)
} else {
c.lhs.serialize(statement, out)
out.WriteString(c.operator)
c.rhs.serialize(statement, out)
}
if wrap {
out.WriteString(")")
}
}
// A prefix operator Expression
type prefixOpExpression struct {
expression Expression
operator string
}
func newPrefixExpression(expression Expression, operator string) prefixOpExpression {
prefixExpression := prefixOpExpression{
expression: expression,
operator: operator,
}
return prefixExpression
}
func (p *prefixOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString("(")
out.WriteString(p.operator)
if p.expression == nil {
panic("jet: nil prefix expression in prefix operator " + p.operator)
}
p.expression.serialize(statement, out)
out.WriteString(")")
}
// A postifx operator Expression
type postfixOpExpression struct {
expression Expression
operator string
}
func newPostfixOpExpression(expression Expression, operator string) postfixOpExpression {
postfixOpExpression := postfixOpExpression{
expression: expression,
operator: operator,
}
return postfixOpExpression
}
func (p *postfixOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if p.expression == nil {
panic("jet: nil prefix expression in postfix operator " + p.operator)
}
p.expression.serialize(statement, out)
out.WriteString(p.operator)
}

View file

@ -4,10 +4,13 @@ import (
"testing" "testing"
) )
func TestInvalidExpression(t *testing.T) {
assertClauseSerializeErr(t, table2Col3.ADD(nil), `jet: rhs is nil for '+' operator`)
}
func TestExpressionIS_NULL(t *testing.T) { func TestExpressionIS_NULL(t *testing.T) {
assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL") assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL")
assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL") assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL")
assertClauseSerializeErr(t, table2Col3.ADD(nil), "jet: nil rhs")
} }
func TestExpressionIS_NOT_NULL(t *testing.T) { func TestExpressionIS_NOT_NULL(t *testing.T) {
@ -26,33 +29,14 @@ func TestExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
} }
func TestIN(t *testing.T) { func TestIN(t *testing.T) {
assertClauseSerialize(t, table2ColInt.IN(Int(1), Int(2), Int(3)),
`(table2.col_int IN ($1, $2, $3))`, int64(1), int64(2), int64(3))
assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)),
`($1 IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
} }
func TestNOT_IN(t *testing.T) { func TestNOT_IN(t *testing.T) {
assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), assertClauseSerialize(t, table2ColInt.NOT_IN(Int(1), Int(2), Int(3)),
`($1 NOT IN (( `(table2.col_int NOT IN ($1, $2, $3))`, int64(1), int64(2), int64(3))
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) NOT IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
} }

View file

@ -81,12 +81,12 @@ func (n *floatInterfaceImpl) MOD(expression NumericExpression) FloatExpression {
} }
func (n *floatInterfaceImpl) POW(expression NumericExpression) FloatExpression { func (n *floatInterfaceImpl) POW(expression NumericExpression) FloatExpression {
return newBinaryFloatExpression(n.parent, expression, "^") return POW(n.parent, expression)
} }
//---------------------------------------------------// //---------------------------------------------------//
type binaryFloatExpression struct { type binaryFloatExpression struct {
expressionInterfaceImpl ExpressionInterfaceImpl
floatInterfaceImpl floatInterfaceImpl
binaryOpExpression binaryOpExpression
@ -97,7 +97,7 @@ func newBinaryFloatExpression(lhs, rhs Expression, operator string) FloatExpress
floatExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) floatExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
floatExpression.expressionInterfaceImpl.parent = &floatExpression floatExpression.ExpressionInterfaceImpl.Parent = &floatExpression
floatExpression.floatInterfaceImpl.parent = &floatExpression floatExpression.floatInterfaceImpl.parent = &floatExpression
return &floatExpression return &floatExpression

View file

@ -70,8 +70,8 @@ func TestFloatExpressionMOD(t *testing.T) {
} }
func TestFloatExpressionPOW(t *testing.T) { func TestFloatExpressionPOW(t *testing.T) {
assertClauseSerialize(t, table1ColFloat.POW(table2ColFloat), "(table1.col_float ^ table2.col_float)") assertClauseSerialize(t, table1ColFloat.POW(table2ColFloat), "POW(table1.col_float, table2.col_float)")
assertClauseSerialize(t, table1ColFloat.POW(Float(2.11)), "(table1.col_float ^ $1)", float64(2.11)) assertClauseSerialize(t, table1ColFloat.POW(Float(2.11)), "POW(table1.col_float, $1)", float64(2.11))
} }
func TestFloatExp(t *testing.T) { func TestFloatExp(t *testing.T) {

View file

@ -1,7 +1,5 @@
package jet package jet
import "errors"
// ROW is construct one table row from list of expressions. // ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression { func ROW(expressions ...Expression) Expression {
return newFunc("ROW", expressions, nil) return newFunc("ROW", expressions, nil)
@ -11,7 +9,7 @@ func ROW(expressions ...Expression) Expression {
// ABSf calculates absolute value from float expression // ABSf calculates absolute value from float expression
func ABSf(floatExpression FloatExpression) FloatExpression { func ABSf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("ABS", floatExpression) return NewFloatFunc("ABS", floatExpression)
} }
// ABSi calculates absolute value from int expression // ABSi calculates absolute value from int expression
@ -19,62 +17,72 @@ func ABSi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("ABS", integerExpression) return newIntegerFunc("ABS", integerExpression)
} }
// POW calculates power of base with exponent
func POW(base, exponent NumericExpression) FloatExpression {
return NewFloatFunc("POW", base, exponent)
}
// POWER calculates power of base with exponent
func POWER(base, exponent NumericExpression) FloatExpression {
return NewFloatFunc("POWER", base, exponent)
}
// SQRT calculates square root of numeric expression // SQRT calculates square root of numeric expression
func SQRT(numericExpression NumericExpression) FloatExpression { func SQRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("SQRT", numericExpression) return NewFloatFunc("SQRT", numericExpression)
} }
// CBRT calculates cube root of numeric expression // CBRT calculates cube root of numeric expression
func CBRT(numericExpression NumericExpression) FloatExpression { func CBRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("CBRT", numericExpression) return NewFloatFunc("CBRT", numericExpression)
} }
// CEIL calculates ceil of float expression // CEIL calculates ceil of float expression
func CEIL(floatExpression FloatExpression) FloatExpression { func CEIL(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("CEIL", floatExpression) return NewFloatFunc("CEIL", floatExpression)
} }
// FLOOR calculates floor of float expression // FLOOR calculates floor of float expression
func FLOOR(floatExpression FloatExpression) FloatExpression { func FLOOR(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("FLOOR", floatExpression) return NewFloatFunc("FLOOR", floatExpression)
} }
// ROUND calculates round of a float expressions with optional precision // ROUND calculates round of a float expressions with optional precision
func ROUND(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression { func ROUND(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression {
if len(precision) > 0 { if len(precision) > 0 {
return newFloatFunc("ROUND", floatExpression, precision[0]) return NewFloatFunc("ROUND", floatExpression, precision[0])
} }
return newFloatFunc("ROUND", floatExpression) return NewFloatFunc("ROUND", floatExpression)
} }
// SIGN returns sign of float expression // SIGN returns sign of float expression
func SIGN(floatExpression FloatExpression) FloatExpression { func SIGN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SIGN", floatExpression) return NewFloatFunc("SIGN", floatExpression)
} }
// TRUNC calculates trunc of float expression with optional precision // TRUNC calculates trunc of float expression with optional precision
func TRUNC(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression { func TRUNC(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression {
if len(precision) > 0 { if len(precision) > 0 {
return newFloatFunc("TRUNC", floatExpression, precision[0]) return NewFloatFunc("TRUNC", floatExpression, precision[0])
} }
return newFloatFunc("TRUNC", floatExpression) return NewFloatFunc("TRUNC", floatExpression)
} }
// LN calculates natural algorithm of float expression // LN calculates natural algorithm of float expression
func LN(floatExpression FloatExpression) FloatExpression { func LN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LN", floatExpression) return NewFloatFunc("LN", floatExpression)
} }
// LOG calculates logarithm of float expression // LOG calculates logarithm of float expression
func LOG(floatExpression FloatExpression) FloatExpression { func LOG(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LOG", floatExpression) return NewFloatFunc("LOG", floatExpression)
} }
// ----------------- Aggregate functions -------------------// // ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression // AVG is aggregate function used to calculate avg value from numeric expression
func AVG(numericExpression NumericExpression) FloatExpression { func AVG(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("AVG", numericExpression) return NewFloatFunc("AVG", numericExpression)
} }
// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none. // BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none.
@ -109,7 +117,7 @@ func EVERY(boolExpression BoolExpression) BoolExpression {
// MAXf is aggregate function. Returns maximum value of float expression across all input values // MAXf is aggregate function. Returns maximum value of float expression across all input values
func MAXf(floatExpression FloatExpression) FloatExpression { func MAXf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MAX", floatExpression) return NewFloatFunc("MAX", floatExpression)
} }
// MAXi is aggregate function. Returns maximum value of int expression across all input values // MAXi is aggregate function. Returns maximum value of int expression across all input values
@ -119,7 +127,7 @@ func MAXi(integerExpression IntegerExpression) IntegerExpression {
// MINf is aggregate function. Returns minimum value of float expression across all input values // MINf is aggregate function. Returns minimum value of float expression across all input values
func MINf(floatExpression FloatExpression) FloatExpression { func MINf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MIN", floatExpression) return NewFloatFunc("MIN", floatExpression)
} }
// MINi is aggregate function. Returns minimum value of int expression across all input values // MINi is aggregate function. Returns minimum value of int expression across all input values
@ -129,7 +137,7 @@ func MINi(integerExpression IntegerExpression) IntegerExpression {
// SUMf is aggregate function. Returns sum of expression across all float expressions // SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) FloatExpression { func SUMf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SUM", floatExpression) return NewFloatFunc("SUM", floatExpression)
} }
// SUMi is aggregate function. Returns sum of expression across all integer expression. // SUMi is aggregate function. Returns sum of expression across all integer expression.
@ -196,14 +204,15 @@ func CHR(integerExpression IntegerExpression) StringExpression {
return newStringFunc("CHR", integerExpression) return newStringFunc("CHR", integerExpression)
} }
// // CONCAT adds two or more expressions together
//func CONCAT(expressions ...Expression) StringExpression { func CONCAT(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT", expressions...) return newStringFunc("CONCAT", expressions...)
//} }
//
//func CONCAT_WS(expressions ...Expression) StringExpression { // CONCAT_WS adds two or more expressions together with a separator.
// return newStringFunc("CONCAT_WS", expressions...) func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression {
//} return newStringFunc("CONCAT_WS", append([]Expression{separator}, expressions...)...)
}
// CONVERT converts string to dest_encoding. The original encoding is // CONVERT converts string to dest_encoding. The original encoding is
// specified by src_encoding. The string must be valid in this encoding. // specified by src_encoding. The string must be valid in this encoding.
@ -235,11 +244,12 @@ func DECODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("DECODE", data, format) return newStringFunc("DECODE", data, format)
} }
//func FORMAT(formatStr StringExpression, formatArgs ...expressions) StringExpression { // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
// args := []expressions{formatStr} func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression {
// args = append(args, formatArgs...) args := []Expression{formatStr}
// return newStringFunc("FORMAT", args...) args = append(args, formatArgs...)
//} return newStringFunc("FORMAT", args...)
}
// INITCAP converts the first letter of each word to upper case // INITCAP converts the first letter of each word to upper case
// and the rest to lower case. Words are sequences of alphanumeric // and the rest to lower case. Words are sequences of alphanumeric
@ -336,6 +346,15 @@ func TO_HEX(number IntegerExpression) StringExpression {
return newStringFunc("TO_HEX", number) return newStringFunc("TO_HEX", number)
} }
// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise.
func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType ...string) BoolExpression {
if len(matchType) > 0 {
return newBoolFunc("REGEXP_LIKE", stringExp, pattern, ConstLiteral(matchType[0]))
}
return newBoolFunc("REGEXP_LIKE", stringExp, pattern)
}
//----------Data Type Formatting Functions ----------------------// //----------Data Type Formatting Functions ----------------------//
// TO_CHAR converts expression to string with format // TO_CHAR converts expression to string with format
@ -350,7 +369,7 @@ func TO_DATE(dateStr, format StringExpression) DateExpression {
// TO_NUMBER converts string to numeric using format // TO_NUMBER converts string to numeric using format
func TO_NUMBER(floatStr, format StringExpression) FloatExpression { func TO_NUMBER(floatStr, format StringExpression) FloatExpression {
return newFloatFunc("TO_NUMBER", floatStr, format) return NewFloatFunc("TO_NUMBER", floatStr, format)
} }
// TO_TIMESTAMP converts string to time stamp with time zone using format // TO_TIMESTAMP converts string to time stamp with time zone using format
@ -372,7 +391,7 @@ func CURRENT_TIME(precision ...int) TimezExpression {
var timezFunc *timezFunc var timezFunc *timezFunc
if len(precision) > 0 { if len(precision) > 0 {
timezFunc = newTimezFunc("CURRENT_TIME", constLiteral(precision[0])) timezFunc = newTimezFunc("CURRENT_TIME", ConstLiteral(precision[0]))
} else { } else {
timezFunc = newTimezFunc("CURRENT_TIME") timezFunc = newTimezFunc("CURRENT_TIME")
} }
@ -387,7 +406,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression {
var timestampzFunc *timestampzFunc var timestampzFunc *timestampzFunc
if len(precision) > 0 { if len(precision) > 0 {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", constLiteral(precision[0])) timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", ConstLiteral(precision[0]))
} else { } else {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP") timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP")
} }
@ -402,7 +421,7 @@ func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc var timeFunc *timeFunc
if len(precision) > 0 { if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", constLiteral(precision[0])) timeFunc = newTimeFunc("LOCALTIME", ConstLiteral(precision[0]))
} else { } else {
timeFunc = newTimeFunc("LOCALTIME") timeFunc = newTimeFunc("LOCALTIME")
} }
@ -417,9 +436,9 @@ func LOCALTIMESTAMP(precision ...int) TimestampExpression {
var timestampFunc *timestampFunc var timestampFunc *timestampFunc
if len(precision) > 0 { if len(precision) > 0 {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP", constLiteral(precision[0])) timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", ConstLiteral(precision[0]))
} else { } else {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP") timestampFunc = NewTimestampFunc("LOCALTIMESTAMP")
} }
timestampFunc.noBrackets = true timestampFunc.noBrackets = true
@ -463,7 +482,7 @@ func LEAST(value Expression, values ...Expression) Expression {
//--------------------------------------------------------------------// //--------------------------------------------------------------------//
type funcExpressionImpl struct { type funcExpressionImpl struct {
expressionInterfaceImpl ExpressionInterfaceImpl
name string name string
expressions []Expression expressions []Expression
@ -477,37 +496,34 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
} }
if parent != nil { if parent != nil {
funcExp.expressionInterfaceImpl.parent = parent funcExp.ExpressionInterfaceImpl.Parent = parent
} else { } else {
funcExp.expressionInterfaceImpl.parent = funcExp funcExp.ExpressionInterfaceImpl.Parent = funcExp
} }
return funcExp return funcExp
} }
func (f *funcExpressionImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (f *funcExpressionImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if f == nil { if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil {
return errors.New("jet: Function expressions is nil. ") serializeOverrideFunc := serializeOverride(f.expressions...)
serializeOverrideFunc(statement, out, options...)
return
} }
addBrackets := !f.noBrackets || len(f.expressions) > 0 addBrackets := !f.noBrackets || len(f.expressions) > 0
if addBrackets { if addBrackets {
out.writeString(f.name + "(") out.WriteString(f.name + "(")
} else { } else {
out.writeString(f.name) out.WriteString(f.name)
} }
err := serializeExpressionList(statement, f.expressions, ", ", out) serializeExpressionList(statement, f.expressions, ", ", out)
if err != nil {
return err
}
if addBrackets { if addBrackets {
out.writeString(")") out.WriteString(")")
} }
return nil
} }
type boolFunc struct { type boolFunc struct {
@ -529,7 +545,7 @@ type floatFunc struct {
floatInterfaceImpl floatInterfaceImpl
} }
func newFloatFunc(name string, expressions ...Expression) FloatExpression { func NewFloatFunc(name string, expressions ...Expression) FloatExpression {
floatFunc := &floatFunc{} floatFunc := &floatFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc)
@ -613,7 +629,7 @@ type timestampFunc struct {
timestampInterfaceImpl timestampInterfaceImpl
} }
func newTimestampFunc(name string, expressions ...Expression) *timestampFunc { func NewTimestampFunc(name string, expressions ...Expression) *timestampFunc {
timestampFunc := &timestampFunc{} timestampFunc := &timestampFunc{}
timestampFunc.funcExpressionImpl = *newFunc(name, expressions, timestampFunc) timestampFunc.funcExpressionImpl = *newFunc(name, expressions, timestampFunc)

View file

@ -0,0 +1,5 @@
package jet
type GroupByClause interface {
serializeForGroupBy(statement StatementType, out *SqlBuilder)
}

View file

@ -106,7 +106,7 @@ func (i *integerInterfaceImpl) MOD(expression IntegerExpression) IntegerExpressi
} }
func (i *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, expression, "^") return IntExp(POW(i.parent, expression))
} }
func (i *integerInterfaceImpl) BIT_AND(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) BIT_AND(expression IntegerExpression) IntegerExpression {
@ -131,7 +131,7 @@ func (i *integerInterfaceImpl) BIT_SHIFT_RIGHT(intExpression IntegerExpression)
//---------------------------------------------------// //---------------------------------------------------//
type binaryIntegerExpression struct { type binaryIntegerExpression struct {
expressionInterfaceImpl ExpressionInterfaceImpl
integerInterfaceImpl integerInterfaceImpl
binaryOpExpression binaryOpExpression
@ -140,7 +140,7 @@ type binaryIntegerExpression struct {
func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression { func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression {
integerExpression := binaryIntegerExpression{} integerExpression := binaryIntegerExpression{}
integerExpression.expressionInterfaceImpl.parent = &integerExpression integerExpression.ExpressionInterfaceImpl.Parent = &integerExpression
integerExpression.integerInterfaceImpl.parent = &integerExpression integerExpression.integerInterfaceImpl.parent = &integerExpression
integerExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) integerExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
@ -150,7 +150,7 @@ func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) Int
//---------------------------------------------------// //---------------------------------------------------//
type prefixIntegerOpExpression struct { type prefixIntegerOpExpression struct {
expressionInterfaceImpl ExpressionInterfaceImpl
integerInterfaceImpl integerInterfaceImpl
prefixOpExpression prefixOpExpression
@ -160,12 +160,30 @@ func newPrefixIntegerOperator(expression IntegerExpression, operator string) Int
integerExpression := prefixIntegerOpExpression{} integerExpression := prefixIntegerOpExpression{}
integerExpression.prefixOpExpression = newPrefixExpression(expression, operator) integerExpression.prefixOpExpression = newPrefixExpression(expression, operator)
integerExpression.expressionInterfaceImpl.parent = &integerExpression integerExpression.ExpressionInterfaceImpl.Parent = &integerExpression
integerExpression.integerInterfaceImpl.parent = &integerExpression integerExpression.integerInterfaceImpl.parent = &integerExpression
return &integerExpression return &integerExpression
} }
//---------------------------------------------------//
type prefixFloatOpExpression struct {
ExpressionInterfaceImpl
floatInterfaceImpl
prefixOpExpression
}
func newPrefixFloatOperator(expression FloatExpression, operator string) FloatExpression {
floatOpExpression := prefixFloatOpExpression{}
floatOpExpression.prefixOpExpression = newPrefixExpression(expression, operator)
floatOpExpression.ExpressionInterfaceImpl.Parent = &floatOpExpression
floatOpExpression.floatInterfaceImpl.parent = &floatOpExpression
return &floatOpExpression
}
//---------------------------------------------------// //---------------------------------------------------//
type integerExpressionWrapper struct { type integerExpressionWrapper struct {
integerInterfaceImpl integerInterfaceImpl

View file

@ -60,13 +60,28 @@ func TestIntExpressionMOD(t *testing.T) {
} }
func TestIntExpressionPOW(t *testing.T) { func TestIntExpressionPOW(t *testing.T) {
assertClauseSerialize(t, table1ColInt.POW(table2ColInt), "(table1.col_int ^ table2.col_int)") assertClauseSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)")
assertClauseSerialize(t, table1ColInt.POW(Int(11)), "(table1.col_int ^ $1)", int64(11)) assertClauseSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, $1)", int64(11))
} }
func TestIntExpressionBIT_NOT(t *testing.T) { func TestIntExpressionBIT_NOT(t *testing.T) {
assertClauseSerialize(t, BIT_NOT(table2ColInt), "~ table2.col_int") assertClauseSerialize(t, BIT_NOT(table2ColInt), "(~ table2.col_int)")
assertClauseSerialize(t, BIT_NOT(Int(11)), "~ $1", int64(11)) assertClauseSerialize(t, BIT_NOT(Int(11)), "(~ 11)")
}
func TestIntExpressionBIT_AND(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BIT_AND(table2ColInt), "(table1.col_int & table2.col_int)")
assertClauseSerialize(t, table1ColInt.BIT_AND(Int(11)), "(table1.col_int & $1)", int64(11))
}
func TestIntExpressionBIT_OR(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BIT_OR(table2ColInt), "(table1.col_int | table2.col_int)")
assertClauseSerialize(t, table1ColInt.BIT_OR(Int(11)), "(table1.col_int | $1)", int64(11))
}
func TestIntExpressionBIT_XOR(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int # table2.col_int)")
assertClauseSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int # $1)", int64(11))
} }
func TestIntExpressionBIT_SHIFT_LEFT(t *testing.T) { func TestIntExpressionBIT_SHIFT_LEFT(t *testing.T) {

View file

@ -14,8 +14,6 @@ var (
type keywordClause string type keywordClause string
func (k keywordClause) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (k keywordClause) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.writeString(string(k)) out.WriteString(string(k))
return nil
} }

View file

@ -0,0 +1,342 @@
package jet
import (
"fmt"
"time"
)
// LiteralExpression is representation of an escaped literal
type LiteralExpression interface {
Expression
Value() interface{}
SetConstant(constant bool)
}
type literalExpressionImpl struct {
ExpressionInterfaceImpl
value interface{}
constant bool
}
func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl {
exp := literalExpressionImpl{value: value}
if len(optionalConstant) > 0 {
exp.constant = optionalConstant[0]
}
exp.ExpressionInterfaceImpl.Parent = &exp
return &exp
}
func ConstLiteral(value interface{}) *literalExpressionImpl {
exp := literal(value)
exp.constant = true
return exp
}
func (l *literalExpressionImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if l.constant {
out.insertConstantArgument(l.value)
} else {
out.insertParametrizedArgument(l.value)
}
}
func (l *literalExpressionImpl) Value() interface{} {
return l.value
}
func (l *literalExpressionImpl) SetConstant(constant bool) {
l.constant = constant
}
type integerLiteralExpression struct {
literalExpressionImpl
integerInterfaceImpl
}
// Int is constructor for integer expressions literals.
func Int(value int64) IntegerExpression {
numLiteral := &integerLiteralExpression{}
numLiteral.literalExpressionImpl = *literal(value)
numLiteral.literalExpressionImpl.Parent = numLiteral
numLiteral.integerInterfaceImpl.parent = numLiteral
return numLiteral
}
//---------------------------------------------------//
type boolLiteralExpression struct {
boolInterfaceImpl
literalExpressionImpl
}
// Bool creates new bool literal expression
func Bool(value bool) BoolExpression {
boolLiteralExpression := boolLiteralExpression{}
boolLiteralExpression.literalExpressionImpl = *literal(value)
boolLiteralExpression.boolInterfaceImpl.parent = &boolLiteralExpression
return &boolLiteralExpression
}
//---------------------------------------------------//
type floatLiteral struct {
floatInterfaceImpl
literalExpressionImpl
}
// Float creates new float literal expression
func Float(value float64) FloatExpression {
floatLiteral := floatLiteral{}
floatLiteral.literalExpressionImpl = *literal(value)
floatLiteral.floatInterfaceImpl.parent = &floatLiteral
return &floatLiteral
}
//---------------------------------------------------//
type stringLiteral struct {
stringInterfaceImpl
literalExpressionImpl
}
// String creates new string literal expression
func String(value string) StringExpression {
stringLiteral := stringLiteral{}
stringLiteral.literalExpressionImpl = *literal(value)
stringLiteral.stringInterfaceImpl.parent = &stringLiteral
return &stringLiteral
}
//---------------------------------------------------//
type timeLiteral struct {
timeInterfaceImpl
literalExpressionImpl
}
// Time creates new time literal expression
func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression {
timeLiteral := &timeLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second)
timeStr += formatNanoseconds(nanoseconds...)
timeLiteral.literalExpressionImpl = *literal(timeStr)
timeLiteral.timeInterfaceImpl.parent = timeLiteral
return timeLiteral
}
func TimeT(t time.Time) TimeExpression {
timeLiteral := &timeLiteral{}
timeLiteral.literalExpressionImpl = *literal(t)
timeLiteral.timeInterfaceImpl.parent = timeLiteral
return timeLiteral
}
//---------------------------------------------------//
type timezLiteral struct {
timezInterfaceImpl
literalExpressionImpl
}
// Timez creates new time with time zone literal expression
func Timez(hour, minute, second int, nanoseconds time.Duration, timezone string) TimezExpression {
timezLiteral := timezLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second)
timeStr += formatNanoseconds(nanoseconds)
timeStr += " " + timezone
timezLiteral.literalExpressionImpl = *literal(timeStr)
return TimezExp(literal(timeStr))
}
func TimezT(t time.Time) TimezExpression {
timeLiteral := &timezLiteral{}
timeLiteral.literalExpressionImpl = *literal(t)
timeLiteral.timezInterfaceImpl.parent = timeLiteral
return timeLiteral
}
//---------------------------------------------------//
type timestampLiteral struct {
timestampInterfaceImpl
literalExpressionImpl
}
// Timestamp creates new timestamp literal expression
func Timestamp(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) TimestampExpression {
timestamp := &timestampLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second)
timeStr += formatNanoseconds(nanoseconds...)
timestamp.literalExpressionImpl = *literal(timeStr)
timestamp.timestampInterfaceImpl.parent = timestamp
return timestamp
}
func TimestampT(t time.Time) TimestampExpression {
timestamp := &timestampLiteral{}
timestamp.literalExpressionImpl = *literal(t)
timestamp.timestampInterfaceImpl.parent = timestamp
return timestamp
}
//---------------------------------------------------//
type timestampzLiteral struct {
timestampzInterfaceImpl
literalExpressionImpl
}
// Timestamp creates new timestamp literal expression
func Timestampz(year int, month time.Month, day, hour, minute, second int, nanoseconds time.Duration, timezone string) TimestampzExpression {
timestamp := &timestampzLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second)
timeStr += formatNanoseconds(nanoseconds)
timeStr += " " + timezone
timestamp.literalExpressionImpl = *literal(timeStr)
timestamp.timestampzInterfaceImpl.parent = timestamp
return timestamp
}
func TimestampzT(t time.Time) TimestampzExpression {
timestamp := &timestampzLiteral{}
timestamp.literalExpressionImpl = *literal(t)
timestamp.timestampzInterfaceImpl.parent = timestamp
return timestamp
}
//---------------------------------------------------//
type dateLiteral struct {
dateInterfaceImpl
literalExpressionImpl
}
//Date creates new date expression
func Date(year int, month time.Month, day int) DateExpression {
dateLiteral := &dateLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d", year, month, day)
dateLiteral.literalExpressionImpl = *literal(timeStr)
dateLiteral.dateInterfaceImpl.parent = dateLiteral
return dateLiteral
}
func DateT(t time.Time) DateExpression {
dateLiteral := &dateLiteral{}
dateLiteral.literalExpressionImpl = *literal(t)
dateLiteral.dateInterfaceImpl.parent = dateLiteral
return dateLiteral
}
func formatNanoseconds(nanoseconds ...time.Duration) string {
if len(nanoseconds) > 0 && nanoseconds[0] != 0 {
duration := fmt.Sprintf("%09d", nanoseconds[0])
i := len(duration) - 1
for ; i >= 3; i-- {
if duration[i] != '0' {
break
}
}
return "." + duration[0:i+1]
}
return ""
}
//--------------------------------------------------//
type nullLiteral struct {
ExpressionInterfaceImpl
}
func newNullLiteral() Expression {
nullExpression := &nullLiteral{}
nullExpression.ExpressionInterfaceImpl.Parent = nullExpression
return nullExpression
}
func (n *nullLiteral) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString("NULL")
}
//--------------------------------------------------//
type starLiteral struct {
ExpressionInterfaceImpl
}
func newStarLiteral() Expression {
starExpression := &starLiteral{}
starExpression.ExpressionInterfaceImpl.Parent = starExpression
return starExpression
}
func (n *starLiteral) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString("*")
}
//---------------------------------------------------//
type wrap struct {
ExpressionInterfaceImpl
expressions []Expression
}
func (n *wrap) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString("(")
serializeExpressionList(statement, n.expressions, ", ", out)
out.WriteString(")")
}
// WRAP wraps list of expressions with brackets '(' and ')'
func WRAP(expression ...Expression) Expression {
wrap := &wrap{expressions: expression}
wrap.ExpressionInterfaceImpl.Parent = wrap
return wrap
}
//---------------------------------------------------//
type rawExpression struct {
ExpressionInterfaceImpl
raw string
}
func (n *rawExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString(n.raw)
}
// Raw can be used for any unsupported functions, operators or expressions.
// For example: Raw("current_database()")
func Raw(raw string) Expression {
rawExp := &rawExpression{raw: raw}
rawExp.ExpressionInterfaceImpl.Parent = rawExp
return rawExp
}

View file

@ -0,0 +1,60 @@
package jet
import (
"testing"
"time"
)
func TestRawExpression(t *testing.T) {
assertClauseSerialize(t, Raw("current_database()"), "current_database()")
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
assertClauseSerialize(t, DateT(timeT), "$1", timeT)
}
func TestTimeLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Time(11, 5, 30), "'11:05:30'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 0), "'11:05:30'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 3*time.Millisecond), "'11:05:30.003'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 30*time.Millisecond), "'11:05:30.030'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 300*time.Millisecond), "'11:05:30.300'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 300*time.Microsecond), "'11:05:30.0003'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 4*time.Nanosecond), "'11:05:30.000000004'")
}
func TestTimeT(t *testing.T) {
timeT := time.Date(2000, 1, 1, 11, 40, 20, 124, time.UTC)
assertClauseDebugSerialize(t, TimeT(timeT), `'2000-01-01 11:40:20.000000124Z'`)
}
func TestTimezLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Timez(11, 5, 30, 10*time.Nanosecond, "UTC"), "'11:05:30.00000001 UTC'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 0, "+1"), "'11:05:30 +1'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 3*time.Microsecond, "-7"), "'11:05:30.000003 -7'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 30*time.Millisecond, "+8:00"), "'11:05:30.030 +8:00'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 300*time.Nanosecond, "America/New_Yor"), "'11:05:30.0000003 America/New_Yor'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 3000*time.Nanosecond, "zulu"), "'11:05:30.000003 zulu'")
}
func TestTimestampLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Timestamp(2011, 1, 8, 11, 5, 30), "'2011-01-08 11:05:30'")
assertClauseDebugSerialize(t, Timestamp(2011, 2, 7, 11, 5, 30, 0), "'2011-02-07 11:05:30'")
assertClauseDebugSerialize(t, Timestamp(2011, 3, 6, 11, 5, 30, 3*time.Millisecond), "'2011-03-06 11:05:30.003'")
assertClauseDebugSerialize(t, Timestamp(2011, 4, 5, 11, 5, 30, 30*time.Millisecond), "'2011-04-05 11:05:30.030'")
assertClauseDebugSerialize(t, Timestamp(2011, 5, 4, 11, 5, 30, 300*time.Millisecond), "'2011-05-04 11:05:30.300'")
assertClauseDebugSerialize(t, Timestamp(2011, 6, 3, 11, 5, 30, 3000*time.Microsecond), "'2011-06-03 11:05:30.003'")
}
func TestTimestampzLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Timestampz(2011, 1, 8, 11, 5, 30, 0, "UTC"), "'2011-01-08 11:05:30 UTC'")
assertClauseDebugSerialize(t, Timestampz(2011, 2, 7, 11, 5, 30, 0, "PST"), "'2011-02-07 11:05:30 PST'")
assertClauseDebugSerialize(t, Timestampz(2011, 3, 6, 11, 5, 30, 3, "+4:00"), "'2011-03-06 11:05:30.000000003 +4:00'")
assertClauseDebugSerialize(t, Timestampz(2011, 4, 5, 11, 5, 30, 30, "-8:00"), "'2011-04-05 11:05:30.00000003 -8:00'")
assertClauseDebugSerialize(t, Timestampz(2011, 5, 4, 11, 5, 30, 300, "400"), "'2011-05-04 11:05:30.0000003 400'")
assertClauseDebugSerialize(t, Timestampz(2011, 6, 3, 11, 5, 30, 3000, "zulu"), "'2011-06-03 11:05:30.000003 zulu'")
}
func TestDate(t *testing.T) {
assertClauseDebugSerialize(t, Date(2019, 8, 8), `'2019-08-08'`)
}

View file

@ -1,6 +1,10 @@
package jet package jet
import "errors" const (
StringConcatOperator = "||"
StringRegexpLikeOperator = "REGEXP"
StringNotRegexpLikeOperator = "NOT REGEXP"
)
//----------- Logical operators ---------------// //----------- Logical operators ---------------//
@ -11,13 +15,16 @@ func NOT(exp BoolExpression) BoolExpression {
// BIT_NOT inverts every bit in integer expression result // BIT_NOT inverts every bit in integer expression result
func BIT_NOT(expr IntegerExpression) IntegerExpression { func BIT_NOT(expr IntegerExpression) IntegerExpression {
if literalExp, ok := expr.(LiteralExpression); ok {
literalExp.SetConstant(true)
}
return newPrefixIntegerOperator(expr, "~") return newPrefixIntegerOperator(expr, "~")
} }
//----------- Comparison operators ---------------// //----------- Comparison operators ---------------//
// EXISTS checks for existence of the rows in subQuery // EXISTS checks for existence of the rows in subQuery
func EXISTS(subQuery SelectStatement) BoolExpression { func EXISTS(subQuery Expression) BoolExpression {
return newPrefixBoolOperator(subQuery, "EXISTS") return newPrefixBoolOperator(subQuery, "EXISTS")
} }
@ -71,7 +78,7 @@ type CaseOperator interface {
} }
type caseOperatorImpl struct { type caseOperatorImpl struct {
expressionInterfaceImpl ExpressionInterfaceImpl
expression Expression expression Expression
when []Expression when []Expression
@ -87,7 +94,7 @@ func CASE(expression ...Expression) CaseOperator {
caseExp.expression = expression[0] caseExp.expression = expression[0]
} }
caseExp.expressionInterfaceImpl.parent = caseExp caseExp.ExpressionInterfaceImpl.Parent = caseExp
return caseExp return caseExp
} }
@ -108,55 +115,33 @@ func (c *caseOperatorImpl) ELSE(els Expression) CaseOperator {
return c return c
} }
func (c *caseOperatorImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (c *caseOperatorImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if c == nil { out.WriteString("(CASE")
return errors.New("jet: Case Expression is nil. ")
}
out.writeString("(CASE")
if c.expression != nil { if c.expression != nil {
err := c.expression.serialize(statement, out) c.expression.serialize(statement, out)
if err != nil {
return err
}
} }
if len(c.when) == 0 || len(c.then) == 0 { if len(c.when) == 0 || len(c.then) == 0 {
return errors.New("jet: Invalid case Statement. There should be at least one when/then Expression pair. ") panic("jet: invalid case Statement. There should be at least one WHEN/THEN pair. ")
} }
if len(c.when) != len(c.then) { if len(c.when) != len(c.then) {
return errors.New("jet: When and then Expression count mismatch. ") panic("jet: WHEN and THEN expression count mismatch. ")
} }
for i, when := range c.when { for i, when := range c.when {
out.writeString("WHEN") out.WriteString("WHEN")
err := when.serialize(statement, out, noWrap) when.serialize(statement, out, noWrap)
if err != nil { out.WriteString("THEN")
return err c.then[i].serialize(statement, out, noWrap)
}
out.writeString("THEN")
err = c.then[i].serialize(statement, out, noWrap)
if err != nil {
return err
}
} }
if c.els != nil { if c.els != nil {
out.writeString("ELSE") out.WriteString("ELSE")
err := c.els.serialize(statement, out, noWrap) c.els.serialize(statement, out, noWrap)
if err != nil {
return err
}
} }
out.writeString("END)") out.WriteString("END)")
return nil
} }

View file

@ -5,10 +5,10 @@ import "testing"
func TestOperatorNOT(t *testing.T) { func TestOperatorNOT(t *testing.T) {
notExpression := NOT(Int(2).EQ(Int(1))) notExpression := NOT(Int(2).EQ(Int(1)))
assertClauseSerialize(t, NOT(table1ColBool), "NOT table1.col_bool") assertClauseSerialize(t, NOT(table1ColBool), "(NOT table1.col_bool)")
assertClauseSerialize(t, notExpression, "NOT ($1 = $2)", int64(2), int64(1)) assertClauseSerialize(t, notExpression, "(NOT ($1 = $2))", int64(2), int64(1))
assertProjectionSerialize(t, notExpression.AS("alias_not_expression"), `NOT ($1 = $2) AS "alias_not_expression"`, int64(2), int64(1)) assertProjectionSerialize(t, notExpression.AS("alias_not_expression"), `(NOT ($1 = $2)) AS "alias_not_expression"`, int64(2), int64(1))
assertClauseSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `(NOT ($1 = $2) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5)) assertClauseSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `((NOT ($1 = $2)) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5))
} }
func TestCase1(t *testing.T) { func TestCase1(t *testing.T) {

View file

@ -0,0 +1,29 @@
package jet
// OrderByClause
type OrderByClause interface {
serializeForOrderBy(statement StatementType, out *SqlBuilder)
}
type orderByClauseImpl struct {
expression Expression
ascent bool
}
func (o *orderByClauseImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) {
if o.expression == nil {
panic("jet: nil expression in ORDER BY clause")
}
o.expression.serializeForOrderBy(statement, out)
if o.ascent {
out.WriteString("ASC")
} else {
out.WriteString("DESC")
}
}
func newOrderByClause(expression Expression, ascent bool) OrderByClause {
return &orderByClauseImpl{expression: expression, ascent: ascent}
}

View file

@ -0,0 +1,27 @@
package jet
type Projection interface {
serializeForProjection(statement StatementType, out *SqlBuilder)
fromImpl(subQuery SelectTable) Projection
}
func SerializeForProjection(projection Projection, statementType StatementType, out *SqlBuilder) {
projection.serializeForProjection(statementType, out)
}
// ProjectionList is a redefined type, so that ProjectionList can be used as a Projection.
type ProjectionList []Projection
func (cl ProjectionList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, projection := range cl {
newProjectionList = append(newProjectionList, projection.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ProjectionList) serializeForProjection(statement StatementType, out *SqlBuilder) {
SerializeProjectionList(statement, cl, out)
}

View file

@ -0,0 +1,46 @@
package jet
// SelectLock is interface for SELECT statement locks
type SelectLock interface {
Serializer
NOWAIT() SelectLock
SKIP_LOCKED() SelectLock
}
type selectLockImpl struct {
lockStrength string
noWait, skipLocked bool
}
func NewSelectLock(name string) func() SelectLock {
return func() SelectLock {
return newSelectLock(name)
}
}
func newSelectLock(lockStrength string) SelectLock {
return &selectLockImpl{lockStrength: lockStrength}
}
func (s *selectLockImpl) NOWAIT() SelectLock {
s.noWait = true
return s
}
func (s *selectLockImpl) SKIP_LOCKED() SelectLock {
s.skipLocked = true
return s
}
func (s *selectLockImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString(s.lockStrength)
if s.noWait {
out.WriteString("NOWAIT")
}
if s.skipLocked {
out.WriteString("SKIP LOCKED")
}
}

View file

@ -0,0 +1,42 @@
package jet
// SelectTable is interface for SELECT sub-queries
type SelectTable interface {
Alias() string
AllColumns() ProjectionList
}
type SelectTableImpl struct {
selectStmt StatementWithProjections
alias string
projections ProjectionList
}
func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl {
selectTable := SelectTableImpl{selectStmt: selectStmt, alias: alias}
projectionList := selectStmt.projections().fromImpl(&selectTable)
selectTable.projections = projectionList.(ProjectionList)
return selectTable
}
func (s *SelectTableImpl) Alias() string {
return s.alias
}
func (s *SelectTableImpl) AllColumns() ProjectionList {
return s.projections
}
func (s *SelectTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if s == nil {
panic("jet: expression table is nil. ")
}
s.selectStmt.serialize(statement, out)
out.WriteString("AS")
out.WriteIdentifier(s.alias)
}

View file

@ -0,0 +1,37 @@
package jet
type SerializeOption int
const (
noWrap SerializeOption = iota
)
type StatementType string
const (
SelectStatementType StatementType = "SELECT"
InsertStatementType StatementType = "INSERT"
UpdateStatementType StatementType = "UPDATE"
DeleteStatementType StatementType = "DELETE"
SetStatementType StatementType = "SET"
LockStatementType StatementType = "LOCK"
UnLockStatementType StatementType = "UNLOCK"
)
type Serializer interface {
serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption)
}
func Serialize(exp Serializer, statementType StatementType, out *SqlBuilder, options ...SerializeOption) {
exp.serialize(statementType, out, options...)
}
func contains(options []SerializeOption, option SerializeOption) bool {
for _, opt := range options {
if opt == option {
return true
}
}
return false
}

View file

@ -21,8 +21,10 @@ func TestArgToString(t *testing.T) {
assert.Equal(t, argToString(uint(32)), "32") assert.Equal(t, argToString(uint(32)), "32")
assert.Equal(t, argToString(uint32(32)), "32") assert.Equal(t, argToString(uint32(32)), "32")
assert.Equal(t, argToString(uint64(64)), "64") assert.Equal(t, argToString(uint64(64)), "64")
assert.Equal(t, argToString(float64(1.11)), "1.11")
assert.Equal(t, argToString("john"), "'john'") assert.Equal(t, argToString("john"), "'john'")
assert.Equal(t, argToString("It's text"), "'It''s text'")
assert.Equal(t, argToString([]byte("john")), "'john'") assert.Equal(t, argToString([]byte("john")), "'john'")
assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'") assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")

176
internal/jet/sql_builder.go Normal file
View file

@ -0,0 +1,176 @@
package jet
import (
"bytes"
"github.com/go-jet/jet/internal/utils"
"github.com/google/uuid"
"strconv"
"strings"
"time"
)
type SqlBuilder struct {
Dialect Dialect
Buff bytes.Buffer
Args []interface{}
lastChar byte
ident int
debug bool
}
const defaultIdent = 5
func (s *SqlBuilder) IncreaseIdent(ident ...int) {
if len(ident) > 0 {
s.ident += ident[0]
} else {
s.ident += defaultIdent
}
}
func (s *SqlBuilder) DecreaseIdent(ident ...int) {
toDecrease := defaultIdent
if len(ident) > 0 {
toDecrease = ident[0]
}
if s.ident < toDecrease {
s.ident = 0
}
s.ident -= toDecrease
}
func (s *SqlBuilder) WriteProjections(statement StatementType, projections []Projection) {
s.IncreaseIdent()
SerializeProjectionList(statement, projections, s)
s.DecreaseIdent()
}
func (s *SqlBuilder) NewLine() {
s.write([]byte{'\n'})
s.write(bytes.Repeat([]byte{' '}, s.ident))
}
func (s *SqlBuilder) write(data []byte) {
if len(data) == 0 {
return
}
if !isPreSeparator(s.lastChar) && !isPostSeparator(data[0]) && s.Buff.Len() > 0 {
s.Buff.WriteByte(' ')
}
s.Buff.Write(data)
s.lastChar = data[len(data)-1]
}
func isPreSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':'
}
func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
}
func (s *SqlBuilder) WriteAlias(str string) {
aliasQuoteChar := string(s.Dialect.AliasQuoteChar())
s.WriteString(aliasQuoteChar + str + aliasQuoteChar)
}
func (s *SqlBuilder) WriteString(str string) {
s.write([]byte(str))
}
func (s *SqlBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
if quoteWrap || len(alwaysQuote) > 0 {
identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
s.WriteString(identQuoteChar + name + identQuoteChar)
} else {
s.WriteString(name)
}
}
func (s *SqlBuilder) WriteByte(b byte) {
s.write([]byte{b})
}
func (s *SqlBuilder) finalize() (string, []interface{}) {
return s.Buff.String() + ";\n", s.Args
}
func (s *SqlBuilder) insertConstantArgument(arg interface{}) {
s.WriteString(argToString(arg))
}
func (s *SqlBuilder) insertParametrizedArgument(arg interface{}) {
if s.debug {
s.insertConstantArgument(arg)
return
}
s.Args = append(s.Args, arg)
argPlaceholder := s.Dialect.ArgumentPlaceholder()(len(s.Args))
s.WriteString(argPlaceholder)
}
func argToString(value interface{}) string {
if utils.IsNil(value) {
return "NULL"
}
switch bindVal := value.(type) {
case bool:
if bindVal {
return "TRUE"
}
return "FALSE"
case int8:
return strconv.FormatInt(int64(bindVal), 10)
case int:
return strconv.FormatInt(int64(bindVal), 10)
case int16:
return strconv.FormatInt(int64(bindVal), 10)
case int32:
return strconv.FormatInt(int64(bindVal), 10)
case int64:
return strconv.FormatInt(int64(bindVal), 10)
case uint8:
return strconv.FormatUint(uint64(bindVal), 10)
case uint:
return strconv.FormatUint(uint64(bindVal), 10)
case uint16:
return strconv.FormatUint(uint64(bindVal), 10)
case uint32:
return strconv.FormatUint(uint64(bindVal), 10)
case uint64:
return strconv.FormatUint(uint64(bindVal), 10)
case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case float64:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case string:
return stringQuote(bindVal)
case []byte:
return stringQuote(string(bindVal))
case uuid.UUID:
return stringQuote(bindVal.String())
case time.Time:
return stringQuote(string(utils.FormatTimestamp(bindVal)))
default:
return "[Unsupported type]"
}
}
func stringQuote(value string) string {
return `'` + strings.Replace(value, "'", "''", -1) + `'`
}

145
internal/jet/statement.go Normal file
View file

@ -0,0 +1,145 @@
package jet
import (
"context"
"database/sql"
"github.com/go-jet/jet/execution"
)
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
type Statement interface {
// Sql returns parametrized sql query with list of arguments.
Sql() (query string, args []interface{})
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument.
// Do not use it in production. Use it only for debug purposes.
DebugSql() (query string)
// Query executes statement over database connection db and stores row result in destination.
// Destination can be arbitrary structure
Query(db execution.DB, destination interface{}) error
// QueryContext executes statement with a context over database connection db and stores row result in destination.
// Destination can be of arbitrary structure
QueryContext(context context.Context, db execution.DB, destination interface{}) error
//Exec executes statement over db connection without returning any rows.
Exec(db execution.DB) (sql.Result, error)
//Exec executes statement with context over db connection without returning any rows.
ExecContext(context context.Context, db execution.DB) (sql.Result, error)
}
type SerializerStatement interface {
Serializer
Statement
}
type StatementWithProjections interface {
Statement
HasProjections
Serializer
}
type HasProjections interface {
projections() ProjectionList
}
type SerializerStatementInterfaceImpl struct {
dialect Dialect
statementType StatementType
parent SerializerStatement
}
func (s *SerializerStatementInterfaceImpl) Sql() (query string, args []interface{}) {
queryData := &SqlBuilder{Dialect: s.dialect}
s.parent.serialize(s.statementType, queryData, noWrap)
query, args = queryData.finalize()
return
}
func (s *SerializerStatementInterfaceImpl) DebugSql() (query string) {
sqlBuilder := &SqlBuilder{Dialect: s.dialect, debug: true}
s.parent.serialize(s.statementType, sqlBuilder, noWrap)
query, _ = sqlBuilder.finalize()
return
}
func (s *SerializerStatementInterfaceImpl) Query(db execution.DB, destination interface{}) error {
query, args := s.Sql()
return execution.Query(context.Background(), db, query, args, destination)
}
func (s *SerializerStatementInterfaceImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
query, args := s.Sql()
return execution.Query(context, db, query, args, destination)
}
func (s *SerializerStatementInterfaceImpl) Exec(db execution.DB) (res sql.Result, err error) {
query, args := s.Sql()
return db.Exec(query, args...)
}
func (s *SerializerStatementInterfaceImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
query, args := s.Sql()
return db.ExecContext(context, query, args...)
}
type ExpressionStatementImpl struct {
ExpressionInterfaceImpl
StatementImpl
}
func (s *ExpressionStatementImpl) serializeForProjection(statement StatementType, out *SqlBuilder) {
s.serialize(statement, out)
}
func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) StatementImpl {
return StatementImpl{
SerializerStatementInterfaceImpl: SerializerStatementInterfaceImpl{
parent: parent,
dialect: Dialect,
statementType: statementType,
},
Clauses: clauses,
}
}
type StatementImpl struct {
SerializerStatementInterfaceImpl
Clauses []Clause
}
func (s *StatementImpl) projections() ProjectionList {
for _, clause := range s.Clauses {
if selectClause, ok := clause.(ClauseWithProjections); ok {
return selectClause.projections()
}
}
return nil
}
func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if !contains(options, noWrap) {
out.WriteString("(")
out.IncreaseIdent()
}
for _, clause := range s.Clauses {
clause.Serialize(statement, out)
}
if !contains(options, noWrap) {
out.DecreaseIdent()
out.NewLine()
out.WriteString(")")
}
}

View file

@ -18,8 +18,9 @@ type StringExpression interface {
LIKE(pattern StringExpression) BoolExpression LIKE(pattern StringExpression) BoolExpression
NOT_LIKE(pattern StringExpression) BoolExpression NOT_LIKE(pattern StringExpression) BoolExpression
SIMILAR_TO(pattern StringExpression) BoolExpression
NOT_SIMILAR_TO(pattern StringExpression) BoolExpression REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression
NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression
} }
type stringInterfaceImpl struct { type stringInterfaceImpl struct {
@ -59,7 +60,7 @@ func (s *stringInterfaceImpl) LT_EQ(rhs StringExpression) BoolExpression {
} }
func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression { func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression {
return newBinaryStringExpression(s.parent, rhs, "||") return newBinaryStringExpression(s.parent, rhs, StringConcatOperator)
} }
func (s *stringInterfaceImpl) LIKE(pattern StringExpression) BoolExpression { func (s *stringInterfaceImpl) LIKE(pattern StringExpression) BoolExpression {
@ -70,17 +71,18 @@ func (s *stringInterfaceImpl) NOT_LIKE(pattern StringExpression) BoolExpression
return newBinaryBoolOperator(s.parent, pattern, "NOT LIKE") return newBinaryBoolOperator(s.parent, pattern, "NOT LIKE")
} }
func (s *stringInterfaceImpl) SIMILAR_TO(pattern StringExpression) BoolExpression { func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, "SIMILAR TO") return newBinaryBoolOperator(s.parent, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0]))
} }
func (s *stringInterfaceImpl) NOT_SIMILAR_TO(pattern StringExpression) BoolExpression { func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, "NOT SIMILAR TO") return newBinaryBoolOperator(s.parent, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0]))
} }
//---------------------------------------------------// //---------------------------------------------------//
type binaryStringExpression struct { type binaryStringExpression struct {
expressionInterfaceImpl ExpressionInterfaceImpl
stringInterfaceImpl stringInterfaceImpl
binaryOpExpression binaryOpExpression
@ -90,7 +92,7 @@ func newBinaryStringExpression(lhs, rhs Expression, operator string) StringExpre
boolExpression := binaryStringExpression{} boolExpression := binaryStringExpression{}
boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
boolExpression.expressionInterfaceImpl.parent = &boolExpression boolExpression.ExpressionInterfaceImpl.Parent = &boolExpression
boolExpression.stringInterfaceImpl.parent = &boolExpression boolExpression.stringInterfaceImpl.parent = &boolExpression
return &boolExpression return &boolExpression

View file

@ -66,14 +66,14 @@ func TestStringNOT_LIKE(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_LIKE(String("JOHN")), "(table3.col2 NOT LIKE $1)", "JOHN") assertClauseSerialize(t, table3StrCol.NOT_LIKE(String("JOHN")), "(table3.col2 NOT LIKE $1)", "JOHN")
} }
func TestStringSIMILAR_TO(t *testing.T) { func TestStringREGEXP_LIKE(t *testing.T) {
assertClauseSerialize(t, table3StrCol.SIMILAR_TO(table2ColStr), "(table3.col2 SIMILAR TO table2.col_str)") assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)")
assertClauseSerialize(t, table3StrCol.SIMILAR_TO(String("JOHN")), "(table3.col2 SIMILAR TO $1)", "JOHN") assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 REGEXP $1)", "JOHN")
} }
func TestStringNOT_SIMILAR_TO(t *testing.T) { func TestStringNOT_REGEXP_LIKE(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_SIMILAR_TO(table2ColStr), "(table3.col2 NOT SIMILAR TO table2.col_str)") assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)")
assertClauseSerialize(t, table3StrCol.NOT_SIMILAR_TO(String("JOHN")), "(table3.col2 NOT SIMILAR TO $1)", "JOHN") assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP $1)", "JOHN")
} }
func TestStringExp(t *testing.T) { func TestStringExp(t *testing.T) {

209
internal/jet/table.go Normal file
View file

@ -0,0 +1,209 @@
package jet
import (
"github.com/go-jet/jet/internal/utils"
)
type SerializerTable interface {
Serializer
TableInterface
}
type TableInterface interface {
columns() []Column
SchemaName() string
TableName() string
AS(alias string)
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, columns ...ColumnExpression) TableImpl {
t := TableImpl{
schemaName: schemaName,
name: name,
columnList: columns,
}
for _, c := range columns {
c.setTableName(name)
}
return t
}
type TableImpl struct {
schemaName string
name string
alias string
columnList []ColumnExpression
}
func (t *TableImpl) AS(alias string) {
t.alias = alias
for _, c := range t.columnList {
c.setTableName(alias)
}
}
func (t *TableImpl) SchemaName() string {
return t.schemaName
}
func (t *TableImpl) TableName() string {
return t.name
}
func (t *TableImpl) columns() []Column {
ret := []Column{}
for _, col := range t.columnList {
ret = append(ret, col)
}
return ret
}
func (t *TableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if t == nil {
panic("jet: tableImpl is nil")
}
out.WriteIdentifier(t.schemaName)
out.WriteString(".")
out.WriteIdentifier(t.name)
if len(t.alias) > 0 {
out.WriteString("AS")
out.WriteIdentifier(t.alias)
}
}
type JoinType int
const (
InnerJoin JoinType = iota
LeftJoin
RightJoin
FullJoin
CrossJoin
)
// Join expressions are pseudo readable tables.
type JoinTableImpl struct {
lhs Serializer
rhs Serializer
joinType JoinType
onCondition BoolExpression
}
func NewJoinTableImpl(lhs Serializer, rhs Serializer, joinType JoinType, onCondition BoolExpression) JoinTableImpl {
joinTable := JoinTableImpl{
lhs: lhs,
rhs: rhs,
joinType: joinType,
onCondition: onCondition,
}
return joinTable
}
func (t *JoinTableImpl) SchemaName() string {
if table, ok := t.lhs.(TableInterface); ok {
return table.SchemaName()
}
return ""
}
func (t *JoinTableImpl) TableName() string {
return ""
}
func (t *JoinTableImpl) Columns() []Column {
var ret []Column
if lhsTable, ok := t.lhs.(TableInterface); ok {
ret = append(ret, lhsTable.columns()...)
}
if rhsTable, ok := t.rhs.(TableInterface); ok {
ret = append(ret, rhsTable.columns()...)
}
return ret
}
func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if t == nil {
panic("jet: Join table is nil. ")
}
if utils.IsNil(t.lhs) {
panic("jet: left hand side of join operation is nil table")
}
t.lhs.serialize(statement, out)
out.NewLine()
switch t.joinType {
case InnerJoin:
out.WriteString("INNER JOIN")
case LeftJoin:
out.WriteString("LEFT JOIN")
case RightJoin:
out.WriteString("RIGHT JOIN")
case FullJoin:
out.WriteString("FULL JOIN")
case CrossJoin:
out.WriteString("CROSS JOIN")
}
if utils.IsNil(t.rhs) {
panic("jet: right hand side of join operation is nil table")
}
t.rhs.serialize(statement, out)
if t.onCondition == nil && t.joinType != CrossJoin {
panic("jet: join condition is nil")
}
if t.onCondition != nil {
out.WriteString("ON")
t.onCondition.serialize(statement, out)
}
}
func UnwindColumns(column1 Column, columns ...Column) []Column {
columnList := []Column{}
if val, ok := column1.(IColumnList); ok {
for _, col := range val.columns() {
columnList = append(columnList, col)
}
columnList = append(columnList, columns...)
} else {
columnList = append(columnList, column1)
columnList = append(columnList, columns...)
}
return columnList
}
func UnwidColumnList(columns []Column) []Column {
ret := []Column{}
for _, col := range columns {
if columnList, ok := col.(IColumnList); ok {
for _, c := range columnList.columns() {
ret = append(ret, c)
}
} else {
ret = append(ret, col)
}
}
return ret
}

85
internal/jet/testutils.go Normal file
View file

@ -0,0 +1,85 @@
package jet
import (
"gotest.tools/assert"
"strconv"
"testing"
)
var DefaultDialect = NewDialect(DialectParams{ // just for tests
AliasQuoteChar: '"',
IdentifierQuoteChar: '"',
ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord)
},
})
var table1Col1 = IntegerColumn("col1")
var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float")
var table1Col3 = IntegerColumn("col3")
var table1ColTime = TimeColumn("col_time")
var table1ColTimez = TimezColumn("col_timez")
var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColTimestampz = TimestampzColumn("col_timestampz")
var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date")
var table1 = NewTable("db", "table1", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz)
var table2Col3 = IntegerColumn("col3")
var table2Col4 = IntegerColumn("col4")
var table2ColInt = IntegerColumn("col_int")
var table2ColFloat = FloatColumn("col_float")
var table2ColStr = StringColumn("col_str")
var table2ColBool = BoolColumn("col_bool")
var table2ColTime = TimeColumn("col_time")
var table2ColTimez = TimezColumn("col_timez")
var table2ColTimestamp = TimestampColumn("col_timestamp")
var table2ColTimestampz = TimestampzColumn("col_timestampz")
var table2ColDate = DateColumn("col_date")
var table2 = NewTable("db", "table2", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz)
var table3Col1 = IntegerColumn("col1")
var table3ColInt = IntegerColumn("col_int")
var table3StrCol = StringColumn("col2")
var table3 = NewTable("db", "table3", table3Col1, table3ColInt, table3StrCol)
func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) {
out := SqlBuilder{Dialect: DefaultDialect}
clause.serialize(SelectStatementType, &out)
//fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}
func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) {
defer func() {
r := recover()
assert.Equal(t, r, errString)
}()
out := SqlBuilder{Dialect: DefaultDialect}
clause.serialize(SelectStatementType, &out)
}
func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) {
out := SqlBuilder{Dialect: DefaultDialect, debug: true}
clause.serialize(SelectStatementType, &out)
//fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}
func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) {
out := SqlBuilder{Dialect: DefaultDialect}
projection.serializeForProjection(SelectStatementType, &out)
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}

View file

@ -53,7 +53,7 @@ func (t *timeInterfaceImpl) GT_EQ(rhs TimeExpression) BoolExpression {
//---------------------------------------------------// //---------------------------------------------------//
type prefixTimeExpression struct { type prefixTimeExpression struct {
expressionInterfaceImpl ExpressionInterfaceImpl
timeInterfaceImpl timeInterfaceImpl
prefixOpExpression prefixOpExpression
@ -63,7 +63,7 @@ type prefixTimeExpression struct {
// timeExpr := prefixTimeExpression{} // timeExpr := prefixTimeExpression{}
// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) // timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
// //
// timeExpr.expressionInterfaceImpl.parent = &timeExpr // timeExpr.ExpressionInterfaceImpl.parent = &timeExpr
// timeExpr.timeInterfaceImpl.parent = &timeExpr // timeExpr.timeInterfaceImpl.parent = &timeExpr
// //
// return &timeExpr // return &timeExpr

View file

@ -2,52 +2,53 @@ package jet
import ( import (
"testing" "testing"
"time"
) )
var timeVar = Time(10, 20, 0, 0) var timeVar = Time(10, 20, 0, 0)
func TestTimeExpressionEQ(t *testing.T) { func TestTimeExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.EQ(table2ColTime), "(table1.col_time = table2.col_time)") assertClauseSerialize(t, table1ColTime.EQ(table2ColTime), "(table1.col_time = table2.col_time)")
assertClauseSerialize(t, table1ColTime.EQ(timeVar), "(table1.col_time = $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.EQ(timeVar), "(table1.col_time = $1)", "10:20:00")
} }
func TestTimeExpressionNOT_EQ(t *testing.T) { func TestTimeExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.NOT_EQ(table2ColTime), "(table1.col_time != table2.col_time)") assertClauseSerialize(t, table1ColTime.NOT_EQ(table2ColTime), "(table1.col_time != table2.col_time)")
assertClauseSerialize(t, table1ColTime.NOT_EQ(timeVar), "(table1.col_time != $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.NOT_EQ(timeVar), "(table1.col_time != $1)", "10:20:00")
} }
func TestTimeExpressionIS_DISTINCT_FROM(t *testing.T) { func TestTimeExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(table2ColTime), "(table1.col_time IS DISTINCT FROM table2.col_time)") assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(table2ColTime), "(table1.col_time IS DISTINCT FROM table2.col_time)")
assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(timeVar), "(table1.col_time IS DISTINCT FROM $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(timeVar), "(table1.col_time IS DISTINCT FROM $1)", "10:20:00")
} }
func TestTimeExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { func TestTimeExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(table2ColTime), "(table1.col_time IS NOT DISTINCT FROM table2.col_time)") assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(table2ColTime), "(table1.col_time IS NOT DISTINCT FROM table2.col_time)")
assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(timeVar), "(table1.col_time IS NOT DISTINCT FROM $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(timeVar), "(table1.col_time IS NOT DISTINCT FROM $1)", "10:20:00")
} }
func TestTimeExpressionLT(t *testing.T) { func TestTimeExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTime.LT(table2ColTime), "(table1.col_time < table2.col_time)") assertClauseSerialize(t, table1ColTime.LT(table2ColTime), "(table1.col_time < table2.col_time)")
assertClauseSerialize(t, table1ColTime.LT(timeVar), "(table1.col_time < $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.LT(timeVar), "(table1.col_time < $1)", "10:20:00")
} }
func TestTimeExpressionLT_EQ(t *testing.T) { func TestTimeExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.LT_EQ(table2ColTime), "(table1.col_time <= table2.col_time)") assertClauseSerialize(t, table1ColTime.LT_EQ(table2ColTime), "(table1.col_time <= table2.col_time)")
assertClauseSerialize(t, table1ColTime.LT_EQ(timeVar), "(table1.col_time <= $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.LT_EQ(timeVar), "(table1.col_time <= $1)", "10:20:00")
} }
func TestTimeExpressionGT(t *testing.T) { func TestTimeExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTime.GT(table2ColTime), "(table1.col_time > table2.col_time)") assertClauseSerialize(t, table1ColTime.GT(table2ColTime), "(table1.col_time > table2.col_time)")
assertClauseSerialize(t, table1ColTime.GT(timeVar), "(table1.col_time > $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.GT(timeVar), "(table1.col_time > $1)", "10:20:00")
} }
func TestTimeExpressionGT_EQ(t *testing.T) { func TestTimeExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.GT_EQ(table2ColTime), "(table1.col_time >= table2.col_time)") assertClauseSerialize(t, table1ColTime.GT_EQ(table2ColTime), "(table1.col_time >= table2.col_time)")
assertClauseSerialize(t, table1ColTime.GT_EQ(timeVar), "(table1.col_time >= $1::time without time zone)", "10:20:00.000") assertClauseSerialize(t, table1ColTime.GT_EQ(timeVar), "(table1.col_time >= $1)", "10:20:00")
} }
func TestTimeExp(t *testing.T) { func TestTimeExp(t *testing.T) {
assertClauseSerialize(t, TimeExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimeExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1)), assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1*time.Millisecond)),
"(table1.col_float < $1::time without time zone)", string("01:01:01.001")) "(table1.col_float < $1)", string("01:01:01.001"))
} }

View file

@ -1,52 +1,55 @@
package jet package jet
import "testing" import (
"testing"
"time"
)
var timestamp = Timestamp(2000, 1, 31, 10, 20, 0, 0) var timestamp = Timestamp(2000, 1, 31, 10, 20, 0, 3*time.Millisecond)
func TestTimestampExpressionEQ(t *testing.T) { func TestTimestampExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.EQ(table2ColTimestamp), "(table1.col_timestamp = table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.EQ(table2ColTimestamp), "(table1.col_timestamp = table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.EQ(timestamp), assertClauseSerialize(t, table1ColTimestamp.EQ(timestamp),
"(table1.col_timestamp = $1::timestamp without time zone)", "2000-01-31 10:20:00.000") "(table1.col_timestamp = $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionNOT_EQ(t *testing.T) { func TestTimestampExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(table2ColTimestamp), "(table1.col_timestamp != table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(table2ColTimestamp), "(table1.col_timestamp != table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(timestamp), "(table1.col_timestamp != $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(timestamp), "(table1.col_timestamp != $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionIS_DISTINCT_FROM(t *testing.T) { func TestTimestampExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS DISTINCT FROM table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS DISTINCT FROM table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS DISTINCT FROM $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS DISTINCT FROM $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { func TestTimestampExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS NOT DISTINCT FROM table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS NOT DISTINCT FROM table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS NOT DISTINCT FROM $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS NOT DISTINCT FROM $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionLT(t *testing.T) { func TestTimestampExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.LT(table2ColTimestamp), "(table1.col_timestamp < table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.LT(table2ColTimestamp), "(table1.col_timestamp < table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.LT(timestamp), "(table1.col_timestamp < $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.LT(timestamp), "(table1.col_timestamp < $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionLT_EQ(t *testing.T) { func TestTimestampExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.LT_EQ(table2ColTimestamp), "(table1.col_timestamp <= table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.LT_EQ(table2ColTimestamp), "(table1.col_timestamp <= table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.LT_EQ(timestamp), "(table1.col_timestamp <= $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.LT_EQ(timestamp), "(table1.col_timestamp <= $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionGT(t *testing.T) { func TestTimestampExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.GT(table2ColTimestamp), "(table1.col_timestamp > table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.GT(table2ColTimestamp), "(table1.col_timestamp > table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.GT(timestamp), "(table1.col_timestamp > $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.GT(timestamp), "(table1.col_timestamp > $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExpressionGT_EQ(t *testing.T) { func TestTimestampExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.GT_EQ(table2ColTimestamp), "(table1.col_timestamp >= table2.col_timestamp)") assertClauseSerialize(t, table1ColTimestamp.GT_EQ(table2ColTimestamp), "(table1.col_timestamp >= table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.GT_EQ(timestamp), "(table1.col_timestamp >= $1::timestamp without time zone)", "2000-01-31 10:20:00.000") assertClauseSerialize(t, table1ColTimestamp.GT_EQ(timestamp), "(table1.col_timestamp >= $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampExp(t *testing.T) { func TestTimestampExp(t *testing.T) {
assertClauseSerialize(t, TimestampExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimestampExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp), assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp),
"(table1.col_float < $1::timestamp without time zone)", "2000-01-31 10:20:00.000") "(table1.col_float < $1)", "2000-01-31 10:20:00.003")
} }

View file

@ -51,6 +51,25 @@ func (t *timestampzInterfaceImpl) GT_EQ(rhs TimestampzExpression) BoolExpression
return gtEq(t.parent, rhs) return gtEq(t.parent, rhs)
} }
//---------------------------------------------------//
type prefixTimestampzOperator struct {
ExpressionInterfaceImpl
timestampzInterfaceImpl
prefixOpExpression
}
func NewPrefixTimestampOperator(operator string, expression Expression) TimestampzExpression {
timeExpr := prefixTimestampzOperator{}
timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
timeExpr.ExpressionInterfaceImpl.Parent = &timeExpr
timeExpr.timestampzInterfaceImpl.parent = &timeExpr
return &timeExpr
}
//------------------------------------------------- //-------------------------------------------------
type timestampzExpressionWrapper struct { type timestampzExpressionWrapper struct {

View file

@ -1,52 +1,55 @@
package jet package jet
import "testing" import (
"testing"
"time"
)
var timestampz = Timestampz(2000, 1, 31, 10, 20, 0, 0, 2) var timestampz = Timestampz(2000, 1, 31, 10, 20, 5, 23*time.Microsecond, "+200")
func TestTimestampzExpressionEQ(t *testing.T) { func TestTimestampzExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.EQ(table2ColTimestampz), "(table1.col_timestampz = table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.EQ(table2ColTimestampz), "(table1.col_timestampz = table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.EQ(timestampz), assertClauseSerialize(t, table1ColTimestampz.EQ(timestampz),
"(table1.col_timestampz = $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") "(table1.col_timestampz = $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionNOT_EQ(t *testing.T) { func TestTimestampzExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(table2ColTimestampz), "(table1.col_timestampz != table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(table2ColTimestampz), "(table1.col_timestampz != table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionIS_DISTINCT_FROM(t *testing.T) { func TestTimestampzExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS DISTINCT FROM table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { func TestTimestampzExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS NOT DISTINCT FROM table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS NOT DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionLT(t *testing.T) { func TestTimestampzExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT(table2ColTimestampz), "(table1.col_timestampz < table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.LT(table2ColTimestampz), "(table1.col_timestampz < table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionLT_EQ(t *testing.T) { func TestTimestampzExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(table2ColTimestampz), "(table1.col_timestampz <= table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.LT_EQ(table2ColTimestampz), "(table1.col_timestampz <= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionGT(t *testing.T) { func TestTimestampzExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT(table2ColTimestampz), "(table1.col_timestampz > table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.GT(table2ColTimestampz), "(table1.col_timestampz > table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExpressionGT_EQ(t *testing.T) { func TestTimestampzExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(table2ColTimestampz), "(table1.col_timestampz >= table2.col_timestampz)") assertClauseSerialize(t, table1ColTimestampz.GT_EQ(table2ColTimestampz), "(table1.col_timestampz >= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzExp(t *testing.T) { func TestTimestampzExp(t *testing.T) {
assertClauseSerialize(t, TimestampzExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimestampzExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz), assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz),
"(table1.col_float < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002") "(table1.col_float < $1)", "2000-01-31 10:20:05.000023 +200")
} }

View file

@ -61,7 +61,7 @@ func (t *timezInterfaceImpl) GT_EQ(rhs TimezExpression) BoolExpression {
//---------------------------------------------------// //---------------------------------------------------//
type prefixTimezExpression struct { type prefixTimezExpression struct {
expressionInterfaceImpl ExpressionInterfaceImpl
timezInterfaceImpl timezInterfaceImpl
prefixOpExpression prefixOpExpression
@ -71,7 +71,7 @@ type prefixTimezExpression struct {
// timeExpr := prefixTimezExpression{} // timeExpr := prefixTimezExpression{}
// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) // timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
// //
// timeExpr.expressionInterfaceImpl.parent = &timeExpr // timeExpr.ExpressionInterfaceImpl.parent = &timeExpr
// timeExpr.timezInterfaceImpl.parent = &timeExpr // timeExpr.timezInterfaceImpl.parent = &timeExpr
// //
// return &timeExpr // return &timeExpr

View file

@ -2,50 +2,50 @@ package jet
import "testing" import "testing"
var timezVar = Timez(10, 20, 0, 0, 4) var timezVar = Timez(10, 20, 0, 0, "+4:00")
func TestTimezExpressionEQ(t *testing.T) { func TestTimezExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.EQ(table2ColTimez), "(table1.col_timez = table2.col_timez)") assertClauseSerialize(t, table1ColTimez.EQ(table2ColTimez), "(table1.col_timez = table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionNOT_EQ(t *testing.T) { func TestTimezExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.NOT_EQ(table2ColTimez), "(table1.col_timez != table2.col_timez)") assertClauseSerialize(t, table1ColTimez.NOT_EQ(table2ColTimez), "(table1.col_timez != table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionIS_DISTINCT_FROM(t *testing.T) { func TestTimezExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS DISTINCT FROM table2.col_timez)") assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { func TestTimezExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS NOT DISTINCT FROM table2.col_timez)") assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS NOT DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionLT(t *testing.T) { func TestTimezExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT(table2ColTimez), "(table1.col_timez < table2.col_timez)") assertClauseSerialize(t, table1ColTimez.LT(table2ColTimez), "(table1.col_timez < table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionLT_EQ(t *testing.T) { func TestTimezExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT_EQ(table2ColTimez), "(table1.col_timez <= table2.col_timez)") assertClauseSerialize(t, table1ColTimez.LT_EQ(table2ColTimez), "(table1.col_timez <= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionGT(t *testing.T) { func TestTimezExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT(table2ColTimez), "(table1.col_timez > table2.col_timez)") assertClauseSerialize(t, table1ColTimez.GT(table2ColTimez), "(table1.col_timez > table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1)", "10:20:00 +4:00")
} }
func TestTimezExpressionGT_EQ(t *testing.T) { func TestTimezExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT_EQ(table2ColTimez), "(table1.col_timez >= table2.col_timez)") assertClauseSerialize(t, table1ColTimez.GT_EQ(table2ColTimez), "(table1.col_timez >= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1::time with time zone)", "10:20:00.000 +04") assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1)", "10:20:00 +4:00")
} }
func TestTimezExp(t *testing.T) { func TestTimezExp(t *testing.T) {
assertClauseSerialize(t, TimezExp(table1ColFloat), "table1.col_float") assertClauseSerialize(t, TimezExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, 4)), assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, "+4:00")),
"(table1.col_float < $1::time with time zone)", string("01:01:01.001 +04")) "(table1.col_float < $1)", string("01:01:01.000000001 +4:00"))
} }

163
internal/jet/utils.go Normal file
View file

@ -0,0 +1,163 @@
package jet
import (
"github.com/go-jet/jet/internal/utils"
"reflect"
)
func serializeOrderByClauseList(statement StatementType, orderByClauses []OrderByClause, out *SqlBuilder) {
for i, value := range orderByClauses {
if i > 0 {
out.WriteString(", ")
}
value.serializeForOrderBy(statement, out)
}
}
func serializeGroupByClauseList(statement StatementType, clauses []GroupByClause, out *SqlBuilder) {
for i, c := range clauses {
if i > 0 {
out.WriteString(", ")
}
if c == nil {
panic("jet: nil clause")
}
c.serializeForGroupBy(statement, out)
}
}
func SerializeClauseList(statement StatementType, clauses []Serializer, out *SqlBuilder) {
for i, c := range clauses {
if i > 0 {
out.WriteString(", ")
}
if c == nil {
panic("jet: nil clause")
}
c.serialize(statement, out)
}
}
func serializeExpressionList(statement StatementType, expressions []Expression, separator string, out *SqlBuilder) {
for i, value := range expressions {
if i > 0 {
out.WriteString(separator)
}
value.serialize(statement, out)
}
}
func SerializeProjectionList(statement StatementType, projections []Projection, out *SqlBuilder) {
for i, col := range projections {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
if col == nil {
panic("jet: Projection is nil")
}
col.serializeForProjection(statement, out)
}
}
func SerializeColumnNames(columns []Column, out *SqlBuilder) {
for i, col := range columns {
if i > 0 {
out.WriteString(", ")
}
if col == nil {
panic("jet: nil column in columns list")
}
out.WriteString(col.Name())
}
}
func ColumnListToProjectionList(columns []ColumnExpression) []Projection {
var ret []Projection
for _, column := range columns {
ret = append(ret, column)
}
return ret
}
func valueToClause(value interface{}) Serializer {
if clause, ok := value.(Serializer); ok {
return clause
}
return literal(value)
}
func UnwindRowFromModel(columns []Column, data interface{}) []Serializer {
structValue := reflect.Indirect(reflect.ValueOf(data))
row := []Serializer{}
utils.ValueMustBe(structValue, reflect.Struct, "jet: data has to be a struct")
for _, column := range columns {
columnName := column.Name()
structFieldName := utils.ToGoIdentifier(columnName)
structField := structValue.FieldByName(structFieldName)
if !structField.IsValid() {
panic("missing struct field for column : " + columnName)
}
var field interface{}
if structField.Kind() == reflect.Ptr && structField.IsNil() {
field = nil
} else {
field = reflect.Indirect(structField).Interface()
}
row = append(row, literal(field))
}
return row
}
func UnwindRowsFromModels(columns []Column, data interface{}) [][]Serializer {
sliceValue := reflect.Indirect(reflect.ValueOf(data))
utils.ValueMustBe(sliceValue, reflect.Slice, "jet: data has to be a slice.")
rows := [][]Serializer{}
for i := 0; i < sliceValue.Len(); i++ {
structValue := sliceValue.Index(i)
rows = append(rows, UnwindRowFromModel(columns, structValue.Interface()))
}
return rows
}
func UnwindRowFromValues(value interface{}, values []interface{}) []Serializer {
row := []Serializer{}
allValues := append([]interface{}{value}, values...)
for _, val := range allValues {
row = append(row, valueToClause(val))
}
return row
}

View file

@ -0,0 +1,148 @@
package testutils
import (
"bytes"
"encoding/json"
"fmt"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/jet"
"gotest.tools/assert"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"testing"
)
func AssertExec(t *testing.T, stmt jet.Statement, db execution.DB, rowsAffected ...int64) {
res, err := stmt.Exec(db)
assert.NilError(t, err)
rows, err := res.RowsAffected()
assert.NilError(t, err)
if len(rowsAffected) > 0 {
assert.Equal(t, rows, rowsAffected[0])
}
}
func AssertExecErr(t *testing.T, stmt jet.Statement, db execution.DB, errorStr string) {
_, err := stmt.Exec(db)
assert.Error(t, err, errorStr)
}
func getFullPath(relativePath string) string {
goPath := os.Getenv("GOPATH")
return filepath.Join(goPath, "src/github.com/go-jet/jet/tests", relativePath)
}
func PrintJson(v interface{}) {
jsonText, _ := json.MarshalIndent(v, "", "\t")
fmt.Println(string(jsonText))
}
func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err)
assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON)
}
func SaveJsonFile(v interface{}, testRelativePath string) {
jsonText, _ := json.MarshalIndent(v, "", "\t")
filePath := getFullPath(testRelativePath)
err := ioutil.WriteFile(filePath, jsonText, 0644)
if err != nil {
panic(err)
}
}
func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath)
fileJSONData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err)
if runtime.GOOS == "windows" {
fileJSONData = bytes.Replace(fileJSONData, []byte("\r\n"), []byte("\n"), -1)
}
jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err)
assert.Assert(t, string(fileJSONData) == string(jsonData))
//assert.DeepEqual(t, string(fileJSONData), string(jsonData))
}
func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) {
queryStr, args := query.Sql()
assert.Equal(t, queryStr, expectedQuery)
if len(expectedArgs) == 0 {
return
}
assert.DeepEqual(t, args, expectedArgs)
}
func AssertStatementSqlErr(t *testing.T, stmt jet.Statement, errorStr string) {
defer func() {
r := recover()
assert.Equal(t, r, errorStr)
}()
stmt.Sql()
}
func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) {
_, args := query.Sql()
if len(expectedArgs) > 0 {
assert.DeepEqual(t, args, expectedArgs)
}
debuqSql := query.DebugSql()
assert.Equal(t, debuqSql, expectedQuery)
}
func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) {
out := jet.SqlBuilder{Dialect: dialect}
jet.Serialize(clause, jet.SelectStatementType, &out)
//fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query)
if len(args) > 0 {
assert.DeepEqual(t, out.Args, args)
}
}
func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) {
defer func() {
r := recover()
assert.Equal(t, r, errString)
}()
out := jet.SqlBuilder{Dialect: dialect}
jet.Serialize(clause, jet.SelectStatementType, &out)
}
func AssertProjectionSerialize(t *testing.T, dialect jet.Dialect, projection jet.Projection, query string, args ...interface{}) {
out := jet.SqlBuilder{Dialect: dialect}
jet.SerializeForProjection(projection, jet.SelectStatementType, &out)
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}
func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db execution.DB, dest interface{}, errString string) {
defer func() {
r := recover()
assert.Equal(t, r, errString)
}()
stmt.Query(db, dest)
}

View file

@ -0,0 +1,70 @@
package testutils
import (
"strings"
"time"
)
func Date(t string) *time.Time {
newTime, err := time.Parse("2006-01-02", t)
if err != nil {
panic(err)
}
return &newTime
}
func TimestampWithoutTimeZone(t string, precision int) *time.Time {
precisionStr := ""
if precision > 0 {
precisionStr = "." + strings.Repeat("9", precision)
}
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000")
if err != nil {
panic(err)
}
return &newTime
}
func TimeWithoutTimeZone(t string) *time.Time {
newTime, err := time.Parse("15:04:05", t)
if err != nil {
panic(err)
}
return &newTime
}
func TimeWithTimeZone(t string) *time.Time {
newTimez, err := time.Parse("15:04:05 -0700", t)
if err != nil {
panic(err)
}
return &newTimez
}
func TimestampWithTimeZone(t string, precision int) *time.Time {
precisionStr := ""
if precision > 0 {
precisionStr = "." + strings.Repeat("9", precision)
}
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t)
if err != nil {
panic(err)
}
return &newTime
}

View file

@ -1,7 +1,7 @@
package utils package utils
import ( import (
"bytes" "database/sql"
"github.com/go-jet/jet/internal/3rdparty/snaker" "github.com/go-jet/jet/internal/3rdparty/snaker"
"go/format" "go/format"
"os" "os"
@ -9,7 +9,6 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"text/template"
"time" "time"
) )
@ -62,28 +61,6 @@ func EnsureDirPath(dirPath string) error {
return nil return nil
} }
// GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": ToGoIdentifier,
"now": func() string {
return time.Now().Format(time.RFC850)
},
}).Parse(templateText)
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := t.Execute(&buf, templateData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// CleanUpGeneratedFiles deletes everything at folder dir. // CleanUpGeneratedFiles deletes everything at folder dir.
func CleanUpGeneratedFiles(dir string) error { func CleanUpGeneratedFiles(dir string) error {
exist, err := DirExists(dir) exist, err := DirExists(dir)
@ -103,6 +80,14 @@ func CleanUpGeneratedFiles(dir string) error {
return nil return nil
} }
func DBClose(db *sql.DB) {
if db == nil {
return
}
db.Close()
}
// DirExists checks if folder at path exist. // DirExists checks if folder at path exist.
func DirExists(path string) (bool, error) { func DirExists(path string) (bool, error) {
_, err := os.Stat(path) _, err := os.Stat(path)
@ -159,3 +144,28 @@ func FormatTimestamp(t time.Time) []byte {
func IsNil(v interface{}) bool { func IsNil(v interface{}) bool {
return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil())
} }
func MustBe(v interface{}, kind reflect.Kind, errorStr string) {
if reflect.TypeOf(v).Kind() != kind {
panic(errorStr)
}
}
func ValueMustBe(v reflect.Value, kind reflect.Kind, errorStr string) {
if v.Kind() != kind {
panic(errorStr)
}
}
func TypeMustBe(v reflect.Type, kind reflect.Kind, errorStr string) {
if v.Kind() != kind {
panic(errorStr)
}
}
func MustBeInitializedPtr(val interface{}, errorStr string) {
if IsNil(val) {
panic(errorStr)
}
}

View file

@ -1,7 +1,7 @@
package utils package utils
import ( import (
"github.com/stretchr/testify/assert" "gotest.tools/assert"
"testing" "testing"
) )

Some files were not shown because too many files have changed in this diff Show more