Merge pull request #9 from go-jet/mysql

MySQL and MariaDB support
This commit is contained in:
go-jet 2019-08-16 13:10:38 +02:00 committed by GitHub
commit de57a52acc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
205 changed files with 11704 additions and 467715 deletions

View file

@ -3,17 +3,25 @@
# Check https://circleci.com/docs/2.0/language-go/ for more details
version: 2
jobs:
build:
build-postgres-and-mysql:
docker:
# specify the version
- image: circleci/golang:1.11
- image: circleci/postgres:10.6-alpine
- image: circleci/postgres:10.8-alpine
environment: # environment variables for primary container
POSTGRES_USER: jet
POSTGRES_PASSWORD: jet
POSTGRES_DB: jetdb
- image: circleci/mysql:8.0
command: [--default-authentication-plugin=mysql_native_password]
environment:
MYSQL_ROOT_PASSWORD: jet
MYSQL_DATABASE: dvds
MYSQL_USER: jet
MYSQL_PASSWORD: jet
working_directory: /go/src/github.com/go-jet/jet
environment: # environment variables for the build itself
@ -22,12 +30,20 @@ jobs:
steps:
- checkout
# specify any bash command here prefixed with `run: `
- run:
name: Submodule init
command: |
git submodule init
git submodule update
cd ./tests/testdata && git fetch && git checkout master
- run:
name: Install dependencies
command: |
go get github.com/google/uuid
go get github.com/lib/pq
go get github.com/go-sql-driver/mysql
go get github.com/pkg/profile
go get gotest.tools/assert
@ -48,14 +64,37 @@ jobs:
echo Failed waiting for Postgres && exit 1
- run:
name: Init Postgres database
name: Waiting for MySQL to be ready
command: |
cd tests
go run ./init/init.go
cd ..
for i in `seq 1 10`;
do
nc -z 127.0.0.1 3306 && echo Success && exit 0
echo -n .
sleep 1
done
echo Failed waiting for MySQL && exit 1
- run:
name: Install MySQL CLI;
command: |
sudo apt-get install default-mysql-client
- run:
name: Create MySQL user and databases
command: |
mysql -h 127.0.0.1 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -u jet -pjet -e "create database test_sample"
- run:
name: Init Postgres database
command: |
cd tests
go run ./init/init.go -testsuite all
cd ..
- run: mkdir -p $TEST_RESULTS
- run: go test -v . ./tests -coverpkg=github.com/go-jet/jet,github.com/go-jet/jet/execution/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
- run: go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/execution/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
- run:
name: Upload code coverage
@ -67,4 +106,75 @@ jobs:
- store_test_results: # Upload test results for display in Test Summary: https://circleci.com/docs/2.0/collect-test-data/
path: /tmp/test-results
build-mariadb:
docker:
# specify the version
- image: circleci/golang:1.11
- image: circleci/mariadb:10.3
command: [--default-authentication-plugin=mysql_native_password]
environment:
MYSQL_ROOT_PASSWORD: jet
MYSQL_DATABASE: dvds
MYSQL_USER: jet
MYSQL_PASSWORD: jet
working_directory: /go/src/github.com/go-jet/jet
environment: # environment variables for the build itself
TEST_RESULTS: /tmp/test-results # path to where test results will be saved
steps:
- checkout
- run:
name: Submodule init
command: |
git submodule init
git submodule update
cd ./tests/testdata && git fetch && git checkout master
- run:
name: Install dependencies
command: |
go get github.com/google/uuid
go get github.com/lib/pq
go get github.com/go-sql-driver/mysql
go get github.com/pkg/profile
go get gotest.tools/assert
go get github.com/davecgh/go-spew/spew
go get github.com/jstemmer/go-junit-report
go install github.com/go-jet/jet/cmd/jet
- run:
name: Install MySQL CLI;
command: |
sudo apt-get install default-mysql-client
- run:
name: Init MariaDB database
command: |
mysql -h 127.0.0.1 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -u jet -pjet -e "create database test_sample"
- run:
name: Init MariaDB database
command: |
cd tests
go run ./init/init.go -testsuite MariaDB
cd ..
- run:
name: Run MariaDB tests
command: |
go test -v ./tests/mysql/ -source=MariaDB
workflows:
version: 2
build_and_test:
jobs:
- build-postgres-and-mysql
- build-mariadb

3
.gitignore vendored
View file

@ -17,4 +17,5 @@
# Test files
gen
.gentestdata
.gentestdata
.tests/testdata/

3
.gitmodules vendored Normal file
View file

@ -0,0 +1,3 @@
[submodule "tests/testdata"]
path = tests/testdata
url = https://github.com/go-jet/jet-test-data

View file

@ -1,9 +1,11 @@
# Jet
[![CircleCI](https://circleci.com/gh/go-jet/jet/tree/develop.svg?style=svg&circle-token=97f255c6a4a3ab6590ea2e9195eb3ebf9f97b4a7)](https://circleci.com/gh/go-jet/jet/tree/develop)
[![codecov](https://codecov.io/gh/go-jet/jet/branch/develop/graph/badge.svg)](https://codecov.io/gh/go-jet/jet)
[![Go Report Card](https://goreportcard.com/badge/github.com/go-jet/jet)](https://goreportcard.com/report/github.com/go-jet/jet)
[![Documentation](https://godoc.org/github.com/go-jet/jet?status.svg)](http://godoc.org/github.com/go-jet/jet)
[![codecov](https://codecov.io/gh/go-jet/jet/branch/develop/graph/badge.svg)](https://codecov.io/gh/go-jet/jet)
[![CircleCI](https://circleci.com/gh/go-jet/jet/tree/develop.svg?style=svg&circle-token=97f255c6a4a3ab6590ea2e9195eb3ebf9f97b4a7)](https://circleci.com/gh/go-jet/jet/tree/develop)
[![GitHub release](https://img.shields.io/github/release/go-jet/jet.svg)](https://github.com/go-jet/jet/releases)
Jet is a framework for writing type-safe SQL queries for PostgreSQL in Go, with ability to easily
convert database query result to desired arbitrary structure.
@ -258,7 +260,7 @@ above statement. Usually this is the most complex and tedious work, but with Jet
First we have to create desired structure to store query result set.
This is done be combining autogenerated model types or it can be done manually(see [wiki](https://github.com/go-jet/jet/wiki/Scan-to-arbitrary-destination) for more information).
Let's say this is our desired structure, created by combining auto-generated model types:
Let's say this is our desired structure:
```go
var dest []struct {
model.Actor
@ -506,7 +508,7 @@ return result in one database call. Handler execution will be only proportional
ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate).
With Jet you can even join the whole database and store the whole structured result in in one query call.
This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/chinook_db_test.go#L40).
This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/postgres/chinook_db_test.go#L40).
The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s.
##### How quickly bugs are found

1
_config.yml Normal file
View file

@ -0,0 +1 @@
theme: jekyll-theme-tactile

View file

@ -1,34 +0,0 @@
package jet
type alias struct {
expression Expression
alias string
}
func newAlias(expression Expression, aliasName string) projection {
return &alias{
expression: expression,
alias: aliasName,
}
}
func (a *alias) from(subQuery SelectTable) projection {
column := newColumn(a.alias, "", nil)
column.parent = &column
column.subQuery = subQuery
return &column
}
func (a *alias) serializeForProjection(statement statementType, out *sqlBuilder) error {
err := a.expression.serialize(statement, out)
if err != nil {
return err
}
out.writeString("AS")
out.writeQuotedString(a.alias)
return nil
}

255
clause.go
View file

@ -1,255 +0,0 @@
package jet
import (
"bytes"
"github.com/go-jet/jet/internal/utils"
"github.com/google/uuid"
"strconv"
"strings"
"time"
)
type serializeOption int
const (
noWrap serializeOption = iota
)
type clause interface {
serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error
}
func contains(options []serializeOption, option serializeOption) bool {
for _, opt := range options {
if opt == option {
return true
}
}
return false
}
type sqlBuilder struct {
buff bytes.Buffer
args []interface{}
lastChar byte
ident int
}
type statementType string
const (
selectStatement statementType = "SELECT"
insertStatement statementType = "INSERT"
updateStatement statementType = "UPDATE"
deleteStatement statementType = "DELETE"
setStatement statementType = "SET"
lockStatement statementType = "LOCK"
)
const defaultIdent = 5
func (q *sqlBuilder) increaseIdent() {
q.ident += defaultIdent
}
func (q *sqlBuilder) decreaseIdent() {
if q.ident < defaultIdent {
q.ident = 0
}
q.ident -= defaultIdent
}
func (q *sqlBuilder) writeProjections(statement statementType, projections []projection) error {
q.increaseIdent()
err := serializeProjectionList(statement, projections, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeFrom(statement statementType, table ReadableTable) error {
q.newLine()
q.writeString("FROM")
q.increaseIdent()
err := table.serialize(statement, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeWhere(statement statementType, where Expression) error {
q.newLine()
q.writeString("WHERE")
q.increaseIdent()
err := where.serialize(statement, q, noWrap)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeGroupBy(statement statementType, groupBy []groupByClause) error {
q.newLine()
q.writeString("GROUP BY")
q.increaseIdent()
err := serializeGroupByClauseList(statement, groupBy, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeOrderBy(statement statementType, orderBy []orderByClause) error {
q.newLine()
q.writeString("ORDER BY")
q.increaseIdent()
err := serializeOrderByClauseList(statement, orderBy, q)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeHaving(statement statementType, having Expression) error {
q.newLine()
q.writeString("HAVING")
q.increaseIdent()
err := having.serialize(statement, q, noWrap)
q.decreaseIdent()
return err
}
func (q *sqlBuilder) writeReturning(statement statementType, returning []projection) error {
if len(returning) == 0 {
return nil
}
q.newLine()
q.writeString("RETURNING")
q.increaseIdent()
return q.writeProjections(statement, returning)
}
func (q *sqlBuilder) newLine() {
q.write([]byte{'\n'})
q.write(bytes.Repeat([]byte{' '}, q.ident))
}
func (q *sqlBuilder) write(data []byte) {
if len(data) == 0 {
return
}
if !isPreSeparator(q.lastChar) && !isPostSeparator(data[0]) && q.buff.Len() > 0 {
q.buff.WriteByte(' ')
}
q.buff.Write(data)
q.lastChar = data[len(data)-1]
}
func isPreSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':'
}
func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
}
func (q *sqlBuilder) writeQuotedString(str string) {
q.writeString(`"` + str + `"`)
}
func (q *sqlBuilder) writeString(str string) {
q.write([]byte(str))
}
func (q *sqlBuilder) writeIdentifier(name string) {
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
if quoteWrap {
q.writeString(`"` + name + `"`)
} else {
q.writeString(name)
}
}
func (q *sqlBuilder) writeByte(b byte) {
q.write([]byte{b})
}
func (q *sqlBuilder) finalize() (string, []interface{}) {
return q.buff.String() + ";\n", q.args
}
func (q *sqlBuilder) insertConstantArgument(arg interface{}) {
q.writeString(argToString(arg))
}
func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) {
q.args = append(q.args, arg)
argPlaceholder := "$" + strconv.Itoa(len(q.args))
q.writeString(argPlaceholder)
}
func argToString(value interface{}) string {
if utils.IsNil(value) {
return "NULL"
}
switch bindVal := value.(type) {
case bool:
if bindVal {
return "TRUE"
}
return "FALSE"
case int8:
return strconv.FormatInt(int64(bindVal), 10)
case int:
return strconv.FormatInt(int64(bindVal), 10)
case int16:
return strconv.FormatInt(int64(bindVal), 10)
case int32:
return strconv.FormatInt(int64(bindVal), 10)
case int64:
return strconv.FormatInt(int64(bindVal), 10)
case uint8:
return strconv.FormatUint(uint64(bindVal), 10)
case uint:
return strconv.FormatUint(uint64(bindVal), 10)
case uint16:
return strconv.FormatUint(uint64(bindVal), 10)
case uint32:
return strconv.FormatUint(uint64(bindVal), 10)
case uint64:
return strconv.FormatUint(uint64(bindVal), 10)
case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case float64:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case string:
return stringQuote(bindVal)
case []byte:
return stringQuote(string(bindVal))
case uuid.UUID:
return stringQuote(bindVal.String())
case time.Time:
return stringQuote(string(utils.FormatTimestamp(bindVal)))
default:
return "[Unsupported type]"
}
}
func stringQuote(value string) string {
return `'` + strings.Replace(value, "'", "''", -1) + `'`
}

View file

@ -3,12 +3,19 @@ package main
import (
"flag"
"fmt"
"github.com/go-jet/jet/generator/postgres"
mysqlgen "github.com/go-jet/jet/generator/mysql"
postgresgen "github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/postgres"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
"os"
"strings"
)
var (
source string
host string
port int
user string
@ -22,14 +29,16 @@ var (
)
func init() {
flag.StringVar(&source, "source", "", "Database system name (PostgreSQL or MySQL)")
flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
flag.IntVar(&port, "port", 0, "Database port")
flag.StringVar(&user, "user", "", "Database user")
flag.StringVar(&password, "password", "", "The users password")
flag.StringVar(&sslmode, "sslmode", "disable", "Whether or not to use SSL(optional)")
flag.StringVar(&params, "params", "", "Additional connection string parameters(optional)")
flag.StringVar(&dbName, "dbname", "", "name of the database")
flag.StringVar(&schemaName, "schema", "public", "Database schema name.")
flag.StringVar(&dbName, "dbname", "", "Database name")
flag.StringVar(&schemaName, "schema", "public", `Database schema name. (default "public") (ignored for MySQL)`)
flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL(optional)(default "disable") (ignored for MySQL)`)
flag.StringVar(&destDir, "path", "", "Destination dir for files generated.")
}
@ -38,7 +47,11 @@ func main() {
flag.Usage = func() {
_, _ = fmt.Fprint(os.Stdout, `
Usage of jet:
Jet generator 2.0.0
Usage:
-source string
Database system name (PostgreSQL or MySQL)
-host string
Database host path (Example: localhost)
-port int
@ -48,13 +61,13 @@ Usage of jet:
-password string
The users password
-dbname string
name of the database
Database name
-params string
Additional connection string parameters(optional)
-schema string
Database schema name. (default "public")
Database schema name. (default "public") (ignored for MySQL)
-sslmode string
Whether or not to use SSL(optional) (default "disable")
Whether or not to use SSL(optional) (default "disable") (ignored for MySQL)
-path string
Destination dir for files generated.
`)
@ -62,28 +75,54 @@ Usage of jet:
flag.Parse()
if host == "" || port == 0 || user == "" || dbName == "" || schemaName == "" {
fmt.Println("\njet: required flag missing")
flag.Usage()
os.Exit(-2)
if source == "" || host == "" || port == 0 || user == "" || dbName == "" {
printErrorAndExit("\nERROR: required flag(s) missing")
}
genData := postgres.DBConnection{
Host: host,
Port: port,
User: user,
Password: password,
SslMode: sslmode,
Params: params,
var err error
DBName: dbName,
SchemaName: schemaName,
switch strings.ToLower(strings.TrimSpace(source)) {
case strings.ToLower(postgres.Dialect.Name()):
genData := postgresgen.DBConnection{
Host: host,
Port: port,
User: user,
Password: password,
SslMode: sslmode,
Params: params,
DBName: dbName,
SchemaName: schemaName,
}
err = postgresgen.Generate(destDir, genData)
case strings.ToLower(mysql.Dialect.Name()):
dbConn := mysqlgen.DBConnection{
Host: host,
Port: port,
User: user,
Password: password,
SslMode: sslmode,
Params: params,
DBName: dbName,
}
err = mysqlgen.Generate(destDir, dbConn)
default:
fmt.Println("ERROR: unsupported source " + source + ". " + postgres.Dialect.Name() + " and " + mysql.Dialect.Name() + " are currently supported.")
os.Exit(-4)
}
err := postgres.Generate(destDir, genData)
if err != nil {
fmt.Println(err.Error())
os.Exit(-1)
os.Exit(-5)
}
}
func printErrorAndExit(error string) {
fmt.Println(error)
flag.Usage()
os.Exit(-2)
}

145
column.go
View file

@ -1,145 +0,0 @@
// Modeling of columns
package jet
type column interface {
Name() string
TableName() string
setTableName(table string)
setSubQuery(subQuery SelectTable)
defaultAlias() string
}
// Column is common column interface for all types of columns.
type Column interface {
Expression
column
}
// The base type for real materialized columns.
type columnImpl struct {
expressionInterfaceImpl
name string
tableName string
subQuery SelectTable
}
func newColumn(name string, tableName string, parent Column) columnImpl {
bc := columnImpl{
name: name,
tableName: tableName,
}
bc.expressionInterfaceImpl.parent = parent
return bc
}
func (c *columnImpl) Name() string {
return c.name
}
func (c *columnImpl) TableName() string {
return c.tableName
}
func (c *columnImpl) setTableName(table string) {
c.tableName = table
}
func (c *columnImpl) setSubQuery(subQuery SelectTable) {
c.subQuery = subQuery
}
func (c *columnImpl) defaultAlias() string {
if c.tableName != "" {
return c.tableName + "." + c.name
}
return c.name
}
func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
if statement == setStatement {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.writeString(`"` + c.defaultAlias() + `"`) //always quote
return nil
}
return c.serialize(statement, out)
}
func (c columnImpl) serializeForProjection(statement statementType, out *sqlBuilder) error {
err := c.serialize(statement, out)
if err != nil {
return err
}
out.writeString(`AS "` + c.defaultAlias() + `"`)
return nil
}
func (c columnImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if c.subQuery != nil {
out.writeIdentifier(c.subQuery.Alias())
out.writeByte('.')
out.writeQuotedString(c.defaultAlias())
} else {
if c.tableName != "" {
out.writeIdentifier(c.tableName)
out.writeByte('.')
}
out.writeIdentifier(c.name)
}
return nil
}
//------------------------------------------------------//
// ColumnList is redefined type to support list of columns as single projection
type ColumnList []Column
// projection interface implementation
func (cl ColumnList) isProjectionType() {}
func (cl ColumnList) from(subQuery SelectTable) projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.from(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serializeForProjection(statement statementType, out *sqlBuilder) error {
projections := columnListToProjectionList(cl)
err := serializeProjectionList(statement, projections, out)
if err != nil {
return err
}
return nil
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }

View file

@ -1,45 +0,0 @@
package jet
import "testing"
var dateVar = Date(2000, 12, 30)
func TestDateExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.EQ(table2ColDate), "(table1.col_date = table2.col_date)")
assertClauseSerialize(t, table1ColDate.EQ(dateVar), "(table1.col_date = $1::date)", "2000-12-30")
}
func TestDateExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.NOT_EQ(table2ColDate), "(table1.col_date != table2.col_date)")
assertClauseSerialize(t, table1ColDate.NOT_EQ(dateVar), "(table1.col_date != $1::date)", "2000-12-30")
}
func TestDateExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(table2ColDate), "(table1.col_date IS DISTINCT FROM table2.col_date)")
assertClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(dateVar), "(table1.col_date IS DISTINCT FROM $1::date)", "2000-12-30")
}
func TestDateExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(table2ColDate), "(table1.col_date IS NOT DISTINCT FROM table2.col_date)")
assertClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(dateVar), "(table1.col_date IS NOT DISTINCT FROM $1::date)", "2000-12-30")
}
func TestDateExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColDate.GT(table2ColDate), "(table1.col_date > table2.col_date)")
assertClauseSerialize(t, table1ColDate.GT(dateVar), "(table1.col_date > $1::date)", "2000-12-30")
}
func TestDateExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.GT_EQ(table2ColDate), "(table1.col_date >= table2.col_date)")
assertClauseSerialize(t, table1ColDate.GT_EQ(dateVar), "(table1.col_date >= $1::date)", "2000-12-30")
}
func TestDateExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColDate.LT(table2ColDate), "(table1.col_date < table2.col_date)")
assertClauseSerialize(t, table1ColDate.LT(dateVar), "(table1.col_date < $1::date)", "2000-12-30")
}
func TestDateExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColDate.LT_EQ(table2ColDate), "(table1.col_date <= table2.col_date)")
assertClauseSerialize(t, table1ColDate.LT_EQ(dateVar), "(table1.col_date <= $1::date)", "2000-12-30")
}

View file

@ -1,102 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
)
// DeleteStatement is interface for SQL DELETE statement
type DeleteStatement interface {
Statement
WHERE(expression BoolExpression) DeleteStatement
RETURNING(projections ...projection) DeleteStatement
}
func newDeleteStatement(table WritableTable) DeleteStatement {
return &deleteStatementImpl{
table: table,
}
}
type deleteStatementImpl struct {
table WritableTable
where BoolExpression
returning []projection
}
func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
d.where = expression
return d
}
func (d *deleteStatementImpl) RETURNING(projections ...projection) DeleteStatement {
d.returning = projections
return d
}
func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
if d == nil {
return errors.New("jet: delete statement is nil")
}
out.newLine()
out.writeString("DELETE FROM")
if d.table == nil {
return errors.New("jet: nil tableName")
}
if err := d.table.serialize(deleteStatement, out); err != nil {
return err
}
if d.where == nil {
return errors.New("jet: deleting without a WHERE clause")
}
if err := out.writeWhere(deleteStatement, d.where); err != nil {
return err
}
if err := out.writeReturning(deleteStatement, d.returning); err != nil {
return err
}
return nil
}
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := &sqlBuilder{}
err = d.serializeImpl(queryData)
if err != nil {
return
}
query, args = queryData.finalize()
return
}
func (d *deleteStatementImpl) DebugSql() (query string, err error) {
return debugSql(d)
}
func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(d, db, destination)
}
func (d *deleteStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, d, db, destination)
}
func (d *deleteStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(d, db)
}
func (d *deleteStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, d, db)
}

View file

@ -1,25 +0,0 @@
package jet
import (
"testing"
)
func TestDeleteUnconditionally(t *testing.T) {
assertStatementErr(t, table1.DELETE(), `jet: deleting without a WHERE clause`)
assertStatementErr(t, table1.DELETE().WHERE(nil), `jet: deleting without a WHERE clause`)
}
func TestDeleteWithWhere(t *testing.T) {
assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), `
DELETE FROM db.table1
WHERE table1.col1 = $1;
`, int64(1))
}
func TestDeleteWithWhereAndReturning(t *testing.T) {
assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).RETURNING(table1Col1), `
DELETE FROM db.table1
WHERE table1.col1 = $1
RETURNING table1.col1 AS "table1.col1";
`, int64(1))
}

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -8,18 +8,18 @@
package enum
import "github.com/go-jet/jet"
import "github.com/go-jet/jet/postgres"
var MpaaRating = &struct {
G jet.StringExpression
Pg jet.StringExpression
Pg13 jet.StringExpression
R jet.StringExpression
Nc17 jet.StringExpression
G postgres.StringExpression
Pg postgres.StringExpression
Pg13 postgres.StringExpression
R postgres.StringExpression
Nc17 postgres.StringExpression
}{
G: jet.NewEnumValue("G"),
Pg: jet.NewEnumValue("PG"),
Pg13: jet.NewEnumValue("PG-13"),
R: jet.NewEnumValue("R"),
Nc17: jet.NewEnumValue("NC-17"),
G: postgres.NewEnumValue("G"),
Pg: postgres.NewEnumValue("PG"),
Pg13: postgres.NewEnumValue("PG-13"),
R: postgres.NewEnumValue("R"),
Nc17: postgres.NewEnumValue("NC-17"),
}

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -9,22 +9,22 @@
package table
import (
"github.com/go-jet/jet"
"github.com/go-jet/jet/postgres"
)
var Actor = newActorTable()
type ActorTable struct {
jet.Table
postgres.Table
//Columns
ActorID jet.ColumnInteger
FirstName jet.ColumnString
LastName jet.ColumnString
LastUpdate jet.ColumnTimestamp
ActorID postgres.ColumnInteger
FirstName postgres.ColumnString
LastName postgres.ColumnString
LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList
MutableColumns jet.ColumnList
AllColumns postgres.IColumnList
MutableColumns postgres.IColumnList
}
// creates new ActorTable with assigned alias
@ -38,14 +38,14 @@ func (a *ActorTable) AS(alias string) *ActorTable {
func newActorTable() *ActorTable {
var (
ActorIDColumn = jet.IntegerColumn("actor_id")
FirstNameColumn = jet.StringColumn("first_name")
LastNameColumn = jet.StringColumn("last_name")
LastUpdateColumn = jet.TimestampColumn("last_update")
ActorIDColumn = postgres.IntegerColumn("actor_id")
FirstNameColumn = postgres.StringColumn("first_name")
LastNameColumn = postgres.StringColumn("last_name")
LastUpdateColumn = postgres.TimestampColumn("last_update")
)
return &ActorTable{
Table: jet.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn),
Table: postgres.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn),
//Columns
ActorID: ActorIDColumn,
@ -53,7 +53,7 @@ func newActorTable() *ActorTable {
LastName: LastNameColumn,
LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn},
MutableColumns: jet.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn},
AllColumns: postgres.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn),
MutableColumns: postgres.ColumnList(FirstNameColumn, LastNameColumn, LastUpdateColumn),
}
}

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table
import (
"github.com/go-jet/jet"
"github.com/go-jet/jet/postgres"
)
var Category = newCategoryTable()
type CategoryTable struct {
jet.Table
postgres.Table
//Columns
CategoryID jet.ColumnInteger
Name jet.ColumnString
LastUpdate jet.ColumnTimestamp
CategoryID postgres.ColumnInteger
Name postgres.ColumnString
LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList
MutableColumns jet.ColumnList
AllColumns postgres.IColumnList
MutableColumns postgres.IColumnList
}
// creates new CategoryTable with assigned alias
@ -37,20 +37,20 @@ func (a *CategoryTable) AS(alias string) *CategoryTable {
func newCategoryTable() *CategoryTable {
var (
CategoryIDColumn = jet.IntegerColumn("category_id")
NameColumn = jet.StringColumn("name")
LastUpdateColumn = jet.TimestampColumn("last_update")
CategoryIDColumn = postgres.IntegerColumn("category_id")
NameColumn = postgres.StringColumn("name")
LastUpdateColumn = postgres.TimestampColumn("last_update")
)
return &CategoryTable{
Table: jet.NewTable("dvds", "category", CategoryIDColumn, NameColumn, LastUpdateColumn),
Table: postgres.NewTable("dvds", "category", CategoryIDColumn, NameColumn, LastUpdateColumn),
//Columns
CategoryID: CategoryIDColumn,
Name: NameColumn,
LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{CategoryIDColumn, NameColumn, LastUpdateColumn},
MutableColumns: jet.ColumnList{NameColumn, LastUpdateColumn},
AllColumns: postgres.ColumnList(CategoryIDColumn, NameColumn, LastUpdateColumn),
MutableColumns: postgres.ColumnList(NameColumn, LastUpdateColumn),
}
}

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -9,31 +9,31 @@
package table
import (
"github.com/go-jet/jet"
"github.com/go-jet/jet/postgres"
)
var Film = newFilmTable()
type FilmTable struct {
jet.Table
postgres.Table
//Columns
FilmID jet.ColumnInteger
Title jet.ColumnString
Description jet.ColumnString
ReleaseYear jet.ColumnInteger
LanguageID jet.ColumnInteger
RentalDuration jet.ColumnInteger
RentalRate jet.ColumnFloat
Length jet.ColumnInteger
ReplacementCost jet.ColumnFloat
Rating jet.ColumnString
LastUpdate jet.ColumnTimestamp
SpecialFeatures jet.ColumnString
Fulltext jet.ColumnString
FilmID postgres.ColumnInteger
Title postgres.ColumnString
Description postgres.ColumnString
ReleaseYear postgres.ColumnInteger
LanguageID postgres.ColumnInteger
RentalDuration postgres.ColumnInteger
RentalRate postgres.ColumnFloat
Length postgres.ColumnInteger
ReplacementCost postgres.ColumnFloat
Rating postgres.ColumnString
LastUpdate postgres.ColumnTimestamp
SpecialFeatures postgres.ColumnString
Fulltext postgres.ColumnString
AllColumns jet.ColumnList
MutableColumns jet.ColumnList
AllColumns postgres.IColumnList
MutableColumns postgres.IColumnList
}
// creates new FilmTable with assigned alias
@ -47,23 +47,23 @@ func (a *FilmTable) AS(alias string) *FilmTable {
func newFilmTable() *FilmTable {
var (
FilmIDColumn = jet.IntegerColumn("film_id")
TitleColumn = jet.StringColumn("title")
DescriptionColumn = jet.StringColumn("description")
ReleaseYearColumn = jet.IntegerColumn("release_year")
LanguageIDColumn = jet.IntegerColumn("language_id")
RentalDurationColumn = jet.IntegerColumn("rental_duration")
RentalRateColumn = jet.FloatColumn("rental_rate")
LengthColumn = jet.IntegerColumn("length")
ReplacementCostColumn = jet.FloatColumn("replacement_cost")
RatingColumn = jet.StringColumn("rating")
LastUpdateColumn = jet.TimestampColumn("last_update")
SpecialFeaturesColumn = jet.StringColumn("special_features")
FulltextColumn = jet.StringColumn("fulltext")
FilmIDColumn = postgres.IntegerColumn("film_id")
TitleColumn = postgres.StringColumn("title")
DescriptionColumn = postgres.StringColumn("description")
ReleaseYearColumn = postgres.IntegerColumn("release_year")
LanguageIDColumn = postgres.IntegerColumn("language_id")
RentalDurationColumn = postgres.IntegerColumn("rental_duration")
RentalRateColumn = postgres.FloatColumn("rental_rate")
LengthColumn = postgres.IntegerColumn("length")
ReplacementCostColumn = postgres.FloatColumn("replacement_cost")
RatingColumn = postgres.StringColumn("rating")
LastUpdateColumn = postgres.TimestampColumn("last_update")
SpecialFeaturesColumn = postgres.StringColumn("special_features")
FulltextColumn = postgres.StringColumn("fulltext")
)
return &FilmTable{
Table: jet.NewTable("dvds", "film", FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn),
Table: postgres.NewTable("dvds", "film", FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn),
//Columns
FilmID: FilmIDColumn,
@ -80,7 +80,7 @@ func newFilmTable() *FilmTable {
SpecialFeatures: SpecialFeaturesColumn,
Fulltext: FulltextColumn,
AllColumns: jet.ColumnList{FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn},
MutableColumns: jet.ColumnList{TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn},
AllColumns: postgres.ColumnList(FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn),
MutableColumns: postgres.ColumnList(TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn),
}
}

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table
import (
"github.com/go-jet/jet"
"github.com/go-jet/jet/postgres"
)
var FilmActor = newFilmActorTable()
type FilmActorTable struct {
jet.Table
postgres.Table
//Columns
ActorID jet.ColumnInteger
FilmID jet.ColumnInteger
LastUpdate jet.ColumnTimestamp
ActorID postgres.ColumnInteger
FilmID postgres.ColumnInteger
LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList
MutableColumns jet.ColumnList
AllColumns postgres.IColumnList
MutableColumns postgres.IColumnList
}
// creates new FilmActorTable with assigned alias
@ -37,20 +37,20 @@ func (a *FilmActorTable) AS(alias string) *FilmActorTable {
func newFilmActorTable() *FilmActorTable {
var (
ActorIDColumn = jet.IntegerColumn("actor_id")
FilmIDColumn = jet.IntegerColumn("film_id")
LastUpdateColumn = jet.TimestampColumn("last_update")
ActorIDColumn = postgres.IntegerColumn("actor_id")
FilmIDColumn = postgres.IntegerColumn("film_id")
LastUpdateColumn = postgres.TimestampColumn("last_update")
)
return &FilmActorTable{
Table: jet.NewTable("dvds", "film_actor", ActorIDColumn, FilmIDColumn, LastUpdateColumn),
Table: postgres.NewTable("dvds", "film_actor", ActorIDColumn, FilmIDColumn, LastUpdateColumn),
//Columns
ActorID: ActorIDColumn,
FilmID: FilmIDColumn,
LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{ActorIDColumn, FilmIDColumn, LastUpdateColumn},
MutableColumns: jet.ColumnList{LastUpdateColumn},
AllColumns: postgres.ColumnList(ActorIDColumn, FilmIDColumn, LastUpdateColumn),
MutableColumns: postgres.ColumnList(LastUpdateColumn),
}
}

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table
import (
"github.com/go-jet/jet"
"github.com/go-jet/jet/postgres"
)
var FilmCategory = newFilmCategoryTable()
type FilmCategoryTable struct {
jet.Table
postgres.Table
//Columns
FilmID jet.ColumnInteger
CategoryID jet.ColumnInteger
LastUpdate jet.ColumnTimestamp
FilmID postgres.ColumnInteger
CategoryID postgres.ColumnInteger
LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList
MutableColumns jet.ColumnList
AllColumns postgres.IColumnList
MutableColumns postgres.IColumnList
}
// creates new FilmCategoryTable with assigned alias
@ -37,20 +37,20 @@ func (a *FilmCategoryTable) AS(alias string) *FilmCategoryTable {
func newFilmCategoryTable() *FilmCategoryTable {
var (
FilmIDColumn = jet.IntegerColumn("film_id")
CategoryIDColumn = jet.IntegerColumn("category_id")
LastUpdateColumn = jet.TimestampColumn("last_update")
FilmIDColumn = postgres.IntegerColumn("film_id")
CategoryIDColumn = postgres.IntegerColumn("category_id")
LastUpdateColumn = postgres.TimestampColumn("last_update")
)
return &FilmCategoryTable{
Table: jet.NewTable("dvds", "film_category", FilmIDColumn, CategoryIDColumn, LastUpdateColumn),
Table: postgres.NewTable("dvds", "film_category", FilmIDColumn, CategoryIDColumn, LastUpdateColumn),
//Columns
FilmID: FilmIDColumn,
CategoryID: CategoryIDColumn,
LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{FilmIDColumn, CategoryIDColumn, LastUpdateColumn},
MutableColumns: jet.ColumnList{LastUpdateColumn},
AllColumns: postgres.ColumnList(FilmIDColumn, CategoryIDColumn, LastUpdateColumn),
MutableColumns: postgres.ColumnList(LastUpdateColumn),
}
}

View file

@ -1,6 +1,6 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Wednesday, 17-Jul-19 13:11:01 CEST
// Generated at Thursday, 08-Aug-19 16:59:58 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -9,21 +9,21 @@
package table
import (
"github.com/go-jet/jet"
"github.com/go-jet/jet/postgres"
)
var Language = newLanguageTable()
type LanguageTable struct {
jet.Table
postgres.Table
//Columns
LanguageID jet.ColumnInteger
Name jet.ColumnString
LastUpdate jet.ColumnTimestamp
LanguageID postgres.ColumnInteger
Name postgres.ColumnString
LastUpdate postgres.ColumnTimestamp
AllColumns jet.ColumnList
MutableColumns jet.ColumnList
AllColumns postgres.IColumnList
MutableColumns postgres.IColumnList
}
// creates new LanguageTable with assigned alias
@ -37,20 +37,20 @@ func (a *LanguageTable) AS(alias string) *LanguageTable {
func newLanguageTable() *LanguageTable {
var (
LanguageIDColumn = jet.IntegerColumn("language_id")
NameColumn = jet.StringColumn("name")
LastUpdateColumn = jet.TimestampColumn("last_update")
LanguageIDColumn = postgres.IntegerColumn("language_id")
NameColumn = postgres.StringColumn("name")
LastUpdateColumn = postgres.TimestampColumn("last_update")
)
return &LanguageTable{
Table: jet.NewTable("dvds", "language", LanguageIDColumn, NameColumn, LastUpdateColumn),
Table: postgres.NewTable("dvds", "language", LanguageIDColumn, NameColumn, LastUpdateColumn),
//Columns
LanguageID: LanguageIDColumn,
Name: NameColumn,
LastUpdate: LastUpdateColumn,
AllColumns: jet.ColumnList{LanguageIDColumn, NameColumn, LastUpdateColumn},
MutableColumns: jet.ColumnList{NameColumn, LastUpdateColumn},
AllColumns: postgres.ColumnList(LanguageIDColumn, NameColumn, LastUpdateColumn),
MutableColumns: postgres.ColumnList(NameColumn, LastUpdateColumn),
}
}

View file

@ -9,8 +9,8 @@ import (
// dot import so go code would resemble as much as native SQL
// dot import is not mandatory
. "github.com/go-jet/jet"
. "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table"
. "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/model"
)
@ -24,7 +24,6 @@ const (
)
func main() {
// Connect to database
var connectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName)
@ -97,17 +96,15 @@ func jsonSave(path string, v interface{}) {
}
}
func printStatementInfo(stmt Statement) {
query, args, err := stmt.Sql()
panicOnError(err)
func printStatementInfo(stmt SelectStatement) {
query, args := stmt.Sql()
fmt.Println("Parameterized query: ")
fmt.Println(query)
fmt.Println("Arguments: ")
fmt.Println(args)
debugSQL, err := stmt.DebugSql()
panicOnError(err)
debugSQL := stmt.DebugSql()
fmt.Println("\n\n==============================")

View file

@ -4,10 +4,10 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"github.com/go-jet/jet/execution/internal"
"github.com/go-jet/jet/internal/utils"
"github.com/google/uuid"
"reflect"
"strconv"
"strings"
@ -18,14 +18,11 @@ import (
// Destination can be either pointer to struct or pointer to slice of structs.
func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error {
if utils.IsNil(destinationPtr) {
return errors.New("jet: Destination is nil")
}
utils.MustBeInitializedPtr(db, "jet: db is nil")
utils.MustBeInitializedPtr(destinationPtr, "jet: destination is nil")
utils.MustBe(destinationPtr, reflect.Ptr, "jet: destination has to be a pointer to slice or pointer to struct")
destinationPtrType := reflect.TypeOf(destinationPtr)
if destinationPtrType.Kind() != reflect.Ptr {
return errors.New("jet: Destination has to be a pointer to slice or pointer to struct")
}
if destinationPtrType.Elem().Kind() == reflect.Slice {
return queryToSlice(context, db, query, args, destinationPtr)
@ -51,24 +48,11 @@ func Query(context context.Context, db DB, query string, args []interface{}, des
}
return nil
} else {
return errors.New("jet: unsupported destination type")
panic("jet: destination has to be a pointer to slice or pointer to struct")
}
}
func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) error {
if db == nil {
return errors.New("jet: db is nil")
}
if slicePtr == nil {
return errors.New("jet: Destination is nil. ")
}
destinationType := reflect.TypeOf(slicePtr)
if destinationType.Kind() != reflect.Ptr && destinationType.Elem().Kind() != reflect.Slice {
return errors.New("jet: Destination has to be a pointer to slice. ")
}
if ctx == nil {
ctx = context.Background()
}
@ -126,14 +110,12 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
sliceElemType := getSliceElemType(slicePtrValue)
if isGoBaseType(sliceElemType) {
if isSimpleModelType(sliceElemType) {
updated, err = mapRowToBaseTypeSlice(scanContext, slicePtrValue, field)
return
}
if sliceElemType.Kind() != reflect.Struct {
return false, errors.New("jet: Unsupported dest type: " + field.Name + " " + field.Type.String())
}
utils.TypeMustBe(sliceElemType, reflect.Struct, "jet: unsupported slice element type"+fieldToString(field))
structGroupKey := scanContext.getGroupKey(sliceElemType, field)
@ -226,7 +208,7 @@ func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
if implementsScannerType(field.Type) {
fieldMap.implementsScanner = true
} else if !isGoBaseType(field.Type) {
} else if !isSimpleModelType(field.Type) {
fieldMap.complexType = true
}
@ -249,6 +231,10 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
field := structType.Field(i)
fieldValue := structValue.Field(i)
if !fieldValue.CanSet() { // private field
continue
}
fieldMap := typeInf.fieldMappings[i]
if fieldMap.complexType {
@ -284,8 +270,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
err = scanner.Scan(cellValue)
if err != nil {
err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String())
return
panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String())
}
updated = true
} else {
@ -294,12 +279,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
if cellValue != nil {
updated = true
initializeValueIfNilPtr(fieldValue)
err = setReflectValue(reflect.ValueOf(cellValue), fieldValue)
if err != nil {
err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String())
return
}
setReflectValue(reflect.ValueOf(cellValue), fieldValue)
}
}
}
@ -310,9 +290,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
if destPtrValue.Kind() != reflect.Ptr {
return false, errors.New("jet: Internal error. ")
}
utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.")
destValueKind := destPtrValue.Elem().Kind()
@ -321,7 +299,7 @@ func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrVa
} else if destValueKind == reflect.Slice {
return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
} else {
return false, errors.New("jet: Unsupported dest type: " + structField.Name + " " + structField.Type.String())
panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
}
}
@ -331,14 +309,12 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
if dest.Kind() != reflect.Ptr {
destPtrValue = dest.Addr()
} else if dest.Kind() == reflect.Ptr {
} else {
if dest.IsNil() {
destPtrValue = reflect.New(dest.Type().Elem())
} else {
destPtrValue = dest
}
} else {
return false, errors.New("jet: Internal error. ")
}
updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField)
@ -399,7 +375,7 @@ func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value {
func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error {
if slicePtrValue.IsNil() {
panic("Slice is nil")
panic("jet: internal, slice is nil")
}
sliceValue := slicePtrValue.Elem()
sliceElemType := sliceValue.Type().Elem()
@ -410,8 +386,12 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
newElemValue = objPtrValue.Elem()
}
if newElemValue.Type().ConvertibleTo(sliceElemType) {
newElemValue = newElemValue.Convert(sliceElemType)
}
if !newElemValue.Type().AssignableTo(sliceElemType) {
return fmt.Errorf("jet: can't append %s to %s slice ", newElemValue.Type().String(), sliceValue.Type().String())
panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice")
}
sliceValue.Set(reflect.Append(sliceValue, newElemValue))
@ -465,6 +445,7 @@ func toCommonIdentifier(name string) string {
}
func initializeValueIfNilPtr(value reflect.Value) {
if !value.IsValid() || !value.CanSet() {
return
}
@ -490,55 +471,119 @@ func valueToString(value reflect.Value) string {
valueInterface = value.Interface()
}
if t, ok := valueInterface.(time.Time); ok {
if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String()
}
return fmt.Sprintf("%#v", valueInterface)
}
func isGoBaseType(objType reflect.Type) bool {
typeStr := objType.String()
var timeType = reflect.TypeOf(time.Now())
var uuidType = reflect.TypeOf(uuid.New())
switch typeStr {
case "string", "int", "int16", "int32", "int64", "float32", "float64", "time.Time", "bool", "[]byte", "[]uint8",
"*string", "*int", "*int16", "*int32", "*int64", "*float32", "*float64", "*time.Time", "*bool", "*[]byte", "*[]uint8":
func isSimpleModelType(objType reflect.Type) bool {
objType = indirectType(objType)
switch objType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String,
reflect.Bool:
return true
case reflect.Slice:
return objType.Elem().Kind() == reflect.Uint8 //[]byte
case reflect.Struct:
return objType == timeType || objType == uuidType // time.Time || uuid.UUID
}
return false
}
func isIntegerType(value reflect.Type) bool {
switch value {
case int8Type, unit8Type, int16Type, uint16Type,
int32Type, uint32Type, int64Type, uint64Type:
return true
}
return false
}
func setReflectValue(source, destination reflect.Value) error {
var sourceElem reflect.Value
func tryAssign(source, destination reflect.Value) bool {
if source.Type().ConvertibleTo(destination.Type()) {
source = source.Convert(destination.Type())
}
if isIntegerType(source.Type()) && destination.Type() == boolType {
intValue := source.Int()
if intValue == 1 {
source = reflect.ValueOf(true)
} else if intValue == 0 {
source = reflect.ValueOf(false)
}
}
if source.Type().AssignableTo(destination.Type()) {
destination.Set(source)
return true
}
return false
}
func setReflectValue(source, destination reflect.Value) {
if tryAssign(source, destination) {
return
}
if destination.Kind() == reflect.Ptr {
if source.Kind() == reflect.Ptr {
sourceElem = source
if !source.IsNil() {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
} else {
return
}
} else {
if source.CanAddr() {
sourceElem = source.Addr()
source = source.Addr()
} else {
sourceCopy := reflect.New(source.Type())
sourceCopy.Elem().Set(source)
sourceElem = sourceCopy
source = sourceCopy
}
if tryAssign(source, destination) {
return
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
}
} else {
if source.Kind() == reflect.Ptr {
sourceElem = source.Elem()
} else {
sourceElem = source
if source.IsNil() {
return
}
source = source.Elem()
}
if tryAssign(source, destination) {
return
}
}
if !sourceElem.Type().AssignableTo(destination.Type()) {
return errors.New("jet: can't set " + sourceElem.Type().String() + " to " + destination.Type().String())
}
destination.Set(sourceElem)
return nil
panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String())
}
func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
@ -555,35 +600,49 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
return values
}
var nullFloatType = reflect.TypeOf(internal.NullFloat32{})
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
var boolType = reflect.TypeOf(true)
var int8Type = reflect.TypeOf(int8(1))
var unit8Type = reflect.TypeOf(uint8(1))
var int16Type = reflect.TypeOf(int16(1))
var uint16Type = reflect.TypeOf(uint16(1))
var int32Type = reflect.TypeOf(int32(1))
var uint32Type = reflect.TypeOf(uint32(1))
var int64Type = reflect.TypeOf(int64(1))
var uint64Type = reflect.TypeOf(uint64(1))
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullInt8Type = reflect.TypeOf(internal.NullInt8{})
var nullInt16Type = reflect.TypeOf(internal.NullInt16{})
var nullInt32Type = reflect.TypeOf(internal.NullInt32{})
var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{})
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
var nullStringType = reflect.TypeOf(sql.NullString{})
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullTimeType = reflect.TypeOf(internal.NullTime{})
var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{})
func newScanType(columnType *sql.ColumnType) reflect.Type {
switch columnType.DatabaseTypeName() {
case "INT2":
case "TINYINT":
return nullInt8Type
case "INT2", "SMALLINT", "YEAR":
return nullInt16Type
case "INT4":
case "INT4", "MEDIUMINT", "INT":
return nullInt32Type
case "INT8":
case "INT8", "BIGINT":
return nullInt64Type
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
return nullStringType
case "FLOAT4":
return nullFloatType
case "FLOAT8", "NUMERIC", "DECIMAL":
return nullFloat32Type
case "FLOAT8", "NUMERIC", "DECIMAL", "FLOAT", "DOUBLE":
return nullFloat64Type
case "BOOL":
return nullBoolType
case "BYTEA":
case "BYTEA", "BINARY", "VARBINARY", "BLOB":
return nullByteArrayType
case "DATE", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
return nullTimeType
default:
return nullStringType
@ -697,7 +756,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl
field := structType.Field(i)
newTypeName, fieldName := getTypeAndFieldName(typeName, field)
if !isGoBaseType(field.Type) {
if !isSimpleModelType(field.Type) {
var structType reflect.Type
if field.Type.Kind() == reflect.Struct {
structType = field.Type
@ -749,7 +808,7 @@ func (s *scanContext) rowElem(index int) interface{} {
valuer, ok := s.row[index].(driver.Valuer)
if !ok {
panic("Scan value doesn't implement driver.Valuer")
panic("jet: internal error, scan value doesn't implement driver.Valuer")
}
value, err := valuer.Value()
@ -791,3 +850,11 @@ func indirectType(reflectType reflect.Type) reflect.Type {
}
return reflectType.Elem()
}
func fieldToString(field *reflect.StructField) string {
if field == nil {
return ""
}
return " at '" + field.Name + " " + field.Type.String() + "'"
}

View file

@ -2,9 +2,12 @@ package internal
import (
"database/sql/driver"
"strconv"
"time"
)
//===============================================================//
// NullByteArray struct
type NullByteArray struct {
ByteArray []byte
@ -31,6 +34,8 @@ func (nb NullByteArray) Value() (driver.Value, error) {
return nb.ByteArray, nil
}
//===============================================================//
// NullTime struct
type NullTime struct {
Time time.Time
@ -38,8 +43,20 @@ type NullTime struct {
}
// Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) error {
nt.Time, nt.Valid = value.(time.Time)
func (nt *NullTime) Scan(value interface{}) (err error) {
switch v := value.(type) {
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, nt.Valid = parseTime(string(v))
return
case string:
nt.Time, nt.Valid = parseTime(v)
return
}
nt.Valid = false
return nil
}
@ -51,24 +68,49 @@ func (nt NullTime) Value() (driver.Value, error) {
return nt.Time, nil
}
// NullInt32 struct
type NullInt32 struct {
Int32 int32
Valid bool // Valid is true if Int64 is not NULL
const formatTime = "2006-01-02 15:04:05.999999"
func parseTime(timeStr string) (t time.Time, valid bool) {
var format string
switch len(timeStr) {
case 8:
format = formatTime[11:19]
case 10, 19, 21, 22, 23, 24, 25, 26:
format = formatTime[:len(timeStr)]
default:
return t, false
}
t, err := time.Parse(format, timeStr)
return t, err == nil
}
//===============================================================//
// NullInt8 struct
type NullInt8 struct {
Int8 int8
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt32) Scan(value interface{}) error {
func (n *NullInt8) Scan(value interface{}) error {
switch v := value.(type) {
case int64:
n.Int32, n.Valid = int32(v), true
n.Int8, n.Valid = int8(v), true
return nil
case int32:
n.Int32, n.Valid = v, true
return nil
case uint8:
n.Int32, n.Valid = int32(v), true
case int8:
n.Int8, n.Valid = v, true
return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 8)
if err == nil {
n.Int8, n.Valid = int8(intV), true
return nil
}
}
n.Valid = false
@ -77,21 +119,24 @@ func (n *NullInt32) Scan(value interface{}) error {
}
// Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) {
func (n NullInt8) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int32, nil
return n.Int8, nil
}
//===============================================================//
// NullInt16 struct
type NullInt16 struct {
Int16 int16
Valid bool // Valid is true if Int64 is not NULL
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt16) Scan(value interface{}) error {
switch v := value.(type) {
case int64:
n.Int16, n.Valid = int16(v), true
@ -99,9 +144,18 @@ func (n *NullInt16) Scan(value interface{}) error {
case int16:
n.Int16, n.Valid = v, true
return nil
case int8:
n.Int16, n.Valid = int16(v), true
return nil
case uint8:
n.Int16, n.Valid = int16(v), true
return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 16)
if err == nil {
n.Int16, n.Valid = int16(intV), true
return nil
}
}
n.Valid = false
@ -117,10 +171,63 @@ func (n NullInt16) Value() (driver.Value, error) {
return n.Int16, nil
}
//===============================================================//
// NullInt32 struct
type NullInt32 struct {
Int32 int32
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt32) Scan(value interface{}) error {
switch v := value.(type) {
case int64:
n.Int32, n.Valid = int32(v), true
return nil
case int32:
n.Int32, n.Valid = v, true
return nil
case int16:
n.Int32, n.Valid = int32(v), true
return nil
case uint16:
n.Int32, n.Valid = int32(v), true
return nil
case int8:
n.Int32, n.Valid = int32(v), true
return nil
case uint8:
n.Int32, n.Valid = int32(v), true
return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 32)
if err == nil {
n.Int32, n.Valid = int32(intV), true
return nil
}
}
n.Valid = false
return nil
}
// Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int32, nil
}
//===============================================================//
// NullFloat32 struct
type NullFloat32 struct {
Float32 float32
Valid bool // Valid is true if Int64 is not NULL
Valid bool
}
// Scan implements the Scanner interface.

View file

@ -1,194 +0,0 @@
package jet
import (
"errors"
)
// Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface {
clause
projection
groupByClause
orderByClause
// Test expression whether it is a NULL value.
IS_NULL() BoolExpression
// Test expression whether it is a non-NULL value.
IS_NOT_NULL() BoolExpression
// Check if this expressions matches any in expressions list
IN(expressions ...Expression) BoolExpression
// Check if this expressions is different of all expressions in expressions list
NOT_IN(expressions ...Expression) BoolExpression
// The temporary alias name to assign to the expression
AS(alias string) projection
// Expression will be used to sort query result in ascending order
ASC() orderByClause
// Expression will be used to sort query result in ascending order
DESC() orderByClause
}
type expressionInterfaceImpl struct {
parent Expression
}
func (e *expressionInterfaceImpl) from(subQuery SelectTable) projection {
return e.parent
}
func (e *expressionInterfaceImpl) IS_NULL() BoolExpression {
return newPostifxBoolExpression(e.parent, "IS NULL")
}
func (e *expressionInterfaceImpl) IS_NOT_NULL() BoolExpression {
return newPostifxBoolExpression(e.parent, "IS NOT NULL")
}
func (e *expressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.parent, WRAP(expressions...), "IN")
}
func (e *expressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.parent, WRAP(expressions...), "NOT IN")
}
func (e *expressionInterfaceImpl) AS(alias string) projection {
return newAlias(e.parent, alias)
}
func (e *expressionInterfaceImpl) ASC() orderByClause {
return newOrderByClause(e.parent, true)
}
func (e *expressionInterfaceImpl) DESC() orderByClause {
return newOrderByClause(e.parent, false)
}
func (e *expressionInterfaceImpl) serializeForGroupBy(statement statementType, out *sqlBuilder) error {
return e.parent.serialize(statement, out, noWrap)
}
func (e *expressionInterfaceImpl) serializeForProjection(statement statementType, out *sqlBuilder) error {
return e.parent.serialize(statement, out, noWrap)
}
func (e *expressionInterfaceImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
return e.parent.serialize(statement, out, noWrap)
}
// Representation of binary operations (e.g. comparisons, arithmetic)
type binaryOpExpression struct {
lhs, rhs Expression
operator string
}
func newBinaryExpression(lhs, rhs Expression, operator string) binaryOpExpression {
binaryExpression := binaryOpExpression{
lhs: lhs,
rhs: rhs,
operator: operator,
}
return binaryExpression
}
func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if c == nil {
return errors.New("jet: binary Expression is nil")
}
if c.lhs == nil {
return errors.New("jet: nil lhs")
}
if c.rhs == nil {
return errors.New("jet: nil rhs")
}
wrap := !contains(options, noWrap)
if wrap {
out.writeString("(")
}
if err := c.lhs.serialize(statement, out); err != nil {
return err
}
out.writeString(c.operator)
if err := c.rhs.serialize(statement, out); err != nil {
return err
}
if wrap {
out.writeString(")")
}
return nil
}
// A prefix operator Expression
type prefixOpExpression struct {
expression Expression
operator string
}
func newPrefixExpression(expression Expression, operator string) prefixOpExpression {
prefixExpression := prefixOpExpression{
expression: expression,
operator: operator,
}
return prefixExpression
}
func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil {
return errors.New("jet: Prefix Expression is nil")
}
out.writeString(p.operator + " ")
if p.expression == nil {
return errors.New("jet: nil prefix Expression")
}
if err := p.expression.serialize(statement, out); err != nil {
return err
}
return nil
}
// A postifx operator Expression
type postfixOpExpression struct {
expression Expression
operator string
}
func newPostfixOpExpression(expression Expression, operator string) postfixOpExpression {
postfixOpExpression := postfixOpExpression{
expression: expression,
operator: operator,
}
return postfixOpExpression
}
func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil {
return errors.New("jet: Postifx operator Expression is nil")
}
if p.expression == nil {
return errors.New("jet: nil prefix Expression")
}
if err := p.expression.serialize(statement, out); err != nil {
return err
}
out.writeString(p.operator)
return nil
}

View file

@ -0,0 +1,174 @@
package metadata
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/internal/utils"
"strings"
)
// ColumnMetaData struct
type ColumnMetaData struct {
Name string
IsNullable bool
DataType string
EnumName string
IsUnsigned bool
SqlBuilderColumnType string
GoBaseType string
GoModelType string
}
func NewColumnMetaData(name string, isNullable bool, dataType string, enumName string, isUnsigned bool) ColumnMetaData {
columnMetaData := ColumnMetaData{
Name: name,
IsNullable: isNullable,
DataType: dataType,
EnumName: enumName,
IsUnsigned: isUnsigned,
}
columnMetaData.SqlBuilderColumnType = columnMetaData.getSqlBuilderColumnType()
columnMetaData.GoBaseType = columnMetaData.getGoBaseType()
columnMetaData.GoModelType = columnMetaData.getGoModelType()
return columnMetaData
}
// getSqlBuilderColumnType returns type of jet sql builder column
func (c ColumnMetaData) getSqlBuilderColumnType() string {
switch c.DataType {
case "boolean":
return "Bool"
case "smallint", "integer", "bigint",
"tinyint", "mediumint", "int", "year": //MySQL
return "Integer"
case "date":
return "Date"
case "timestamp without time zone",
"timestamp", "datetime": //MySQL:
return "Timestamp"
case "timestamp with time zone":
return "Timestampz"
case "time without time zone",
"time": //MySQL
return "Time"
case "time with time zone":
return "Timez"
case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "interval", "line", "ARRAY",
"char", "varchar", "binary", "varbinary",
"tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL
return "String"
case "real", "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "Float"
default:
fmt.Println("- [SQL Builder] Unsupported sql column '" + c.Name + " " + c.DataType + "', using StringColumn instead.")
return "String"
}
}
// getGoBaseType returns model type for column info.
func (c ColumnMetaData) getGoBaseType() string {
switch c.DataType {
case "USER-DEFINED", "enum":
return utils.ToGoIdentifier(c.EnumName)
case "boolean":
return "bool"
case "tinyint":
return "int8"
case "smallint",
"year":
return "int16"
case "integer",
"mediumint", "int": //MySQL
return "int32"
case "bigint":
return "int64"
case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone",
"timestamp", "datetime", "time": // MySQL
return "time.Time"
case "bytea",
"binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": //MySQL
return "[]byte"
case "text", "character", "character varying", "tsvector", "bit", "bit varying", "money", "json", "jsonb",
"xml", "point", "interval", "line", "ARRAY",
"char", "varchar", "tinytext", "mediumtext", "longtext": // MySQL
return "string"
case "real":
return "float32"
case "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "float64"
case "uuid":
return "uuid.UUID"
default:
fmt.Println("- [Model ] Unsupported sql column '" + c.Name + " " + c.DataType + "', using string instead.")
return "string"
}
}
// GoModelType returns model type for column info with optional pointer if
// column can be NULL.
func (c ColumnMetaData) getGoModelType() string {
typeStr := c.GoBaseType
if strings.Contains(typeStr, "int") && c.IsUnsigned {
typeStr = "u" + typeStr
}
if c.IsNullable {
return "*" + typeStr
}
return typeStr
}
// GoModelTag returns model field tag for column
func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string {
tags := []string{}
if isPrimaryKey {
tags = append(tags, "primary_key")
}
if len(tags) > 0 {
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
}
return ""
}
func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) ([]ColumnMetaData, error) {
rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []ColumnMetaData{}
for rows.Next() {
var name, isNullable, dataType, enumName string
var isUnsigned bool
err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned)
if err != nil {
return nil, err
}
ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned))
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -0,0 +1,140 @@
package metadata
import (
"database/sql"
"strings"
)
type DialectQuerySet interface {
ListOfTablesQuery() string
PrimaryKeysQuery() string
ListOfColumnsQuery() string
ListOfEnumsQuery() string
GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaData, error)
}
type PostgresQuerySet struct{}
func (p *PostgresQuerySet) ListOfTablesQuery() string {
return `
SELECT table_name
FROM information_schema.tables
where table_schema = $1 and table_type = 'BASE TABLE';
`
}
func (p *PostgresQuerySet) PrimaryKeysQuery() string {
return `
SELECT c.column_name
FROM information_schema.key_column_usage AS c
LEFT JOIN information_schema.table_constraints AS t
ON t.constraint_name = c.constraint_name
WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY';
`
}
func (p *PostgresQuerySet) ListOfColumnsQuery() string {
return `
SELECT column_name, is_nullable, data_type, udt_name, FALSE
FROM information_schema.columns
where table_schema = $1 and table_name = $2
order by ordinal_position;`
}
func (p *PostgresQuerySet) ListOfEnumsQuery() string {
return `
SELECT t.typname,
e.enumlabel
FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;`
}
func (p *PostgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaData, error) {
return getEnumInfos(db, p, schemaName)
}
// =======================================================================//
type MySqlQuerySet struct{}
func (m *MySqlQuerySet) ListOfTablesQuery() string {
return `
SELECT table_name
FROM INFORMATION_SCHEMA.tables
WHERE table_schema = ? and table_type = 'BASE TABLE';
`
}
func (m *MySqlQuerySet) PrimaryKeysQuery() string {
return `
SELECT k.column_name
FROM information_schema.table_constraints t
JOIN information_schema.key_column_usage k
USING(constraint_name,table_schema,table_name)
WHERE t.constraint_type='PRIMARY KEY'
AND t.table_schema= ?
AND t.table_name= ?;
`
}
func (m *MySqlQuerySet) ListOfColumnsQuery() string {
return `
SELECT COLUMN_NAME,
IS_NULLABLE, IF(COLUMN_TYPE = 'tinyint(1)', 'boolean', DATA_TYPE),
IF(DATA_TYPE = 'enum', CONCAT(TABLE_NAME, '_', COLUMN_NAME), ''),
COLUMN_TYPE LIKE '%unsigned%'
FROM information_schema.columns
WHERE table_schema = ? and table_name = ?
ORDER BY ordinal_position;
`
}
func (m *MySqlQuerySet) ListOfEnumsQuery() string {
return `
SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5)
FROM information_schema.columns as c
INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name)
WHERE c.table_schema = ? AND DATA_TYPE = 'enum' AND t.TABLE_TYPE = 'BASE TABLE';
`
}
func (m *MySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaData, error) {
rows, err := db.Query(m.ListOfEnumsQuery(), schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []MetaData{}
for rows.Next() {
var enumName string
var enumValues string
err = rows.Scan(&enumName, &enumValues)
if err != nil {
return nil, err
}
enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1)
ret = append(ret, EnumMetaData{
name: enumName,
Values: strings.Split(enumValues, ","),
})
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -1,32 +1,23 @@
package postgresmeta
package metadata
import (
"database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
)
// EnumInfo struct
type EnumInfo struct {
// EnumMetaData struct
type EnumMetaData struct {
name string
Values []string
}
// Name returns enum name
func (e EnumInfo) Name() string {
func (e EnumMetaData) Name() string {
return e.name
}
func getEnumInfos(db *sql.DB, schemaName string) ([]metadata.MetaData, error) {
query := `
SELECT t.typname,
e.enumlabel
FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;`
func getEnumInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]MetaData, error) {
rows, err := db.Query(query, schemaName)
rows, err := db.Query(querySet.ListOfEnumsQuery(), schemaName)
if err != nil {
return nil, err
@ -55,10 +46,10 @@ ORDER BY n.nspname, t.typname, e.enumsortorder;`
return nil, err
}
ret := []metadata.MetaData{}
ret := []MetaData{}
for enumName, enumValues := range enumsInfosMap {
ret = append(ret, EnumInfo{
ret = append(ret, EnumMetaData{
enumName,
enumValues,
})

View file

@ -1,142 +0,0 @@
package postgresmeta
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/internal/utils"
"strings"
)
// ColumnInfo metadata struct
type ColumnInfo struct {
Name string
IsNullable bool
DataType string
EnumName string
}
// SqlBuilderColumnType returns type of jet sql builder column
func (c ColumnInfo) SqlBuilderColumnType() string {
switch c.DataType {
case "boolean":
return "Bool"
case "smallint", "integer", "bigint":
return "Integer"
case "date":
return "Date"
case "timestamp without time zone":
return "Timestamp"
case "timestamp with time zone":
return "Timestampz"
case "time without time zone":
return "Time"
case "time with time zone":
return "Timez"
case "USER-DEFINED", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "interval", "line", "ARRAY":
return "String"
case "real", "numeric", "decimal", "double precision":
return "Float"
default:
fmt.Println("Unsupported sql type: " + c.DataType + ", using string column instead for sql builder.")
return "String"
}
}
// GoBaseType returns model type for column info.
func (c ColumnInfo) GoBaseType() string {
switch c.DataType {
case "USER-DEFINED":
return utils.ToGoIdentifier(c.EnumName)
case "boolean":
return "bool"
case "smallint":
return "int16"
case "integer":
return "int32"
case "bigint":
return "int64"
case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone":
return "time.Time"
case "bytea":
return "[]byte"
case "text", "character", "character varying", "tsvector", "bit", "bit varying", "money", "json", "jsonb",
"xml", "point", "interval", "line", "ARRAY":
return "string"
case "real":
return "float32"
case "numeric", "decimal", "double precision":
return "float64"
case "uuid":
return "uuid.UUID"
default:
fmt.Println("Unsupported sql type: " + c.DataType + ", " + c.EnumName + ", using string instead for model type.")
return "string"
}
}
// GoModelType returns model type for column info with optional pointer if
// column can be NULL.
func (c ColumnInfo) GoModelType() string {
typeStr := c.GoBaseType()
if c.IsNullable {
return "*" + typeStr
}
return typeStr
}
// GoModelTag returns model field tag for column
func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string {
tags := []string{}
if isPrimaryKey {
tags = append(tags, "primary_key")
}
if len(tags) > 0 {
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
}
return ""
}
func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) {
query := `
SELECT column_name, is_nullable, data_type, udt_name
FROM information_schema.columns
where table_catalog = $1 and table_schema = $2 and table_name = $3
order by ordinal_position;`
rows, err := db.Query(query, dbName, schemaName, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []ColumnInfo{}
for rows.Next() {
columnInfo := ColumnInfo{}
var isNullable string
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType, &columnInfo.EnumName)
columnInfo.IsNullable = isNullable == "YES"
if err != nil {
return nil, err
}
ret = append(ret, columnInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -1,76 +0,0 @@
package postgresmeta
import (
"database/sql"
"github.com/go-jet/jet/generator/internal/metadata"
)
// SchemaInfo metadata struct
type SchemaInfo struct {
DatabaseName string
Name string
TableInfos []metadata.MetaData
EnumInfos []metadata.MetaData
}
// GetSchemaInfo returns schema information from db connection.
func GetSchemaInfo(db *sql.DB, databaseName, schemaName string) (schemaInfo SchemaInfo, err error) {
schemaInfo.DatabaseName = databaseName
schemaInfo.Name = schemaName
schemaInfo.TableInfos, err = getTableInfos(db, databaseName, schemaName)
if err != nil {
return
}
schemaInfo.EnumInfos, err = getEnumInfos(db, schemaName)
if err != nil {
return
}
return
}
func getTableInfos(db *sql.DB, dbName, schemaName string) ([]metadata.MetaData, error) {
query := `
SELECT table_name
FROM information_schema.tables
where table_catalog = $1 and table_schema = $2 and table_type = 'BASE TABLE';
`
rows, err := db.Query(query, dbName, schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []metadata.MetaData{}
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
return nil, err
}
tableInfo, err := GetTableInfo(db, dbName, schemaName, tableName)
if err != nil {
return nil, err
}
ret = append(ret, tableInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -0,0 +1,68 @@
package metadata
import (
"database/sql"
"fmt"
)
// SchemaMetaData struct
type SchemaMetaData struct {
TableInfos []MetaData
EnumInfos []MetaData
}
// GetSchemaInfo returns schema information from db connection.
func GetSchemaInfo(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData, err error) {
schemaInfo.TableInfos, err = getTableInfos(db, querySet, schemaName)
if err != nil {
return
}
schemaInfo.EnumInfos, err = querySet.GetEnumsMetaData(db, schemaName)
if err != nil {
return
}
fmt.Println(" FOUND", len(schemaInfo.TableInfos), "table(s), ", len(schemaInfo.EnumInfos), "enum(s)")
return
}
func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]MetaData, error) {
rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []MetaData{}
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
return nil, err
}
tableInfo, err := GetTableInfo(db, querySet, schemaName, tableName)
if err != nil {
return nil, err
}
ret = append(ret, tableInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -1,31 +1,31 @@
package postgresmeta
package metadata
import (
"database/sql"
"github.com/go-jet/jet/internal/utils"
)
// TableInfo metadata struct
type TableInfo struct {
// TableMetaData metadata struct
type TableMetaData struct {
SchemaName string
name string
PrimaryKeys map[string]bool
Columns []ColumnInfo
Columns []ColumnMetaData
}
// Name returns table info name
func (t TableInfo) Name() string {
func (t TableMetaData) Name() string {
return t.name
}
// IsPrimaryKey returns if column is a part of primary key
func (t TableInfo) IsPrimaryKey(column string) bool {
func (t TableMetaData) IsPrimaryKey(column string) bool {
return t.PrimaryKeys[column]
}
// MutableColumns returns list of mutable columns for table
func (t TableInfo) MutableColumns() []ColumnInfo {
ret := []ColumnInfo{}
func (t TableMetaData) MutableColumns() []ColumnMetaData {
ret := []ColumnMetaData{}
for _, column := range t.Columns {
if t.IsPrimaryKey(column.Name) {
@ -39,11 +39,11 @@ func (t TableInfo) MutableColumns() []ColumnInfo {
}
// GetImports returns model imports for table.
func (t TableInfo) GetImports() []string {
func (t TableMetaData) GetImports() []string {
imports := map[string]string{}
for _, column := range t.Columns {
columnType := column.GoBaseType()
columnType := column.GoBaseType
switch columnType {
case "time.Time":
@ -63,22 +63,22 @@ func (t TableInfo) GetImports() []string {
}
// GoStructName returns go struct name for sql builder
func (t TableInfo) GoStructName() string {
func (t TableMetaData) GoStructName() string {
return utils.ToGoIdentifier(t.name) + "Table"
}
// GetTableInfo returns table info metadata
func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo TableInfo, err error) {
func GetTableInfo(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData, err error) {
tableInfo.SchemaName = schemaName
tableInfo.name = tableName
tableInfo.PrimaryKeys, err = getPrimaryKeys(db, dbName, schemaName, tableName)
tableInfo.PrimaryKeys, err = getPrimaryKeys(db, querySet, schemaName, tableName)
if err != nil {
return
}
tableInfo.Columns, err = getColumnInfos(db, dbName, schemaName, tableName)
tableInfo.Columns, err = getColumnsMetaData(db, querySet, schemaName, tableName)
if err != nil {
return
@ -87,15 +87,9 @@ func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo T
return
}
func getPrimaryKeys(db *sql.DB, dbName, schemaName, tableName string) (map[string]bool, error) {
query := `
SELECT c.column_name
FROM information_schema.key_column_usage AS c
LEFT JOIN information_schema.table_constraints AS t
ON t.constraint_name = c.constraint_name
WHERE t.table_catalog = $1 AND t.table_schema = $2 AND t.table_name = $3 AND t.constraint_type = 'PRIMARY KEY';
`
rows, err := db.Query(query, dbName, schemaName, tableName)
func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (map[string]bool, error) {
rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName)
if err != nil {
return nil, err

View file

@ -0,0 +1,118 @@
package template
import (
"bytes"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/utils"
"path/filepath"
"text/template"
"time"
)
func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect jet.Dialect) error {
if len(tables) == 0 && len(enums) == 0 {
return nil
}
fmt.Println("Destination directory:", destDir)
fmt.Println("Cleaning up destination directory...")
err := utils.CleanUpGeneratedFiles(destDir)
if err != nil {
return err
}
fmt.Println("Generating table sql builder files...")
err = generate(destDir, "table", tableSQLBuilderTemplate, tables, dialect)
if err != nil {
return err
}
fmt.Println("Generating table model files...")
err = generate(destDir, "model", tableModelTemplate, tables, dialect)
if err != nil {
return err
}
if len(enums) > 0 {
fmt.Println("Generating enum sql builder files...")
err = generate(destDir, "enum", enumSQLBuilderTemplate, enums, dialect)
if err != nil {
return err
}
fmt.Println("Generating enum model files...")
err = generate(destDir, "model", enumModelTemplate, enums, dialect)
if err != nil {
return err
}
}
fmt.Println("Done")
return nil
}
func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error {
modelDirPath := filepath.Join(dirPath, packageName)
err := utils.EnsureDirPath(modelDirPath)
if err != nil {
return err
}
autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect)
if err != nil {
return err
}
for _, metaData := range metaDataList {
text, err := GenerateTemplate(template, metaData, dialect)
if err != nil {
return err
}
err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...))
if err != nil {
return err
}
}
return nil
}
// GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}, dialect1 jet.Dialect) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": utils.ToGoIdentifier,
"now": func() string {
return time.Now().Format(time.RFC850)
},
"dialect": func() jet.Dialect {
return dialect1
},
}).Parse(templateText)
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := t.Execute(&buf, templateData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View file

@ -1,4 +1,4 @@
package postgres
package template
var autoGenWarningTemplate = `
//
@ -11,7 +11,7 @@ var autoGenWarningTemplate = `
`
var sqlBuilderTableTemplate = `
var tableSQLBuilderTemplate = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column
@ -21,21 +21,21 @@ var sqlBuilderTableTemplate = `
package table
import (
"github.com/go-jet/jet"
"github.com/go-jet/jet/{{dialect.PackageName}}"
)
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}()
type {{.GoStructName}} struct {
jet.Table
{{dialect.PackageName}}.Table
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}} jet.Column{{.SqlBuilderColumnType}}
{{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
{{- end}}
AllColumns jet.ColumnList
MutableColumns jet.ColumnList
AllColumns {{dialect.PackageName}}.IColumnList
MutableColumns {{dialect.PackageName}}.IColumnList
}
// creates new {{.GoStructName}} with assigned alias
@ -50,26 +50,26 @@ func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
func new{{.GoStructName}}() *{{.GoStructName}} {
var (
{{- range .Columns}}
{{ToGoIdentifier .Name}}Column = jet.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}}
)
return &{{.GoStructName}}{
Table: jet.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}),
Table: {{dialect.PackageName}}.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}),
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column,
{{- end}}
AllColumns: jet.ColumnList{ {{template "column-list" .Columns}} },
MutableColumns: jet.ColumnList{ {{template "column-list" .MutableColumns}} },
AllColumns: {{dialect.PackageName}}.ColumnList( {{template "column-list" .Columns}} ),
MutableColumns: {{dialect.PackageName}}.ColumnList( {{template "column-list" .MutableColumns}} ),
}
}
`
var dataModelTemplate = `package model
var tableModelTemplate = `package model
{{ if .GetImports }}
import (
@ -85,6 +85,22 @@ type {{ToGoIdentifier .Name}} struct {
{{ToGoIdentifier .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + `
{{- end}}
}
`
var enumSQLBuilderTemplate = `package enum
import "github.com/go-jet/jet/{{dialect.PackageName}}"
var {{ToGoIdentifier $.Name}} = &struct {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}} {{dialect.PackageName}}.StringExpression
{{- end}}
} {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}}: {{dialect.PackageName}}.NewEnumValue("{{$element}}"),
{{- end}}
}
`
var enumModelTemplate = `package model
@ -121,17 +137,3 @@ func (e {{ToGoIdentifier $.Name}}) String() string {
}
`
var enumTypeTemplate = `package enum
import "github.com/go-jet/jet"
var {{ToGoIdentifier $.Name}} = &struct {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}} jet.StringExpression
{{- end}}
} {
{{- range $index, $element := .Values}}
{{ToGoIdentifier $element}}: jet.NewEnumValue("{{$element}}"),
{{- end}}
}
`

View file

@ -0,0 +1,71 @@
package mysql
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/template"
"github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/mysql"
"path"
)
type DBConnection struct {
Host string
Port int
User string
Password string
SslMode string
Params string
DBName string
}
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
db, err := openConnection(dbConn)
if err != nil {
return err
}
defer utils.DBClose(db)
fmt.Println("Retrieving database information...")
// No schemas in MySQL
dbInfo, err := metadata.GetSchemaInfo(db, dbConn.DBName, &metadata.MySqlQuerySet{})
if err != nil {
return err
}
genPath := path.Join(destDir, dbConn.DBName)
err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, mysql.Dialect)
if err != nil {
return err
}
return nil
}
func openConnection(dbConn DBConnection) (*sql.DB, error) {
var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName)
if dbConn.Params != "" {
connectionString += "?" + dbConn.Params
}
db, err := sql.Open("mysql", connectionString)
fmt.Println("Connecting to MySQL database: " + connectionString)
if err != nil {
return nil, err
}
err = db.Ping()
if err != nil {
return nil, err
}
return db, nil
}

View file

@ -1,134 +0,0 @@
package postgres
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/metadata/postgresmeta"
"github.com/go-jet/jet/internal/utils"
"path"
"path/filepath"
"strconv"
)
// DBConnection contains postgres connection details
type DBConnection struct {
Host string
Port int
User string
Password string
SslMode string
Params string
DBName string
SchemaName string
}
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params)
fmt.Println("Connecting to postgres database: " + connectionString)
db, err := sql.Open("postgres", connectionString)
if err != nil {
return err
}
defer db.Close()
err = db.Ping()
if err != nil {
return err
}
fmt.Println("Retrieving schema information...")
schemaInfo, err := postgresmeta.GetSchemaInfo(db, dbConn.DBName, dbConn.SchemaName)
if err != nil {
return err
}
fmt.Println(" FOUND", len(schemaInfo.TableInfos), "table(s), ", len(schemaInfo.EnumInfos), "enum(s)")
if len(schemaInfo.TableInfos) == 0 && len(schemaInfo.EnumInfos) == 0 {
return nil
}
schemaGenPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
fmt.Println("Destination directory:", schemaGenPath)
fmt.Println("Cleaning up destination directory...")
err = utils.CleanUpGeneratedFiles(schemaGenPath)
if err != nil {
return err
}
fmt.Println("Generating table sql builder files...")
err = generate(schemaInfo, destDir, "table", sqlBuilderTableTemplate, schemaInfo.TableInfos)
if err != nil {
return err
}
fmt.Println("Generating table model files...")
err = generate(schemaInfo, destDir, "model", dataModelTemplate, schemaInfo.TableInfos)
if err != nil {
return err
}
if len(schemaInfo.EnumInfos) > 0 {
fmt.Println("Generating enum sql builder files...")
err = generate(schemaInfo, destDir, "enum", enumTypeTemplate, schemaInfo.EnumInfos)
if err != nil {
return err
}
fmt.Println("Generating enum model files...")
err = generate(schemaInfo, destDir, "model", enumModelTemplate, schemaInfo.EnumInfos)
if err != nil {
return err
}
}
fmt.Println("Done")
return nil
}
func generate(schemaInfo postgresmeta.SchemaInfo, dirPath, packageName string, template string, metaDataList []metadata.MetaData) error {
modelDirPath := filepath.Join(dirPath, schemaInfo.DatabaseName, schemaInfo.Name, packageName)
err := utils.EnsureDirPath(modelDirPath)
if err != nil {
return err
}
autoGenWarning, err := utils.GenerateTemplate(autoGenWarningTemplate, nil)
if err != nil {
return err
}
for _, metaData := range metaDataList {
text, err := utils.GenerateTemplate(template, metaData)
if err != nil {
return err
}
err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...))
if err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,73 @@
package postgres
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/template"
"github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/postgres"
"path"
"strconv"
)
// DBConnection contains postgres connection details
type DBConnection struct {
Host string
Port int
User string
Password string
SslMode string
Params string
DBName string
SchemaName string
}
// Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) error {
db, err := openConnection(dbConn)
defer utils.DBClose(db)
if err != nil {
return err
}
fmt.Println("Retrieving schema information...")
schemaInfo, err := metadata.GetSchemaInfo(db, dbConn.SchemaName, &metadata.PostgresQuerySet{})
if err != nil {
return err
}
genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, postgres.Dialect)
if err != nil {
return err
}
return nil
}
func openConnection(dbConn DBConnection) (*sql.DB, error) {
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params)
fmt.Println("Connecting to postgres database: " + connectionString)
db, err := sql.Open("postgres", connectionString)
if err != nil {
return nil, err
}
err = db.Ping()
if err != nil {
return nil, err
}
return db, nil
}

View file

@ -1,5 +0,0 @@
package jet
type groupByClause interface {
serializeForGroupBy(statement statementType, out *sqlBuilder) error
}

View file

@ -1,170 +0,0 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/utils"
)
// InsertStatement is interface for SQL INSERT statements
type InsertStatement interface {
Statement
// Insert row of values
VALUES(value interface{}, values ...interface{}) InsertStatement
// Insert row of values, where value for each column is extracted from filed of structure data.
// If data is not struct or there is no field for every column selected, this method will panic.
MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
RETURNING(projections ...projection) InsertStatement
}
func newInsertStatement(t WritableTable, columns []column) InsertStatement {
return &insertStatementImpl{
table: t,
columns: columns,
}
}
type insertStatementImpl struct {
table WritableTable
columns []column
rows [][]clause
query SelectStatement
returning []projection
}
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
i.rows = append(i.rows, unwindRowFromValues(value, values))
return i
}
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
i.rows = append(i.rows, unwindRowFromModel(i.getColumns(), data))
return i
}
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
i.rows = append(i.rows, unwindRowsFromModels(i.getColumns(), data)...)
return i
}
func (i *insertStatementImpl) RETURNING(projections ...projection) InsertStatement {
i.returning = projections
return i
}
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
i.query = selectStatement
return i
}
func (i *insertStatementImpl) getColumns() []column {
if len(i.columns) > 0 {
return i.columns
}
return i.table.columns()
}
func (i *insertStatementImpl) DebugSql() (query string, err error) {
return debugSql(i)
}
func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) {
queryData := &sqlBuilder{}
queryData.newLine()
queryData.writeString("INSERT INTO")
if utils.IsNil(i.table) {
return "", nil, errors.New("jet: table is nil")
}
err = i.table.serialize(insertStatement, queryData)
if err != nil {
return
}
if len(i.columns) > 0 {
queryData.writeString("(")
err = serializeColumnNames(i.columns, queryData)
if err != nil {
return
}
queryData.writeString(")")
}
if len(i.rows) == 0 && i.query == nil {
return "", nil, errors.New("jet: no row values or query specified")
}
if len(i.rows) > 0 && i.query != nil {
return "", nil, errors.New("jet: only row values or query has to be specified")
}
if len(i.rows) > 0 {
queryData.writeString("VALUES")
for rowIndex, row := range i.rows {
if rowIndex > 0 {
queryData.writeString(",")
}
queryData.increaseIdent()
queryData.newLine()
queryData.writeString("(")
err = serializeClauseList(insertStatement, row, queryData)
if err != nil {
return "", nil, err
}
queryData.writeByte(')')
queryData.decreaseIdent()
}
}
if i.query != nil {
err = i.query.serialize(insertStatement, queryData)
if err != nil {
return
}
}
if err = queryData.writeReturning(insertStatement, i.returning); err != nil {
return
}
sql, args = queryData.finalize()
return
}
func (i *insertStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(i, db, destination)
}
func (i *insertStatementImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, i, db, destination)
}
func (i *insertStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(i, db)
}
func (i *insertStatementImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, i, db)
}

View file

@ -1,13 +0,0 @@
package snaker
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestDb(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Snaker Suite")
}

View file

@ -1,40 +1,16 @@
package snaker
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"gotest.tools/assert"
"testing"
)
var _ = Describe("Snaker", func() {
Describe("SnakeToCamel test", func() {
It("should return an empty string on an empty input", func() {
Expect(SnakeToCamel("")).To(Equal(""))
})
It("should not blow up on trailing _", func() {
Expect(SnakeToCamel("potato_")).To(Equal("Potato"))
})
It("should return a snaked text as camel case", func() {
Expect(SnakeToCamel("this_has_to_be_uppercased")).To(
Equal("ThisHasToBeUppercased"))
})
It("should return a snaked text as camel case, except the word ID", func() {
Expect(SnakeToCamel("this_is_an_id")).To(Equal("ThisIsAnID"))
})
It("should return 'id' not as uppercase", func() {
Expect(SnakeToCamel("this_is_an_identifier")).To(Equal("ThisIsAnIdentifier"))
})
It("should simply work with id", func() {
Expect(SnakeToCamel("id")).To(Equal("ID"))
})
It("should work with initialism where only certain characters are uppercase", func() {
Expect(SnakeToCamel("oauth_client")).To(Equal("OAuthClient"))
})
})
})
func TestSnakeToCamel(t *testing.T) {
assert.Equal(t, SnakeToCamel(""), "")
assert.Equal(t, SnakeToCamel("potato_"), "Potato")
assert.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased")
assert.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID")
assert.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier")
assert.Equal(t, SnakeToCamel("id"), "ID")
assert.Equal(t, SnakeToCamel("oauth_client"), "OAuthClient")
}

28
internal/jet/alias.go Normal file
View file

@ -0,0 +1,28 @@
package jet
type alias struct {
expression Expression
alias string
}
func newAlias(expression Expression, aliasName string) Projection {
return &alias{
expression: expression,
alias: aliasName,
}
}
func (a *alias) fromImpl(subQuery SelectTable) Projection {
column := newColumn(a.alias, "", nil)
column.Parent = &column
column.subQuery = subQuery
return &column
}
func (a *alias) serializeForProjection(statement StatementType, out *SqlBuilder) {
a.expression.serialize(statement, out)
out.WriteString("AS")
out.WriteAlias(a.alias)
}

View file

@ -86,25 +86,25 @@ func (b *boolInterfaceImpl) IS_NOT_UNKNOWN() BoolExpression {
//---------------------------------------------------//
type binaryBoolExpression struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
boolInterfaceImpl
binaryOpExpression
}
func newBinaryBoolOperator(lhs, rhs Expression, operator string) BoolExpression {
boolExpression := binaryBoolExpression{}
func newBinaryBoolOperator(lhs, rhs Expression, operator string, additionalParams ...Expression) BoolExpression {
binaryBoolExpression := binaryBoolExpression{}
boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
boolExpression.expressionInterfaceImpl.parent = &boolExpression
boolExpression.boolInterfaceImpl.parent = &boolExpression
binaryBoolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator, additionalParams...)
binaryBoolExpression.ExpressionInterfaceImpl.Parent = &binaryBoolExpression
binaryBoolExpression.boolInterfaceImpl.parent = &binaryBoolExpression
return &boolExpression
return &binaryBoolExpression
}
//---------------------------------------------------//
type prefixBoolExpression struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
boolInterfaceImpl
prefixOpExpression
@ -114,7 +114,7 @@ func newPrefixBoolOperator(expression Expression, operator string) BoolExpressio
exp := prefixBoolExpression{}
exp.prefixOpExpression = newPrefixExpression(expression, operator)
exp.expressionInterfaceImpl.parent = &exp
exp.ExpressionInterfaceImpl.Parent = &exp
exp.boolInterfaceImpl.parent = &exp
return &exp
@ -122,7 +122,7 @@ func newPrefixBoolOperator(expression Expression, operator string) BoolExpressio
//---------------------------------------------------//
type postfixBoolOpExpression struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
boolInterfaceImpl
postfixOpExpression
@ -132,7 +132,7 @@ func newPostifxBoolExpression(expression Expression, operator string) BoolExpres
exp := postfixBoolOpExpression{}
exp.postfixOpExpression = newPostfixOpExpression(expression, operator)
exp.expressionInterfaceImpl.parent = &exp
exp.ExpressionInterfaceImpl.Parent = &exp
exp.boolInterfaceImpl.parent = &exp
return &exp

View file

@ -5,9 +5,8 @@ import (
)
func TestBoolExpressionEQ(t *testing.T) {
assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: nil rhs")
assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)")
assertClauseSerialize(t, table1ColBool.EQ(Bool(true)), "(table1.col_bool = $1)", true)
assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: rhs is nil for '=' operator")
}
func TestBoolExpressionNOT_EQ(t *testing.T) {
@ -57,6 +56,7 @@ func TestBinaryBoolExpression(t *testing.T) {
boolExpression := Int(2).EQ(Int(3))
assertClauseSerialize(t, boolExpression, "($1 = $2)", int64(2), int64(3))
assertProjectionSerialize(t, boolExpression, "$1 = $2", int64(2), int64(3))
assertProjectionSerialize(t, boolExpression.AS("alias_eq_expression"),
`($1 = $2) AS "alias_eq_expression"`, int64(2), int64(3))
@ -71,20 +71,6 @@ func TestBoolLiteral(t *testing.T) {
assertClauseSerialize(t, Bool(false), "$1", false)
}
func TestExists(t *testing.T) {
assertClauseSerialize(t, EXISTS(
table2.
SELECT(Int(1)).
WHERE(table1Col1.EQ(table2Col3)),
),
`EXISTS (
SELECT $1
FROM db.table2
WHERE table1.col1 = table2.col3
)`, int64(1))
}
func TestBoolExp(t *testing.T) {
assertClauseSerialize(t, BoolExp(String("true")), "$1", "true")
assertClauseSerialize(t, BoolExp(String("true")).IS_TRUE(), "$1 IS TRUE", "true")

51
internal/jet/cast.go Normal file
View file

@ -0,0 +1,51 @@
package jet
type Cast interface {
AS(castType string) Expression
}
type CastImpl struct {
expression Expression
}
func NewCastImpl(expression Expression) Cast {
castImpl := CastImpl{
expression: expression,
}
return &castImpl
}
func (b *CastImpl) AS(castType string) Expression {
castExp := &castExpression{
expression: b.expression,
cast: string(castType),
}
castExp.ExpressionInterfaceImpl.Parent = castExp
return castExp
}
type castExpression struct {
ExpressionInterfaceImpl
expression Expression
cast string
}
func (b *castExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
expression := b.expression
castType := b.cast
if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil {
castOverride(expression, String(castType))(statement, out, options...)
return
}
out.WriteString("CAST(")
expression.serialize(statement, out, options...)
out.WriteString("AS")
out.WriteString(castType + ")")
}

11
internal/jet/cast_test.go Normal file
View file

@ -0,0 +1,11 @@
package jet
import (
"testing"
)
func TestCastAS(t *testing.T) {
assertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST($1 AS boolean)", int64(1))
assertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)")
assertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)")
}

408
internal/jet/clause.go Normal file
View file

@ -0,0 +1,408 @@
package jet
import (
"github.com/go-jet/jet/internal/utils"
)
type Clause interface {
Serialize(statementType StatementType, out *SqlBuilder)
}
type ClauseWithProjections interface {
Clause
projections() ProjectionList
}
type ClauseSelect struct {
Distinct bool
Projections []Projection
}
func (s *ClauseSelect) projections() ProjectionList {
return s.Projections
}
func (s *ClauseSelect) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString("SELECT")
if s.Distinct {
out.WriteString("DISTINCT")
}
if len(s.Projections) == 0 {
panic("jet: SELECT clause has to have at least one projection")
}
out.WriteProjections(statementType, s.Projections)
}
type ClauseFrom struct {
Table Serializer
}
func (f *ClauseFrom) Serialize(statementType StatementType, out *SqlBuilder) {
if f.Table == nil {
return
}
out.NewLine()
out.WriteString("FROM")
out.IncreaseIdent()
f.Table.serialize(statementType, out)
out.DecreaseIdent()
}
type ClauseWhere struct {
Condition BoolExpression
Mandatory bool
}
func (c *ClauseWhere) Serialize(statementType StatementType, out *SqlBuilder) {
if c.Condition == nil {
if c.Mandatory {
panic("jet: WHERE clause not set")
}
return
}
out.NewLine()
out.WriteString("WHERE")
out.IncreaseIdent()
c.Condition.serialize(statementType, out, noWrap)
out.DecreaseIdent()
}
type ClauseGroupBy struct {
List []GroupByClause
}
func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SqlBuilder) {
if len(c.List) == 0 {
return
}
out.NewLine()
out.WriteString("GROUP BY")
out.IncreaseIdent()
serializeGroupByClauseList(statementType, c.List, out)
out.DecreaseIdent()
}
type ClauseHaving struct {
Condition BoolExpression
}
func (c *ClauseHaving) Serialize(statementType StatementType, out *SqlBuilder) {
if c.Condition == nil {
return
}
out.NewLine()
out.WriteString("HAVING")
out.IncreaseIdent()
c.Condition.serialize(statementType, out, noWrap)
out.DecreaseIdent()
}
type ClauseOrderBy struct {
List []OrderByClause
}
func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SqlBuilder) {
if o.List == nil {
return
}
out.NewLine()
out.WriteString("ORDER BY")
out.IncreaseIdent()
serializeOrderByClauseList(statementType, o.List, out)
out.DecreaseIdent()
}
type ClauseLimit struct {
Count int64
}
func (l *ClauseLimit) Serialize(statementType StatementType, out *SqlBuilder) {
if l.Count >= 0 {
out.NewLine()
out.WriteString("LIMIT")
out.insertParametrizedArgument(l.Count)
}
}
type ClauseOffset struct {
Count int64
}
func (o *ClauseOffset) Serialize(statementType StatementType, out *SqlBuilder) {
if o.Count >= 0 {
out.NewLine()
out.WriteString("OFFSET")
out.insertParametrizedArgument(o.Count)
}
}
type ClauseFor struct {
Lock SelectLock
}
func (f *ClauseFor) Serialize(statementType StatementType, out *SqlBuilder) {
if f.Lock == nil {
return
}
out.NewLine()
out.WriteString("FOR")
f.Lock.serialize(statementType, out)
}
type ClauseSetStmtOperator struct {
Operator string
All bool
Selects []StatementWithProjections
OrderBy ClauseOrderBy
Limit ClauseLimit
Offset ClauseOffset
}
func (s *ClauseSetStmtOperator) projections() ProjectionList {
if len(s.Selects) > 0 {
return s.Selects[0].projections()
}
return nil
}
func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlBuilder) {
if len(s.Selects) < 2 {
panic("jet: UNION Statement must contain at least two SELECT statements")
}
for i, selectStmt := range s.Selects {
out.NewLine()
if i > 0 {
out.WriteString(s.Operator)
if s.All {
out.WriteString("ALL")
}
out.NewLine()
}
if selectStmt == nil {
panic("jet: select statement of '" + s.Operator + "' is nil")
}
selectStmt.serialize(statementType, out)
}
s.OrderBy.Serialize(statementType, out)
s.Limit.Serialize(statementType, out)
s.Offset.Serialize(statementType, out)
}
type ClauseUpdate struct {
Table SerializerTable
}
func (u *ClauseUpdate) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString("UPDATE")
if utils.IsNil(u.Table) {
panic("jet: table to update is nil")
}
u.Table.serialize(statementType, out)
}
type ClauseSet struct {
Columns []Column
Values []Serializer
}
func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString("SET")
if len(s.Columns) != len(s.Values) {
panic("jet: mismatch in numbers of columns and values for SET clause")
}
out.IncreaseIdent(4)
for i, column := range s.Columns {
if i > 0 {
out.WriteString(", ")
out.NewLine()
}
if column == nil {
panic("jet: nil column in columns list for SET clause")
}
out.WriteString(column.Name())
out.WriteString(" = ")
s.Values[i].serialize(UpdateStatementType, out)
}
out.DecreaseIdent(4)
}
type ClauseInsert struct {
Table SerializerTable
Columns []Column
}
func (i *ClauseInsert) GetColumns() []Column {
if len(i.Columns) > 0 {
return i.Columns
}
return i.Table.columns()
}
func (i *ClauseInsert) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString("INSERT INTO")
if utils.IsNil(i.Table) {
panic("jet: table is nil for INSERT clause")
}
i.Table.serialize(statementType, out)
if len(i.Columns) > 0 {
out.WriteString("(")
SerializeColumnNames(i.Columns, out)
out.WriteString(")")
}
}
type ClauseValuesQuery struct {
ClauseValues
ClauseQuery
}
func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SqlBuilder) {
if len(v.Rows) == 0 && v.Query == nil {
panic("jet: VALUES or QUERY has to be specified for INSERT statement")
}
if len(v.Rows) > 0 && v.Query != nil {
panic("jet: VALUES or QUERY has to be specified for INSERT statement")
}
v.ClauseValues.Serialize(statementType, out)
v.ClauseQuery.Serialize(statementType, out)
}
type ClauseValues struct {
Rows [][]Serializer
}
func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) {
if len(v.Rows) == 0 {
return
}
out.WriteString("VALUES")
for rowIndex, row := range v.Rows {
if rowIndex > 0 {
out.WriteString(",")
}
out.IncreaseIdent()
out.NewLine()
out.WriteString("(")
SerializeClauseList(statementType, row, out)
out.WriteByte(')')
out.DecreaseIdent()
}
}
type ClauseQuery struct {
Query SerializerStatement
}
func (v *ClauseQuery) Serialize(statementType StatementType, out *SqlBuilder) {
if v.Query == nil {
return
}
v.Query.serialize(statementType, out)
}
type ClauseDelete struct {
Table SerializerTable
}
func (d *ClauseDelete) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString("DELETE FROM")
if d.Table == nil {
panic("jet: nil table in DELETE clause")
}
d.Table.serialize(statementType, out)
}
type ClauseStatementBegin struct {
Name string
Tables []SerializerTable
}
func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SqlBuilder) {
out.NewLine()
out.WriteString(d.Name)
for i, table := range d.Tables {
if i > 0 {
out.WriteString(", ")
}
table.serialize(statementType, out)
}
}
type ClauseOptional struct {
Name string
Show bool
InNewLine bool
}
func (d *ClauseOptional) Serialize(statementType StatementType, out *SqlBuilder) {
if !d.Show {
return
}
if d.InNewLine {
out.NewLine()
}
out.WriteString(d.Name)
}
type ClauseIn struct {
LockMode string
}
func (i *ClauseIn) Serialize(statementType StatementType, out *SqlBuilder) {
if i.LockMode == "" {
return
}
out.WriteString("IN")
out.WriteString(string(i.LockMode))
out.WriteString("MODE")
}

View file

@ -0,0 +1,16 @@
package jet
import (
"gotest.tools/assert"
"testing"
)
func TestClauseSelect_Serialize(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "jet: SELECT clause has to have at least one projection")
}()
selectClause := &ClauseSelect{}
selectClause.Serialize(SelectStatementType, &SqlBuilder{})
}

144
internal/jet/column.go Normal file
View file

@ -0,0 +1,144 @@
// Modeling of columns
package jet
type Column interface {
Name() string
TableName() string
setTableName(table string)
setSubQuery(subQuery SelectTable)
defaultAlias() string
}
// Column is common column interface for all types of columns.
type ColumnExpression interface {
Column
Expression
}
// The base type for real materialized columns.
type columnImpl struct {
ExpressionInterfaceImpl
name string
tableName string
subQuery SelectTable
}
func newColumn(name string, tableName string, parent ColumnExpression) columnImpl {
bc := columnImpl{
name: name,
tableName: tableName,
}
bc.ExpressionInterfaceImpl.Parent = parent
return bc
}
func (c *columnImpl) Name() string {
return c.name
}
func (c *columnImpl) TableName() string {
return c.tableName
}
func (c *columnImpl) setTableName(table string) {
c.tableName = table
}
func (c *columnImpl) setSubQuery(subQuery SelectTable) {
c.subQuery = subQuery
}
func (c *columnImpl) defaultAlias() string {
if c.tableName != "" {
return c.tableName + "." + c.name
}
return c.name
}
func (c *columnImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) {
if statement == SetStatementType {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.WriteAlias(c.defaultAlias()) //always quote
return
}
c.serialize(statement, out)
}
func (c columnImpl) serializeForProjection(statement StatementType, out *SqlBuilder) {
c.serialize(statement, out)
out.WriteString("AS")
out.WriteAlias(c.defaultAlias())
}
func (c columnImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if c.subQuery != nil {
out.WriteIdentifier(c.subQuery.Alias())
out.WriteByte('.')
out.WriteIdentifier(c.defaultAlias(), true)
} else {
if c.tableName != "" {
out.WriteIdentifier(c.tableName)
out.WriteByte('.')
}
out.WriteIdentifier(c.name)
}
}
//------------------------------------------------------//
type IColumnList interface {
Projection
Column
columns() []ColumnExpression
}
func ColumnList(columns ...ColumnExpression) IColumnList {
return columnListImpl(columns)
}
// ColumnList is redefined type to support list of columns as single Projection
type columnListImpl []ColumnExpression
func (cl columnListImpl) columns() []ColumnExpression {
return cl
}
func (cl columnListImpl) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
}
return newProjectionList
}
func (cl columnListImpl) serializeForProjection(statement StatementType, out *SqlBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionList(statement, projections, out)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl columnListImpl) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl columnListImpl) TableName() string { return "" }
func (cl columnListImpl) setTableName(name string) {}
func (cl columnListImpl) setSubQuery(subQuery SelectTable) {}
func (cl columnListImpl) defaultAlias() string { return "" }

View file

@ -4,7 +4,7 @@ import "testing"
func TestColumn(t *testing.T) {
column := newColumn("col", "", nil)
column.expressionInterfaceImpl.parent = &column
column.ExpressionInterfaceImpl.Parent = &column
assertClauseSerialize(t, column, "col")
column.setTableName("table1")

View file

@ -3,7 +3,7 @@ package jet
// ColumnBool is interface for SQL boolean columns.
type ColumnBool interface {
BoolExpression
column
Column
From(subQuery SelectTable) ColumnBool
}
@ -14,7 +14,7 @@ type boolColumnImpl struct {
columnImpl
}
func (i *boolColumnImpl) from(subQuery SelectTable) projection {
func (i *boolColumnImpl) fromImpl(subQuery SelectTable) Projection {
newBoolColumn := BoolColumn(i.name)
newBoolColumn.setTableName(i.tableName)
newBoolColumn.setSubQuery(subQuery)
@ -23,7 +23,7 @@ func (i *boolColumnImpl) from(subQuery SelectTable) projection {
}
func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
newBoolColumn := i.from(subQuery).(ColumnBool)
newBoolColumn := i.fromImpl(subQuery).(ColumnBool)
return newBoolColumn
}
@ -42,7 +42,7 @@ func BoolColumn(name string) ColumnBool {
// ColumnFloat is interface for SQL real, numeric, decimal or double precision column.
type ColumnFloat interface {
FloatExpression
column
Column
From(subQuery SelectTable) ColumnFloat
}
@ -52,7 +52,7 @@ type floatColumnImpl struct {
columnImpl
}
func (i *floatColumnImpl) from(subQuery SelectTable) projection {
func (i *floatColumnImpl) fromImpl(subQuery SelectTable) Projection {
newFloatColumn := FloatColumn(i.name)
newFloatColumn.setTableName(i.tableName)
newFloatColumn.setSubQuery(subQuery)
@ -61,7 +61,7 @@ func (i *floatColumnImpl) from(subQuery SelectTable) projection {
}
func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
newFloatColumn := i.from(subQuery).(ColumnFloat)
newFloatColumn := i.fromImpl(subQuery).(ColumnFloat)
return newFloatColumn
}
@ -80,7 +80,7 @@ func FloatColumn(name string) ColumnFloat {
// ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger interface {
IntegerExpression
column
Column
From(subQuery SelectTable) ColumnInteger
}
@ -91,7 +91,7 @@ type integerColumnImpl struct {
columnImpl
}
func (i *integerColumnImpl) from(subQuery SelectTable) projection {
func (i *integerColumnImpl) fromImpl(subQuery SelectTable) Projection {
newIntColumn := IntegerColumn(i.name)
newIntColumn.setTableName(i.tableName)
newIntColumn.setSubQuery(subQuery)
@ -100,7 +100,7 @@ func (i *integerColumnImpl) from(subQuery SelectTable) projection {
}
func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
return i.from(subQuery).(ColumnInteger)
return i.fromImpl(subQuery).(ColumnInteger)
}
// IntegerColumn creates named integer column.
@ -118,7 +118,7 @@ func IntegerColumn(name string) ColumnInteger {
// bytea, uuid columns and enums types.
type ColumnString interface {
StringExpression
column
Column
From(subQuery SelectTable) ColumnString
}
@ -129,7 +129,7 @@ type stringColumnImpl struct {
columnImpl
}
func (i *stringColumnImpl) from(subQuery SelectTable) projection {
func (i *stringColumnImpl) fromImpl(subQuery SelectTable) Projection {
newStrColumn := StringColumn(i.name)
newStrColumn.setTableName(i.tableName)
newStrColumn.setSubQuery(subQuery)
@ -138,7 +138,7 @@ func (i *stringColumnImpl) from(subQuery SelectTable) projection {
}
func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
return i.from(subQuery).(ColumnString)
return i.fromImpl(subQuery).(ColumnString)
}
// StringColumn creates named string column.
@ -155,7 +155,7 @@ func StringColumn(name string) ColumnString {
// ColumnTime is interface for SQL time column.
type ColumnTime interface {
TimeExpression
column
Column
From(subQuery SelectTable) ColumnTime
}
@ -165,7 +165,7 @@ type timeColumnImpl struct {
columnImpl
}
func (i *timeColumnImpl) from(subQuery SelectTable) projection {
func (i *timeColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimeColumn := TimeColumn(i.name)
newTimeColumn.setTableName(i.tableName)
newTimeColumn.setSubQuery(subQuery)
@ -174,7 +174,7 @@ func (i *timeColumnImpl) from(subQuery SelectTable) projection {
}
func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
return i.from(subQuery).(ColumnTime)
return i.fromImpl(subQuery).(ColumnTime)
}
// TimeColumn creates named time column
@ -190,7 +190,7 @@ func TimeColumn(name string) ColumnTime {
// ColumnTimez is interface of SQL time with time zone columns.
type ColumnTimez interface {
TimezExpression
column
Column
From(subQuery SelectTable) ColumnTimez
}
@ -201,7 +201,7 @@ type timezColumnImpl struct {
columnImpl
}
func (i *timezColumnImpl) from(subQuery SelectTable) projection {
func (i *timezColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimezColumn := TimezColumn(i.name)
newTimezColumn.setTableName(i.tableName)
newTimezColumn.setSubQuery(subQuery)
@ -210,7 +210,7 @@ func (i *timezColumnImpl) from(subQuery SelectTable) projection {
}
func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
return i.from(subQuery).(ColumnTimez)
return i.fromImpl(subQuery).(ColumnTimez)
}
// TimezColumn creates named time with time zone column.
@ -227,7 +227,7 @@ func TimezColumn(name string) ColumnTimez {
// ColumnTimestamp is interface of SQL timestamp columns.
type ColumnTimestamp interface {
TimestampExpression
column
Column
From(subQuery SelectTable) ColumnTimestamp
}
@ -238,7 +238,7 @@ type timestampColumnImpl struct {
columnImpl
}
func (i *timestampColumnImpl) from(subQuery SelectTable) projection {
func (i *timestampColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimestampColumn := TimestampColumn(i.name)
newTimestampColumn.setTableName(i.tableName)
newTimestampColumn.setSubQuery(subQuery)
@ -247,7 +247,7 @@ func (i *timestampColumnImpl) from(subQuery SelectTable) projection {
}
func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
return i.from(subQuery).(ColumnTimestamp)
return i.fromImpl(subQuery).(ColumnTimestamp)
}
// TimestampColumn creates named timestamp column
@ -264,7 +264,7 @@ func TimestampColumn(name string) ColumnTimestamp {
// ColumnTimestampz is interface of SQL timestamp with timezone columns.
type ColumnTimestampz interface {
TimestampzExpression
column
Column
From(subQuery SelectTable) ColumnTimestampz
}
@ -275,7 +275,7 @@ type timestampzColumnImpl struct {
columnImpl
}
func (i *timestampzColumnImpl) from(subQuery SelectTable) projection {
func (i *timestampzColumnImpl) fromImpl(subQuery SelectTable) Projection {
newTimestampzColumn := TimestampzColumn(i.name)
newTimestampzColumn.setTableName(i.tableName)
newTimestampzColumn.setSubQuery(subQuery)
@ -284,7 +284,7 @@ func (i *timestampzColumnImpl) from(subQuery SelectTable) projection {
}
func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
return i.from(subQuery).(ColumnTimestampz)
return i.fromImpl(subQuery).(ColumnTimestampz)
}
// TimestampzColumn creates named timestamp with time zone column.
@ -301,7 +301,7 @@ func TimestampzColumn(name string) ColumnTimestampz {
// ColumnDate is interface of SQL date columns.
type ColumnDate interface {
DateExpression
column
Column
From(subQuery SelectTable) ColumnDate
}
@ -312,7 +312,7 @@ type dateColumnImpl struct {
columnImpl
}
func (i *dateColumnImpl) from(subQuery SelectTable) projection {
func (i *dateColumnImpl) fromImpl(subQuery SelectTable) Projection {
newDateColumn := DateColumn(i.name)
newDateColumn.setTableName(i.tableName)
newDateColumn.setSubQuery(subQuery)
@ -321,7 +321,7 @@ func (i *dateColumnImpl) from(subQuery SelectTable) projection {
}
func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
return i.from(subQuery).(ColumnDate)
return i.fromImpl(subQuery).(ColumnDate)
}
// DateColumn creates named date column.

View file

@ -4,7 +4,9 @@ import (
"testing"
)
var subQuery = table1.SELECT(table1ColFloat, table1ColInt).AsTable("sub_query")
var subQuery = &SelectTableImpl{
alias: "sub_query",
}
func TestNewBoolColumn(t *testing.T) {
boolColumn := BoolColumn("colBool").From(subQuery)

83
internal/jet/dialect.go Normal file
View file

@ -0,0 +1,83 @@
package jet
type Dialect interface {
Name() string
PackageName() string
OperatorSerializeOverride(operator string) SerializeOverride
FunctionSerializeOverride(function string) SerializeOverride
AliasQuoteChar() byte
IdentifierQuoteChar() byte
ArgumentPlaceholder() QueryPlaceholderFunc
}
type SerializeFunc func(statement StatementType, out *SqlBuilder, options ...SerializeOption)
type SerializeOverride func(expressions ...Expression) SerializeFunc
type QueryPlaceholderFunc func(ord int) string
type DialectParams struct {
Name string
PackageName string
OperatorSerializeOverrides map[string]SerializeOverride
FunctionSerializeOverrides map[string]SerializeOverride
AliasQuoteChar byte
IdentifierQuoteChar byte
ArgumentPlaceholder QueryPlaceholderFunc
}
func NewDialect(params DialectParams) Dialect {
return &dialectImpl{
name: params.Name,
packageName: params.PackageName,
operatorSerializeOverrides: params.OperatorSerializeOverrides,
functionSerializeOverrides: params.FunctionSerializeOverrides,
aliasQuoteChar: params.AliasQuoteChar,
identifierQuoteChar: params.IdentifierQuoteChar,
argumentPlaceholder: params.ArgumentPlaceholder,
}
}
type dialectImpl struct {
name string
packageName string
operatorSerializeOverrides map[string]SerializeOverride
functionSerializeOverrides map[string]SerializeOverride
aliasQuoteChar byte
identifierQuoteChar byte
argumentPlaceholder QueryPlaceholderFunc
supportsReturning bool
}
func (d *dialectImpl) Name() string {
return d.name
}
func (d *dialectImpl) PackageName() string {
return d.packageName
}
func (d *dialectImpl) OperatorSerializeOverride(operator string) SerializeOverride {
if d.operatorSerializeOverrides == nil {
return nil
}
return d.operatorSerializeOverrides[operator]
}
func (d *dialectImpl) FunctionSerializeOverride(function string) SerializeOverride {
if d.functionSerializeOverrides == nil {
return nil
}
return d.functionSerializeOverrides[function]
}
func (d *dialectImpl) AliasQuoteChar() byte {
return d.aliasQuoteChar
}
func (d *dialectImpl) IdentifierQuoteChar() byte {
return d.identifierQuoteChar
}
func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc {
return d.argumentPlaceholder
}

View file

@ -1,8 +1,9 @@
package jet
type enumValue struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
stringInterfaceImpl
name string
}
@ -10,13 +11,12 @@ type enumValue struct {
func NewEnumValue(name string) StringExpression {
enumValue := &enumValue{name: name}
enumValue.expressionInterfaceImpl.parent = enumValue
enumValue.ExpressionInterfaceImpl.Parent = enumValue
enumValue.stringInterfaceImpl.parent = enumValue
return enumValue
}
func (e enumValue) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
func (e enumValue) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.insertConstantArgument(e.name)
return nil
}

178
internal/jet/expression.go Normal file
View file

@ -0,0 +1,178 @@
package jet
// Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface {
Serializer
Projection
GroupByClause
OrderByClause
// Test expression whether it is a NULL value.
IS_NULL() BoolExpression
// Test expression whether it is a non-NULL value.
IS_NOT_NULL() BoolExpression
// Check if this expressions matches any in expressions list
IN(expressions ...Expression) BoolExpression
// Check if this expressions is different of all expressions in expressions list
NOT_IN(expressions ...Expression) BoolExpression
// The temporary alias name to assign to the expression
AS(alias string) Projection
// Expression will be used to sort query result in ascending order
ASC() OrderByClause
// Expression will be used to sort query result in ascending order
DESC() OrderByClause
}
type ExpressionInterfaceImpl struct {
Parent Expression
}
func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection {
return e.Parent
}
func (e *ExpressionInterfaceImpl) IS_NULL() BoolExpression {
return newPostifxBoolExpression(e.Parent, "IS NULL")
}
func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression {
return newPostifxBoolExpression(e.Parent, "IS NOT NULL")
}
func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "IN")
}
func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "NOT IN")
}
func (e *ExpressionInterfaceImpl) AS(alias string) Projection {
return newAlias(e.Parent, alias)
}
func (e *ExpressionInterfaceImpl) ASC() OrderByClause {
return newOrderByClause(e.Parent, true)
}
func (e *ExpressionInterfaceImpl) DESC() OrderByClause {
return newOrderByClause(e.Parent, false)
}
func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SqlBuilder) {
e.Parent.serialize(statement, out, noWrap)
}
func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SqlBuilder) {
e.Parent.serialize(statement, out, noWrap)
}
func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) {
e.Parent.serialize(statement, out, noWrap)
}
// Representation of binary operations (e.g. comparisons, arithmetic)
type binaryOpExpression struct {
lhs, rhs Expression
additionalParam Expression
operator string
}
func newBinaryExpression(lhs, rhs Expression, operator string, additionalParam ...Expression) binaryOpExpression {
binaryExpression := binaryOpExpression{
lhs: lhs,
rhs: rhs,
operator: operator,
}
if len(additionalParam) > 0 {
binaryExpression.additionalParam = additionalParam[0]
}
return binaryExpression
}
func (c *binaryOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if c.lhs == nil {
panic("jet: lhs is nil for '" + c.operator + "' operator")
}
if c.rhs == nil {
panic("jet: rhs is nil for '" + c.operator + "' operator")
}
wrap := !contains(options, noWrap)
if wrap {
out.WriteString("(")
}
if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam)
serializeOverrideFunc(statement, out, options...)
} else {
c.lhs.serialize(statement, out)
out.WriteString(c.operator)
c.rhs.serialize(statement, out)
}
if wrap {
out.WriteString(")")
}
}
// A prefix operator Expression
type prefixOpExpression struct {
expression Expression
operator string
}
func newPrefixExpression(expression Expression, operator string) prefixOpExpression {
prefixExpression := prefixOpExpression{
expression: expression,
operator: operator,
}
return prefixExpression
}
func (p *prefixOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString("(")
out.WriteString(p.operator)
if p.expression == nil {
panic("jet: nil prefix expression in prefix operator " + p.operator)
}
p.expression.serialize(statement, out)
out.WriteString(")")
}
// A postifx operator Expression
type postfixOpExpression struct {
expression Expression
operator string
}
func newPostfixOpExpression(expression Expression, operator string) postfixOpExpression {
postfixOpExpression := postfixOpExpression{
expression: expression,
operator: operator,
}
return postfixOpExpression
}
func (p *postfixOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if p.expression == nil {
panic("jet: nil prefix expression in postfix operator " + p.operator)
}
p.expression.serialize(statement, out)
out.WriteString(p.operator)
}

View file

@ -4,10 +4,13 @@ import (
"testing"
)
func TestInvalidExpression(t *testing.T) {
assertClauseSerializeErr(t, table2Col3.ADD(nil), `jet: rhs is nil for '+' operator`)
}
func TestExpressionIS_NULL(t *testing.T) {
assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL")
assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL")
assertClauseSerializeErr(t, table2Col3.ADD(nil), "jet: nil rhs")
}
func TestExpressionIS_NOT_NULL(t *testing.T) {
@ -26,33 +29,14 @@ func TestExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
}
func TestIN(t *testing.T) {
assertClauseSerialize(t, table2ColInt.IN(Int(1), Int(2), Int(3)),
`(table2.col_int IN ($1, $2, $3))`, int64(1), int64(2), int64(3))
assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)),
`($1 IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
}
func TestNOT_IN(t *testing.T) {
assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)),
`($1 NOT IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)))`, float64(1.11))
assertClauseSerialize(t, table2ColInt.NOT_IN(Int(1), Int(2), Int(3)),
`(table2.col_int NOT IN ($1, $2, $3))`, int64(1), int64(2), int64(3))
assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) NOT IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
)))`, int64(12))
}

View file

@ -81,12 +81,12 @@ func (n *floatInterfaceImpl) MOD(expression NumericExpression) FloatExpression {
}
func (n *floatInterfaceImpl) POW(expression NumericExpression) FloatExpression {
return newBinaryFloatExpression(n.parent, expression, "^")
return POW(n.parent, expression)
}
//---------------------------------------------------//
type binaryFloatExpression struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
floatInterfaceImpl
binaryOpExpression
@ -97,7 +97,7 @@ func newBinaryFloatExpression(lhs, rhs Expression, operator string) FloatExpress
floatExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
floatExpression.expressionInterfaceImpl.parent = &floatExpression
floatExpression.ExpressionInterfaceImpl.Parent = &floatExpression
floatExpression.floatInterfaceImpl.parent = &floatExpression
return &floatExpression

View file

@ -70,8 +70,8 @@ func TestFloatExpressionMOD(t *testing.T) {
}
func TestFloatExpressionPOW(t *testing.T) {
assertClauseSerialize(t, table1ColFloat.POW(table2ColFloat), "(table1.col_float ^ table2.col_float)")
assertClauseSerialize(t, table1ColFloat.POW(Float(2.11)), "(table1.col_float ^ $1)", float64(2.11))
assertClauseSerialize(t, table1ColFloat.POW(table2ColFloat), "POW(table1.col_float, table2.col_float)")
assertClauseSerialize(t, table1ColFloat.POW(Float(2.11)), "POW(table1.col_float, $1)", float64(2.11))
}
func TestFloatExp(t *testing.T) {

View file

@ -1,7 +1,5 @@
package jet
import "errors"
// ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression {
return newFunc("ROW", expressions, nil)
@ -11,7 +9,7 @@ func ROW(expressions ...Expression) Expression {
// ABSf calculates absolute value from float expression
func ABSf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("ABS", floatExpression)
return NewFloatFunc("ABS", floatExpression)
}
// ABSi calculates absolute value from int expression
@ -19,62 +17,72 @@ func ABSi(integerExpression IntegerExpression) IntegerExpression {
return newIntegerFunc("ABS", integerExpression)
}
// POW calculates power of base with exponent
func POW(base, exponent NumericExpression) FloatExpression {
return NewFloatFunc("POW", base, exponent)
}
// POWER calculates power of base with exponent
func POWER(base, exponent NumericExpression) FloatExpression {
return NewFloatFunc("POWER", base, exponent)
}
// SQRT calculates square root of numeric expression
func SQRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("SQRT", numericExpression)
return NewFloatFunc("SQRT", numericExpression)
}
// CBRT calculates cube root of numeric expression
func CBRT(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("CBRT", numericExpression)
return NewFloatFunc("CBRT", numericExpression)
}
// CEIL calculates ceil of float expression
func CEIL(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("CEIL", floatExpression)
return NewFloatFunc("CEIL", floatExpression)
}
// FLOOR calculates floor of float expression
func FLOOR(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("FLOOR", floatExpression)
return NewFloatFunc("FLOOR", floatExpression)
}
// ROUND calculates round of a float expressions with optional precision
func ROUND(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression {
if len(precision) > 0 {
return newFloatFunc("ROUND", floatExpression, precision[0])
return NewFloatFunc("ROUND", floatExpression, precision[0])
}
return newFloatFunc("ROUND", floatExpression)
return NewFloatFunc("ROUND", floatExpression)
}
// SIGN returns sign of float expression
func SIGN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SIGN", floatExpression)
return NewFloatFunc("SIGN", floatExpression)
}
// TRUNC calculates trunc of float expression with optional precision
func TRUNC(floatExpression FloatExpression, precision ...IntegerExpression) FloatExpression {
if len(precision) > 0 {
return newFloatFunc("TRUNC", floatExpression, precision[0])
return NewFloatFunc("TRUNC", floatExpression, precision[0])
}
return newFloatFunc("TRUNC", floatExpression)
return NewFloatFunc("TRUNC", floatExpression)
}
// LN calculates natural algorithm of float expression
func LN(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LN", floatExpression)
return NewFloatFunc("LN", floatExpression)
}
// LOG calculates logarithm of float expression
func LOG(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("LOG", floatExpression)
return NewFloatFunc("LOG", floatExpression)
}
// ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression
func AVG(numericExpression NumericExpression) FloatExpression {
return newFloatFunc("AVG", numericExpression)
return NewFloatFunc("AVG", numericExpression)
}
// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none.
@ -109,7 +117,7 @@ func EVERY(boolExpression BoolExpression) BoolExpression {
// MAXf is aggregate function. Returns maximum value of float expression across all input values
func MAXf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MAX", floatExpression)
return NewFloatFunc("MAX", floatExpression)
}
// MAXi is aggregate function. Returns maximum value of int expression across all input values
@ -119,7 +127,7 @@ func MAXi(integerExpression IntegerExpression) IntegerExpression {
// MINf is aggregate function. Returns minimum value of float expression across all input values
func MINf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("MIN", floatExpression)
return NewFloatFunc("MIN", floatExpression)
}
// MINi is aggregate function. Returns minimum value of int expression across all input values
@ -129,7 +137,7 @@ func MINi(integerExpression IntegerExpression) IntegerExpression {
// SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) FloatExpression {
return newFloatFunc("SUM", floatExpression)
return NewFloatFunc("SUM", floatExpression)
}
// SUMi is aggregate function. Returns sum of expression across all integer expression.
@ -196,14 +204,15 @@ func CHR(integerExpression IntegerExpression) StringExpression {
return newStringFunc("CHR", integerExpression)
}
//
//func CONCAT(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT", expressions...)
//}
//
//func CONCAT_WS(expressions ...Expression) StringExpression {
// return newStringFunc("CONCAT_WS", expressions...)
//}
// CONCAT adds two or more expressions together
func CONCAT(expressions ...Expression) StringExpression {
return newStringFunc("CONCAT", expressions...)
}
// CONCAT_WS adds two or more expressions together with a separator.
func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression {
return newStringFunc("CONCAT_WS", append([]Expression{separator}, expressions...)...)
}
// CONVERT converts string to dest_encoding. The original encoding is
// specified by src_encoding. The string must be valid in this encoding.
@ -235,11 +244,12 @@ func DECODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("DECODE", data, format)
}
//func FORMAT(formatStr StringExpression, formatArgs ...expressions) StringExpression {
// args := []expressions{formatStr}
// args = append(args, formatArgs...)
// return newStringFunc("FORMAT", args...)
//}
// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression {
args := []Expression{formatStr}
args = append(args, formatArgs...)
return newStringFunc("FORMAT", args...)
}
// INITCAP converts the first letter of each word to upper case
// and the rest to lower case. Words are sequences of alphanumeric
@ -336,6 +346,15 @@ func TO_HEX(number IntegerExpression) StringExpression {
return newStringFunc("TO_HEX", number)
}
// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise.
func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType ...string) BoolExpression {
if len(matchType) > 0 {
return newBoolFunc("REGEXP_LIKE", stringExp, pattern, ConstLiteral(matchType[0]))
}
return newBoolFunc("REGEXP_LIKE", stringExp, pattern)
}
//----------Data Type Formatting Functions ----------------------//
// TO_CHAR converts expression to string with format
@ -350,7 +369,7 @@ func TO_DATE(dateStr, format StringExpression) DateExpression {
// TO_NUMBER converts string to numeric using format
func TO_NUMBER(floatStr, format StringExpression) FloatExpression {
return newFloatFunc("TO_NUMBER", floatStr, format)
return NewFloatFunc("TO_NUMBER", floatStr, format)
}
// TO_TIMESTAMP converts string to time stamp with time zone using format
@ -372,7 +391,7 @@ func CURRENT_TIME(precision ...int) TimezExpression {
var timezFunc *timezFunc
if len(precision) > 0 {
timezFunc = newTimezFunc("CURRENT_TIME", constLiteral(precision[0]))
timezFunc = newTimezFunc("CURRENT_TIME", ConstLiteral(precision[0]))
} else {
timezFunc = newTimezFunc("CURRENT_TIME")
}
@ -387,7 +406,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression {
var timestampzFunc *timestampzFunc
if len(precision) > 0 {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", constLiteral(precision[0]))
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", ConstLiteral(precision[0]))
} else {
timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP")
}
@ -402,7 +421,7 @@ func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc
if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", constLiteral(precision[0]))
timeFunc = newTimeFunc("LOCALTIME", ConstLiteral(precision[0]))
} else {
timeFunc = newTimeFunc("LOCALTIME")
}
@ -417,9 +436,9 @@ func LOCALTIMESTAMP(precision ...int) TimestampExpression {
var timestampFunc *timestampFunc
if len(precision) > 0 {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP", constLiteral(precision[0]))
timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", ConstLiteral(precision[0]))
} else {
timestampFunc = newTimestampFunc("LOCALTIMESTAMP")
timestampFunc = NewTimestampFunc("LOCALTIMESTAMP")
}
timestampFunc.noBrackets = true
@ -463,7 +482,7 @@ func LEAST(value Expression, values ...Expression) Expression {
//--------------------------------------------------------------------//
type funcExpressionImpl struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
name string
expressions []Expression
@ -477,37 +496,34 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
}
if parent != nil {
funcExp.expressionInterfaceImpl.parent = parent
funcExp.ExpressionInterfaceImpl.Parent = parent
} else {
funcExp.expressionInterfaceImpl.parent = funcExp
funcExp.ExpressionInterfaceImpl.Parent = funcExp
}
return funcExp
}
func (f *funcExpressionImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if f == nil {
return errors.New("jet: Function expressions is nil. ")
func (f *funcExpressionImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(f.expressions...)
serializeOverrideFunc(statement, out, options...)
return
}
addBrackets := !f.noBrackets || len(f.expressions) > 0
if addBrackets {
out.writeString(f.name + "(")
out.WriteString(f.name + "(")
} else {
out.writeString(f.name)
out.WriteString(f.name)
}
err := serializeExpressionList(statement, f.expressions, ", ", out)
if err != nil {
return err
}
serializeExpressionList(statement, f.expressions, ", ", out)
if addBrackets {
out.writeString(")")
out.WriteString(")")
}
return nil
}
type boolFunc struct {
@ -529,7 +545,7 @@ type floatFunc struct {
floatInterfaceImpl
}
func newFloatFunc(name string, expressions ...Expression) FloatExpression {
func NewFloatFunc(name string, expressions ...Expression) FloatExpression {
floatFunc := &floatFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc)
@ -613,7 +629,7 @@ type timestampFunc struct {
timestampInterfaceImpl
}
func newTimestampFunc(name string, expressions ...Expression) *timestampFunc {
func NewTimestampFunc(name string, expressions ...Expression) *timestampFunc {
timestampFunc := &timestampFunc{}
timestampFunc.funcExpressionImpl = *newFunc(name, expressions, timestampFunc)

View file

@ -0,0 +1,5 @@
package jet
type GroupByClause interface {
serializeForGroupBy(statement StatementType, out *SqlBuilder)
}

View file

@ -106,7 +106,7 @@ func (i *integerInterfaceImpl) MOD(expression IntegerExpression) IntegerExpressi
}
func (i *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, expression, "^")
return IntExp(POW(i.parent, expression))
}
func (i *integerInterfaceImpl) BIT_AND(expression IntegerExpression) IntegerExpression {
@ -131,7 +131,7 @@ func (i *integerInterfaceImpl) BIT_SHIFT_RIGHT(intExpression IntegerExpression)
//---------------------------------------------------//
type binaryIntegerExpression struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
integerInterfaceImpl
binaryOpExpression
@ -140,7 +140,7 @@ type binaryIntegerExpression struct {
func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression {
integerExpression := binaryIntegerExpression{}
integerExpression.expressionInterfaceImpl.parent = &integerExpression
integerExpression.ExpressionInterfaceImpl.Parent = &integerExpression
integerExpression.integerInterfaceImpl.parent = &integerExpression
integerExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
@ -150,7 +150,7 @@ func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) Int
//---------------------------------------------------//
type prefixIntegerOpExpression struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
integerInterfaceImpl
prefixOpExpression
@ -160,12 +160,30 @@ func newPrefixIntegerOperator(expression IntegerExpression, operator string) Int
integerExpression := prefixIntegerOpExpression{}
integerExpression.prefixOpExpression = newPrefixExpression(expression, operator)
integerExpression.expressionInterfaceImpl.parent = &integerExpression
integerExpression.ExpressionInterfaceImpl.Parent = &integerExpression
integerExpression.integerInterfaceImpl.parent = &integerExpression
return &integerExpression
}
//---------------------------------------------------//
type prefixFloatOpExpression struct {
ExpressionInterfaceImpl
floatInterfaceImpl
prefixOpExpression
}
func newPrefixFloatOperator(expression FloatExpression, operator string) FloatExpression {
floatOpExpression := prefixFloatOpExpression{}
floatOpExpression.prefixOpExpression = newPrefixExpression(expression, operator)
floatOpExpression.ExpressionInterfaceImpl.Parent = &floatOpExpression
floatOpExpression.floatInterfaceImpl.parent = &floatOpExpression
return &floatOpExpression
}
//---------------------------------------------------//
type integerExpressionWrapper struct {
integerInterfaceImpl

View file

@ -60,13 +60,28 @@ func TestIntExpressionMOD(t *testing.T) {
}
func TestIntExpressionPOW(t *testing.T) {
assertClauseSerialize(t, table1ColInt.POW(table2ColInt), "(table1.col_int ^ table2.col_int)")
assertClauseSerialize(t, table1ColInt.POW(Int(11)), "(table1.col_int ^ $1)", int64(11))
assertClauseSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)")
assertClauseSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, $1)", int64(11))
}
func TestIntExpressionBIT_NOT(t *testing.T) {
assertClauseSerialize(t, BIT_NOT(table2ColInt), "~ table2.col_int")
assertClauseSerialize(t, BIT_NOT(Int(11)), "~ $1", int64(11))
assertClauseSerialize(t, BIT_NOT(table2ColInt), "(~ table2.col_int)")
assertClauseSerialize(t, BIT_NOT(Int(11)), "(~ 11)")
}
func TestIntExpressionBIT_AND(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BIT_AND(table2ColInt), "(table1.col_int & table2.col_int)")
assertClauseSerialize(t, table1ColInt.BIT_AND(Int(11)), "(table1.col_int & $1)", int64(11))
}
func TestIntExpressionBIT_OR(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BIT_OR(table2ColInt), "(table1.col_int | table2.col_int)")
assertClauseSerialize(t, table1ColInt.BIT_OR(Int(11)), "(table1.col_int | $1)", int64(11))
}
func TestIntExpressionBIT_XOR(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int # table2.col_int)")
assertClauseSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int # $1)", int64(11))
}
func TestIntExpressionBIT_SHIFT_LEFT(t *testing.T) {

View file

@ -14,8 +14,6 @@ var (
type keywordClause string
func (k keywordClause) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
out.writeString(string(k))
return nil
func (k keywordClause) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString(string(k))
}

View file

@ -0,0 +1,342 @@
package jet
import (
"fmt"
"time"
)
// LiteralExpression is representation of an escaped literal
type LiteralExpression interface {
Expression
Value() interface{}
SetConstant(constant bool)
}
type literalExpressionImpl struct {
ExpressionInterfaceImpl
value interface{}
constant bool
}
func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl {
exp := literalExpressionImpl{value: value}
if len(optionalConstant) > 0 {
exp.constant = optionalConstant[0]
}
exp.ExpressionInterfaceImpl.Parent = &exp
return &exp
}
func ConstLiteral(value interface{}) *literalExpressionImpl {
exp := literal(value)
exp.constant = true
return exp
}
func (l *literalExpressionImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if l.constant {
out.insertConstantArgument(l.value)
} else {
out.insertParametrizedArgument(l.value)
}
}
func (l *literalExpressionImpl) Value() interface{} {
return l.value
}
func (l *literalExpressionImpl) SetConstant(constant bool) {
l.constant = constant
}
type integerLiteralExpression struct {
literalExpressionImpl
integerInterfaceImpl
}
// Int is constructor for integer expressions literals.
func Int(value int64) IntegerExpression {
numLiteral := &integerLiteralExpression{}
numLiteral.literalExpressionImpl = *literal(value)
numLiteral.literalExpressionImpl.Parent = numLiteral
numLiteral.integerInterfaceImpl.parent = numLiteral
return numLiteral
}
//---------------------------------------------------//
type boolLiteralExpression struct {
boolInterfaceImpl
literalExpressionImpl
}
// Bool creates new bool literal expression
func Bool(value bool) BoolExpression {
boolLiteralExpression := boolLiteralExpression{}
boolLiteralExpression.literalExpressionImpl = *literal(value)
boolLiteralExpression.boolInterfaceImpl.parent = &boolLiteralExpression
return &boolLiteralExpression
}
//---------------------------------------------------//
type floatLiteral struct {
floatInterfaceImpl
literalExpressionImpl
}
// Float creates new float literal expression
func Float(value float64) FloatExpression {
floatLiteral := floatLiteral{}
floatLiteral.literalExpressionImpl = *literal(value)
floatLiteral.floatInterfaceImpl.parent = &floatLiteral
return &floatLiteral
}
//---------------------------------------------------//
type stringLiteral struct {
stringInterfaceImpl
literalExpressionImpl
}
// String creates new string literal expression
func String(value string) StringExpression {
stringLiteral := stringLiteral{}
stringLiteral.literalExpressionImpl = *literal(value)
stringLiteral.stringInterfaceImpl.parent = &stringLiteral
return &stringLiteral
}
//---------------------------------------------------//
type timeLiteral struct {
timeInterfaceImpl
literalExpressionImpl
}
// Time creates new time literal expression
func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression {
timeLiteral := &timeLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second)
timeStr += formatNanoseconds(nanoseconds...)
timeLiteral.literalExpressionImpl = *literal(timeStr)
timeLiteral.timeInterfaceImpl.parent = timeLiteral
return timeLiteral
}
func TimeT(t time.Time) TimeExpression {
timeLiteral := &timeLiteral{}
timeLiteral.literalExpressionImpl = *literal(t)
timeLiteral.timeInterfaceImpl.parent = timeLiteral
return timeLiteral
}
//---------------------------------------------------//
type timezLiteral struct {
timezInterfaceImpl
literalExpressionImpl
}
// Timez creates new time with time zone literal expression
func Timez(hour, minute, second int, nanoseconds time.Duration, timezone string) TimezExpression {
timezLiteral := timezLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second)
timeStr += formatNanoseconds(nanoseconds)
timeStr += " " + timezone
timezLiteral.literalExpressionImpl = *literal(timeStr)
return TimezExp(literal(timeStr))
}
func TimezT(t time.Time) TimezExpression {
timeLiteral := &timezLiteral{}
timeLiteral.literalExpressionImpl = *literal(t)
timeLiteral.timezInterfaceImpl.parent = timeLiteral
return timeLiteral
}
//---------------------------------------------------//
type timestampLiteral struct {
timestampInterfaceImpl
literalExpressionImpl
}
// Timestamp creates new timestamp literal expression
func Timestamp(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) TimestampExpression {
timestamp := &timestampLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second)
timeStr += formatNanoseconds(nanoseconds...)
timestamp.literalExpressionImpl = *literal(timeStr)
timestamp.timestampInterfaceImpl.parent = timestamp
return timestamp
}
func TimestampT(t time.Time) TimestampExpression {
timestamp := &timestampLiteral{}
timestamp.literalExpressionImpl = *literal(t)
timestamp.timestampInterfaceImpl.parent = timestamp
return timestamp
}
//---------------------------------------------------//
type timestampzLiteral struct {
timestampzInterfaceImpl
literalExpressionImpl
}
// Timestamp creates new timestamp literal expression
func Timestampz(year int, month time.Month, day, hour, minute, second int, nanoseconds time.Duration, timezone string) TimestampzExpression {
timestamp := &timestampzLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second)
timeStr += formatNanoseconds(nanoseconds)
timeStr += " " + timezone
timestamp.literalExpressionImpl = *literal(timeStr)
timestamp.timestampzInterfaceImpl.parent = timestamp
return timestamp
}
func TimestampzT(t time.Time) TimestampzExpression {
timestamp := &timestampzLiteral{}
timestamp.literalExpressionImpl = *literal(t)
timestamp.timestampzInterfaceImpl.parent = timestamp
return timestamp
}
//---------------------------------------------------//
type dateLiteral struct {
dateInterfaceImpl
literalExpressionImpl
}
//Date creates new date expression
func Date(year int, month time.Month, day int) DateExpression {
dateLiteral := &dateLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d", year, month, day)
dateLiteral.literalExpressionImpl = *literal(timeStr)
dateLiteral.dateInterfaceImpl.parent = dateLiteral
return dateLiteral
}
func DateT(t time.Time) DateExpression {
dateLiteral := &dateLiteral{}
dateLiteral.literalExpressionImpl = *literal(t)
dateLiteral.dateInterfaceImpl.parent = dateLiteral
return dateLiteral
}
func formatNanoseconds(nanoseconds ...time.Duration) string {
if len(nanoseconds) > 0 && nanoseconds[0] != 0 {
duration := fmt.Sprintf("%09d", nanoseconds[0])
i := len(duration) - 1
for ; i >= 3; i-- {
if duration[i] != '0' {
break
}
}
return "." + duration[0:i+1]
}
return ""
}
//--------------------------------------------------//
type nullLiteral struct {
ExpressionInterfaceImpl
}
func newNullLiteral() Expression {
nullExpression := &nullLiteral{}
nullExpression.ExpressionInterfaceImpl.Parent = nullExpression
return nullExpression
}
func (n *nullLiteral) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString("NULL")
}
//--------------------------------------------------//
type starLiteral struct {
ExpressionInterfaceImpl
}
func newStarLiteral() Expression {
starExpression := &starLiteral{}
starExpression.ExpressionInterfaceImpl.Parent = starExpression
return starExpression
}
func (n *starLiteral) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString("*")
}
//---------------------------------------------------//
type wrap struct {
ExpressionInterfaceImpl
expressions []Expression
}
func (n *wrap) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString("(")
serializeExpressionList(statement, n.expressions, ", ", out)
out.WriteString(")")
}
// WRAP wraps list of expressions with brackets '(' and ')'
func WRAP(expression ...Expression) Expression {
wrap := &wrap{expressions: expression}
wrap.ExpressionInterfaceImpl.Parent = wrap
return wrap
}
//---------------------------------------------------//
type rawExpression struct {
ExpressionInterfaceImpl
raw string
}
func (n *rawExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString(n.raw)
}
// Raw can be used for any unsupported functions, operators or expressions.
// For example: Raw("current_database()")
func Raw(raw string) Expression {
rawExp := &rawExpression{raw: raw}
rawExp.ExpressionInterfaceImpl.Parent = rawExp
return rawExp
}

View file

@ -0,0 +1,60 @@
package jet
import (
"testing"
"time"
)
func TestRawExpression(t *testing.T) {
assertClauseSerialize(t, Raw("current_database()"), "current_database()")
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
assertClauseSerialize(t, DateT(timeT), "$1", timeT)
}
func TestTimeLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Time(11, 5, 30), "'11:05:30'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 0), "'11:05:30'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 3*time.Millisecond), "'11:05:30.003'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 30*time.Millisecond), "'11:05:30.030'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 300*time.Millisecond), "'11:05:30.300'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 300*time.Microsecond), "'11:05:30.0003'")
assertClauseDebugSerialize(t, Time(11, 5, 30, 4*time.Nanosecond), "'11:05:30.000000004'")
}
func TestTimeT(t *testing.T) {
timeT := time.Date(2000, 1, 1, 11, 40, 20, 124, time.UTC)
assertClauseDebugSerialize(t, TimeT(timeT), `'2000-01-01 11:40:20.000000124Z'`)
}
func TestTimezLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Timez(11, 5, 30, 10*time.Nanosecond, "UTC"), "'11:05:30.00000001 UTC'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 0, "+1"), "'11:05:30 +1'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 3*time.Microsecond, "-7"), "'11:05:30.000003 -7'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 30*time.Millisecond, "+8:00"), "'11:05:30.030 +8:00'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 300*time.Nanosecond, "America/New_Yor"), "'11:05:30.0000003 America/New_Yor'")
assertClauseDebugSerialize(t, Timez(11, 5, 30, 3000*time.Nanosecond, "zulu"), "'11:05:30.000003 zulu'")
}
func TestTimestampLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Timestamp(2011, 1, 8, 11, 5, 30), "'2011-01-08 11:05:30'")
assertClauseDebugSerialize(t, Timestamp(2011, 2, 7, 11, 5, 30, 0), "'2011-02-07 11:05:30'")
assertClauseDebugSerialize(t, Timestamp(2011, 3, 6, 11, 5, 30, 3*time.Millisecond), "'2011-03-06 11:05:30.003'")
assertClauseDebugSerialize(t, Timestamp(2011, 4, 5, 11, 5, 30, 30*time.Millisecond), "'2011-04-05 11:05:30.030'")
assertClauseDebugSerialize(t, Timestamp(2011, 5, 4, 11, 5, 30, 300*time.Millisecond), "'2011-05-04 11:05:30.300'")
assertClauseDebugSerialize(t, Timestamp(2011, 6, 3, 11, 5, 30, 3000*time.Microsecond), "'2011-06-03 11:05:30.003'")
}
func TestTimestampzLiteral(t *testing.T) {
assertClauseDebugSerialize(t, Timestampz(2011, 1, 8, 11, 5, 30, 0, "UTC"), "'2011-01-08 11:05:30 UTC'")
assertClauseDebugSerialize(t, Timestampz(2011, 2, 7, 11, 5, 30, 0, "PST"), "'2011-02-07 11:05:30 PST'")
assertClauseDebugSerialize(t, Timestampz(2011, 3, 6, 11, 5, 30, 3, "+4:00"), "'2011-03-06 11:05:30.000000003 +4:00'")
assertClauseDebugSerialize(t, Timestampz(2011, 4, 5, 11, 5, 30, 30, "-8:00"), "'2011-04-05 11:05:30.00000003 -8:00'")
assertClauseDebugSerialize(t, Timestampz(2011, 5, 4, 11, 5, 30, 300, "400"), "'2011-05-04 11:05:30.0000003 400'")
assertClauseDebugSerialize(t, Timestampz(2011, 6, 3, 11, 5, 30, 3000, "zulu"), "'2011-06-03 11:05:30.000003 zulu'")
}
func TestDate(t *testing.T) {
assertClauseDebugSerialize(t, Date(2019, 8, 8), `'2019-08-08'`)
}

View file

@ -1,6 +1,10 @@
package jet
import "errors"
const (
StringConcatOperator = "||"
StringRegexpLikeOperator = "REGEXP"
StringNotRegexpLikeOperator = "NOT REGEXP"
)
//----------- Logical operators ---------------//
@ -11,13 +15,16 @@ func NOT(exp BoolExpression) BoolExpression {
// BIT_NOT inverts every bit in integer expression result
func BIT_NOT(expr IntegerExpression) IntegerExpression {
if literalExp, ok := expr.(LiteralExpression); ok {
literalExp.SetConstant(true)
}
return newPrefixIntegerOperator(expr, "~")
}
//----------- Comparison operators ---------------//
// EXISTS checks for existence of the rows in subQuery
func EXISTS(subQuery SelectStatement) BoolExpression {
func EXISTS(subQuery Expression) BoolExpression {
return newPrefixBoolOperator(subQuery, "EXISTS")
}
@ -71,7 +78,7 @@ type CaseOperator interface {
}
type caseOperatorImpl struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
expression Expression
when []Expression
@ -87,7 +94,7 @@ func CASE(expression ...Expression) CaseOperator {
caseExp.expression = expression[0]
}
caseExp.expressionInterfaceImpl.parent = caseExp
caseExp.ExpressionInterfaceImpl.Parent = caseExp
return caseExp
}
@ -108,55 +115,33 @@ func (c *caseOperatorImpl) ELSE(els Expression) CaseOperator {
return c
}
func (c *caseOperatorImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if c == nil {
return errors.New("jet: Case Expression is nil. ")
}
out.writeString("(CASE")
func (c *caseOperatorImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString("(CASE")
if c.expression != nil {
err := c.expression.serialize(statement, out)
if err != nil {
return err
}
c.expression.serialize(statement, out)
}
if len(c.when) == 0 || len(c.then) == 0 {
return errors.New("jet: Invalid case Statement. There should be at least one when/then Expression pair. ")
panic("jet: invalid case Statement. There should be at least one WHEN/THEN pair. ")
}
if len(c.when) != len(c.then) {
return errors.New("jet: When and then Expression count mismatch. ")
panic("jet: WHEN and THEN expression count mismatch. ")
}
for i, when := range c.when {
out.writeString("WHEN")
err := when.serialize(statement, out, noWrap)
out.WriteString("WHEN")
when.serialize(statement, out, noWrap)
if err != nil {
return err
}
out.writeString("THEN")
err = c.then[i].serialize(statement, out, noWrap)
if err != nil {
return err
}
out.WriteString("THEN")
c.then[i].serialize(statement, out, noWrap)
}
if c.els != nil {
out.writeString("ELSE")
err := c.els.serialize(statement, out, noWrap)
if err != nil {
return err
}
out.WriteString("ELSE")
c.els.serialize(statement, out, noWrap)
}
out.writeString("END)")
return nil
out.WriteString("END)")
}

View file

@ -5,10 +5,10 @@ import "testing"
func TestOperatorNOT(t *testing.T) {
notExpression := NOT(Int(2).EQ(Int(1)))
assertClauseSerialize(t, NOT(table1ColBool), "NOT table1.col_bool")
assertClauseSerialize(t, notExpression, "NOT ($1 = $2)", int64(2), int64(1))
assertProjectionSerialize(t, notExpression.AS("alias_not_expression"), `NOT ($1 = $2) AS "alias_not_expression"`, int64(2), int64(1))
assertClauseSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `(NOT ($1 = $2) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5))
assertClauseSerialize(t, NOT(table1ColBool), "(NOT table1.col_bool)")
assertClauseSerialize(t, notExpression, "(NOT ($1 = $2))", int64(2), int64(1))
assertProjectionSerialize(t, notExpression.AS("alias_not_expression"), `(NOT ($1 = $2)) AS "alias_not_expression"`, int64(2), int64(1))
assertClauseSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `((NOT ($1 = $2)) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5))
}
func TestCase1(t *testing.T) {

View file

@ -0,0 +1,29 @@
package jet
// OrderByClause
type OrderByClause interface {
serializeForOrderBy(statement StatementType, out *SqlBuilder)
}
type orderByClauseImpl struct {
expression Expression
ascent bool
}
func (o *orderByClauseImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) {
if o.expression == nil {
panic("jet: nil expression in ORDER BY clause")
}
o.expression.serializeForOrderBy(statement, out)
if o.ascent {
out.WriteString("ASC")
} else {
out.WriteString("DESC")
}
}
func newOrderByClause(expression Expression, ascent bool) OrderByClause {
return &orderByClauseImpl{expression: expression, ascent: ascent}
}

View file

@ -0,0 +1,27 @@
package jet
type Projection interface {
serializeForProjection(statement StatementType, out *SqlBuilder)
fromImpl(subQuery SelectTable) Projection
}
func SerializeForProjection(projection Projection, statementType StatementType, out *SqlBuilder) {
projection.serializeForProjection(statementType, out)
}
// ProjectionList is a redefined type, so that ProjectionList can be used as a Projection.
type ProjectionList []Projection
func (cl ProjectionList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, projection := range cl {
newProjectionList = append(newProjectionList, projection.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ProjectionList) serializeForProjection(statement StatementType, out *SqlBuilder) {
SerializeProjectionList(statement, cl, out)
}

View file

@ -0,0 +1,46 @@
package jet
// SelectLock is interface for SELECT statement locks
type SelectLock interface {
Serializer
NOWAIT() SelectLock
SKIP_LOCKED() SelectLock
}
type selectLockImpl struct {
lockStrength string
noWait, skipLocked bool
}
func NewSelectLock(name string) func() SelectLock {
return func() SelectLock {
return newSelectLock(name)
}
}
func newSelectLock(lockStrength string) SelectLock {
return &selectLockImpl{lockStrength: lockStrength}
}
func (s *selectLockImpl) NOWAIT() SelectLock {
s.noWait = true
return s
}
func (s *selectLockImpl) SKIP_LOCKED() SelectLock {
s.skipLocked = true
return s
}
func (s *selectLockImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
out.WriteString(s.lockStrength)
if s.noWait {
out.WriteString("NOWAIT")
}
if s.skipLocked {
out.WriteString("SKIP LOCKED")
}
}

View file

@ -0,0 +1,42 @@
package jet
// SelectTable is interface for SELECT sub-queries
type SelectTable interface {
Alias() string
AllColumns() ProjectionList
}
type SelectTableImpl struct {
selectStmt StatementWithProjections
alias string
projections ProjectionList
}
func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl {
selectTable := SelectTableImpl{selectStmt: selectStmt, alias: alias}
projectionList := selectStmt.projections().fromImpl(&selectTable)
selectTable.projections = projectionList.(ProjectionList)
return selectTable
}
func (s *SelectTableImpl) Alias() string {
return s.alias
}
func (s *SelectTableImpl) AllColumns() ProjectionList {
return s.projections
}
func (s *SelectTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if s == nil {
panic("jet: expression table is nil. ")
}
s.selectStmt.serialize(statement, out)
out.WriteString("AS")
out.WriteIdentifier(s.alias)
}

View file

@ -0,0 +1,37 @@
package jet
type SerializeOption int
const (
noWrap SerializeOption = iota
)
type StatementType string
const (
SelectStatementType StatementType = "SELECT"
InsertStatementType StatementType = "INSERT"
UpdateStatementType StatementType = "UPDATE"
DeleteStatementType StatementType = "DELETE"
SetStatementType StatementType = "SET"
LockStatementType StatementType = "LOCK"
UnLockStatementType StatementType = "UNLOCK"
)
type Serializer interface {
serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption)
}
func Serialize(exp Serializer, statementType StatementType, out *SqlBuilder, options ...SerializeOption) {
exp.serialize(statementType, out, options...)
}
func contains(options []SerializeOption, option SerializeOption) bool {
for _, opt := range options {
if opt == option {
return true
}
}
return false
}

View file

@ -21,8 +21,10 @@ func TestArgToString(t *testing.T) {
assert.Equal(t, argToString(uint(32)), "32")
assert.Equal(t, argToString(uint32(32)), "32")
assert.Equal(t, argToString(uint64(64)), "64")
assert.Equal(t, argToString(float64(1.11)), "1.11")
assert.Equal(t, argToString("john"), "'john'")
assert.Equal(t, argToString("It's text"), "'It''s text'")
assert.Equal(t, argToString([]byte("john")), "'john'")
assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")

176
internal/jet/sql_builder.go Normal file
View file

@ -0,0 +1,176 @@
package jet
import (
"bytes"
"github.com/go-jet/jet/internal/utils"
"github.com/google/uuid"
"strconv"
"strings"
"time"
)
type SqlBuilder struct {
Dialect Dialect
Buff bytes.Buffer
Args []interface{}
lastChar byte
ident int
debug bool
}
const defaultIdent = 5
func (s *SqlBuilder) IncreaseIdent(ident ...int) {
if len(ident) > 0 {
s.ident += ident[0]
} else {
s.ident += defaultIdent
}
}
func (s *SqlBuilder) DecreaseIdent(ident ...int) {
toDecrease := defaultIdent
if len(ident) > 0 {
toDecrease = ident[0]
}
if s.ident < toDecrease {
s.ident = 0
}
s.ident -= toDecrease
}
func (s *SqlBuilder) WriteProjections(statement StatementType, projections []Projection) {
s.IncreaseIdent()
SerializeProjectionList(statement, projections, s)
s.DecreaseIdent()
}
func (s *SqlBuilder) NewLine() {
s.write([]byte{'\n'})
s.write(bytes.Repeat([]byte{' '}, s.ident))
}
func (s *SqlBuilder) write(data []byte) {
if len(data) == 0 {
return
}
if !isPreSeparator(s.lastChar) && !isPostSeparator(data[0]) && s.Buff.Len() > 0 {
s.Buff.WriteByte(' ')
}
s.Buff.Write(data)
s.lastChar = data[len(data)-1]
}
func isPreSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':'
}
func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
}
func (s *SqlBuilder) WriteAlias(str string) {
aliasQuoteChar := string(s.Dialect.AliasQuoteChar())
s.WriteString(aliasQuoteChar + str + aliasQuoteChar)
}
func (s *SqlBuilder) WriteString(str string) {
s.write([]byte(str))
}
func (s *SqlBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
if quoteWrap || len(alwaysQuote) > 0 {
identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
s.WriteString(identQuoteChar + name + identQuoteChar)
} else {
s.WriteString(name)
}
}
func (s *SqlBuilder) WriteByte(b byte) {
s.write([]byte{b})
}
func (s *SqlBuilder) finalize() (string, []interface{}) {
return s.Buff.String() + ";\n", s.Args
}
func (s *SqlBuilder) insertConstantArgument(arg interface{}) {
s.WriteString(argToString(arg))
}
func (s *SqlBuilder) insertParametrizedArgument(arg interface{}) {
if s.debug {
s.insertConstantArgument(arg)
return
}
s.Args = append(s.Args, arg)
argPlaceholder := s.Dialect.ArgumentPlaceholder()(len(s.Args))
s.WriteString(argPlaceholder)
}
func argToString(value interface{}) string {
if utils.IsNil(value) {
return "NULL"
}
switch bindVal := value.(type) {
case bool:
if bindVal {
return "TRUE"
}
return "FALSE"
case int8:
return strconv.FormatInt(int64(bindVal), 10)
case int:
return strconv.FormatInt(int64(bindVal), 10)
case int16:
return strconv.FormatInt(int64(bindVal), 10)
case int32:
return strconv.FormatInt(int64(bindVal), 10)
case int64:
return strconv.FormatInt(int64(bindVal), 10)
case uint8:
return strconv.FormatUint(uint64(bindVal), 10)
case uint:
return strconv.FormatUint(uint64(bindVal), 10)
case uint16:
return strconv.FormatUint(uint64(bindVal), 10)
case uint32:
return strconv.FormatUint(uint64(bindVal), 10)
case uint64:
return strconv.FormatUint(uint64(bindVal), 10)
case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case float64:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case string:
return stringQuote(bindVal)
case []byte:
return stringQuote(string(bindVal))
case uuid.UUID:
return stringQuote(bindVal.String())
case time.Time:
return stringQuote(string(utils.FormatTimestamp(bindVal)))
default:
return "[Unsupported type]"
}
}
func stringQuote(value string) string {
return `'` + strings.Replace(value, "'", "''", -1) + `'`
}

145
internal/jet/statement.go Normal file
View file

@ -0,0 +1,145 @@
package jet
import (
"context"
"database/sql"
"github.com/go-jet/jet/execution"
)
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
type Statement interface {
// Sql returns parametrized sql query with list of arguments.
Sql() (query string, args []interface{})
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument.
// Do not use it in production. Use it only for debug purposes.
DebugSql() (query string)
// Query executes statement over database connection db and stores row result in destination.
// Destination can be arbitrary structure
Query(db execution.DB, destination interface{}) error
// QueryContext executes statement with a context over database connection db and stores row result in destination.
// Destination can be of arbitrary structure
QueryContext(context context.Context, db execution.DB, destination interface{}) error
//Exec executes statement over db connection without returning any rows.
Exec(db execution.DB) (sql.Result, error)
//Exec executes statement with context over db connection without returning any rows.
ExecContext(context context.Context, db execution.DB) (sql.Result, error)
}
type SerializerStatement interface {
Serializer
Statement
}
type StatementWithProjections interface {
Statement
HasProjections
Serializer
}
type HasProjections interface {
projections() ProjectionList
}
type SerializerStatementInterfaceImpl struct {
dialect Dialect
statementType StatementType
parent SerializerStatement
}
func (s *SerializerStatementInterfaceImpl) Sql() (query string, args []interface{}) {
queryData := &SqlBuilder{Dialect: s.dialect}
s.parent.serialize(s.statementType, queryData, noWrap)
query, args = queryData.finalize()
return
}
func (s *SerializerStatementInterfaceImpl) DebugSql() (query string) {
sqlBuilder := &SqlBuilder{Dialect: s.dialect, debug: true}
s.parent.serialize(s.statementType, sqlBuilder, noWrap)
query, _ = sqlBuilder.finalize()
return
}
func (s *SerializerStatementInterfaceImpl) Query(db execution.DB, destination interface{}) error {
query, args := s.Sql()
return execution.Query(context.Background(), db, query, args, destination)
}
func (s *SerializerStatementInterfaceImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
query, args := s.Sql()
return execution.Query(context, db, query, args, destination)
}
func (s *SerializerStatementInterfaceImpl) Exec(db execution.DB) (res sql.Result, err error) {
query, args := s.Sql()
return db.Exec(query, args...)
}
func (s *SerializerStatementInterfaceImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
query, args := s.Sql()
return db.ExecContext(context, query, args...)
}
type ExpressionStatementImpl struct {
ExpressionInterfaceImpl
StatementImpl
}
func (s *ExpressionStatementImpl) serializeForProjection(statement StatementType, out *SqlBuilder) {
s.serialize(statement, out)
}
func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) StatementImpl {
return StatementImpl{
SerializerStatementInterfaceImpl: SerializerStatementInterfaceImpl{
parent: parent,
dialect: Dialect,
statementType: statementType,
},
Clauses: clauses,
}
}
type StatementImpl struct {
SerializerStatementInterfaceImpl
Clauses []Clause
}
func (s *StatementImpl) projections() ProjectionList {
for _, clause := range s.Clauses {
if selectClause, ok := clause.(ClauseWithProjections); ok {
return selectClause.projections()
}
}
return nil
}
func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if !contains(options, noWrap) {
out.WriteString("(")
out.IncreaseIdent()
}
for _, clause := range s.Clauses {
clause.Serialize(statement, out)
}
if !contains(options, noWrap) {
out.DecreaseIdent()
out.NewLine()
out.WriteString(")")
}
}

View file

@ -18,8 +18,9 @@ type StringExpression interface {
LIKE(pattern StringExpression) BoolExpression
NOT_LIKE(pattern StringExpression) BoolExpression
SIMILAR_TO(pattern StringExpression) BoolExpression
NOT_SIMILAR_TO(pattern StringExpression) BoolExpression
REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression
NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression
}
type stringInterfaceImpl struct {
@ -59,7 +60,7 @@ func (s *stringInterfaceImpl) LT_EQ(rhs StringExpression) BoolExpression {
}
func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression {
return newBinaryStringExpression(s.parent, rhs, "||")
return newBinaryStringExpression(s.parent, rhs, StringConcatOperator)
}
func (s *stringInterfaceImpl) LIKE(pattern StringExpression) BoolExpression {
@ -70,17 +71,18 @@ func (s *stringInterfaceImpl) NOT_LIKE(pattern StringExpression) BoolExpression
return newBinaryBoolOperator(s.parent, pattern, "NOT LIKE")
}
func (s *stringInterfaceImpl) SIMILAR_TO(pattern StringExpression) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, "SIMILAR TO")
func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0]))
}
func (s *stringInterfaceImpl) NOT_SIMILAR_TO(pattern StringExpression) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, "NOT SIMILAR TO")
func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0]))
}
//---------------------------------------------------//
type binaryStringExpression struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
stringInterfaceImpl
binaryOpExpression
@ -90,7 +92,7 @@ func newBinaryStringExpression(lhs, rhs Expression, operator string) StringExpre
boolExpression := binaryStringExpression{}
boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
boolExpression.expressionInterfaceImpl.parent = &boolExpression
boolExpression.ExpressionInterfaceImpl.Parent = &boolExpression
boolExpression.stringInterfaceImpl.parent = &boolExpression
return &boolExpression

View file

@ -66,14 +66,14 @@ func TestStringNOT_LIKE(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_LIKE(String("JOHN")), "(table3.col2 NOT LIKE $1)", "JOHN")
}
func TestStringSIMILAR_TO(t *testing.T) {
assertClauseSerialize(t, table3StrCol.SIMILAR_TO(table2ColStr), "(table3.col2 SIMILAR TO table2.col_str)")
assertClauseSerialize(t, table3StrCol.SIMILAR_TO(String("JOHN")), "(table3.col2 SIMILAR TO $1)", "JOHN")
func TestStringREGEXP_LIKE(t *testing.T) {
assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)")
assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 REGEXP $1)", "JOHN")
}
func TestStringNOT_SIMILAR_TO(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_SIMILAR_TO(table2ColStr), "(table3.col2 NOT SIMILAR TO table2.col_str)")
assertClauseSerialize(t, table3StrCol.NOT_SIMILAR_TO(String("JOHN")), "(table3.col2 NOT SIMILAR TO $1)", "JOHN")
func TestStringNOT_REGEXP_LIKE(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)")
assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP $1)", "JOHN")
}
func TestStringExp(t *testing.T) {

209
internal/jet/table.go Normal file
View file

@ -0,0 +1,209 @@
package jet
import (
"github.com/go-jet/jet/internal/utils"
)
type SerializerTable interface {
Serializer
TableInterface
}
type TableInterface interface {
columns() []Column
SchemaName() string
TableName() string
AS(alias string)
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, columns ...ColumnExpression) TableImpl {
t := TableImpl{
schemaName: schemaName,
name: name,
columnList: columns,
}
for _, c := range columns {
c.setTableName(name)
}
return t
}
type TableImpl struct {
schemaName string
name string
alias string
columnList []ColumnExpression
}
func (t *TableImpl) AS(alias string) {
t.alias = alias
for _, c := range t.columnList {
c.setTableName(alias)
}
}
func (t *TableImpl) SchemaName() string {
return t.schemaName
}
func (t *TableImpl) TableName() string {
return t.name
}
func (t *TableImpl) columns() []Column {
ret := []Column{}
for _, col := range t.columnList {
ret = append(ret, col)
}
return ret
}
func (t *TableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if t == nil {
panic("jet: tableImpl is nil")
}
out.WriteIdentifier(t.schemaName)
out.WriteString(".")
out.WriteIdentifier(t.name)
if len(t.alias) > 0 {
out.WriteString("AS")
out.WriteIdentifier(t.alias)
}
}
type JoinType int
const (
InnerJoin JoinType = iota
LeftJoin
RightJoin
FullJoin
CrossJoin
)
// Join expressions are pseudo readable tables.
type JoinTableImpl struct {
lhs Serializer
rhs Serializer
joinType JoinType
onCondition BoolExpression
}
func NewJoinTableImpl(lhs Serializer, rhs Serializer, joinType JoinType, onCondition BoolExpression) JoinTableImpl {
joinTable := JoinTableImpl{
lhs: lhs,
rhs: rhs,
joinType: joinType,
onCondition: onCondition,
}
return joinTable
}
func (t *JoinTableImpl) SchemaName() string {
if table, ok := t.lhs.(TableInterface); ok {
return table.SchemaName()
}
return ""
}
func (t *JoinTableImpl) TableName() string {
return ""
}
func (t *JoinTableImpl) Columns() []Column {
var ret []Column
if lhsTable, ok := t.lhs.(TableInterface); ok {
ret = append(ret, lhsTable.columns()...)
}
if rhsTable, ok := t.rhs.(TableInterface); ok {
ret = append(ret, rhsTable.columns()...)
}
return ret
}
func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) {
if t == nil {
panic("jet: Join table is nil. ")
}
if utils.IsNil(t.lhs) {
panic("jet: left hand side of join operation is nil table")
}
t.lhs.serialize(statement, out)
out.NewLine()
switch t.joinType {
case InnerJoin:
out.WriteString("INNER JOIN")
case LeftJoin:
out.WriteString("LEFT JOIN")
case RightJoin:
out.WriteString("RIGHT JOIN")
case FullJoin:
out.WriteString("FULL JOIN")
case CrossJoin:
out.WriteString("CROSS JOIN")
}
if utils.IsNil(t.rhs) {
panic("jet: right hand side of join operation is nil table")
}
t.rhs.serialize(statement, out)
if t.onCondition == nil && t.joinType != CrossJoin {
panic("jet: join condition is nil")
}
if t.onCondition != nil {
out.WriteString("ON")
t.onCondition.serialize(statement, out)
}
}
func UnwindColumns(column1 Column, columns ...Column) []Column {
columnList := []Column{}
if val, ok := column1.(IColumnList); ok {
for _, col := range val.columns() {
columnList = append(columnList, col)
}
columnList = append(columnList, columns...)
} else {
columnList = append(columnList, column1)
columnList = append(columnList, columns...)
}
return columnList
}
func UnwidColumnList(columns []Column) []Column {
ret := []Column{}
for _, col := range columns {
if columnList, ok := col.(IColumnList); ok {
for _, c := range columnList.columns() {
ret = append(ret, c)
}
} else {
ret = append(ret, col)
}
}
return ret
}

85
internal/jet/testutils.go Normal file
View file

@ -0,0 +1,85 @@
package jet
import (
"gotest.tools/assert"
"strconv"
"testing"
)
var DefaultDialect = NewDialect(DialectParams{ // just for tests
AliasQuoteChar: '"',
IdentifierQuoteChar: '"',
ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord)
},
})
var table1Col1 = IntegerColumn("col1")
var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float")
var table1Col3 = IntegerColumn("col3")
var table1ColTime = TimeColumn("col_time")
var table1ColTimez = TimezColumn("col_timez")
var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColTimestampz = TimestampzColumn("col_timestampz")
var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date")
var table1 = NewTable("db", "table1", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz)
var table2Col3 = IntegerColumn("col3")
var table2Col4 = IntegerColumn("col4")
var table2ColInt = IntegerColumn("col_int")
var table2ColFloat = FloatColumn("col_float")
var table2ColStr = StringColumn("col_str")
var table2ColBool = BoolColumn("col_bool")
var table2ColTime = TimeColumn("col_time")
var table2ColTimez = TimezColumn("col_timez")
var table2ColTimestamp = TimestampColumn("col_timestamp")
var table2ColTimestampz = TimestampzColumn("col_timestampz")
var table2ColDate = DateColumn("col_date")
var table2 = NewTable("db", "table2", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz)
var table3Col1 = IntegerColumn("col1")
var table3ColInt = IntegerColumn("col_int")
var table3StrCol = StringColumn("col2")
var table3 = NewTable("db", "table3", table3Col1, table3ColInt, table3StrCol)
func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) {
out := SqlBuilder{Dialect: DefaultDialect}
clause.serialize(SelectStatementType, &out)
//fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}
func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) {
defer func() {
r := recover()
assert.Equal(t, r, errString)
}()
out := SqlBuilder{Dialect: DefaultDialect}
clause.serialize(SelectStatementType, &out)
}
func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) {
out := SqlBuilder{Dialect: DefaultDialect, debug: true}
clause.serialize(SelectStatementType, &out)
//fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}
func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) {
out := SqlBuilder{Dialect: DefaultDialect}
projection.serializeForProjection(SelectStatementType, &out)
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}

View file

@ -53,7 +53,7 @@ func (t *timeInterfaceImpl) GT_EQ(rhs TimeExpression) BoolExpression {
//---------------------------------------------------//
type prefixTimeExpression struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
timeInterfaceImpl
prefixOpExpression
@ -63,7 +63,7 @@ type prefixTimeExpression struct {
// timeExpr := prefixTimeExpression{}
// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
//
// timeExpr.expressionInterfaceImpl.parent = &timeExpr
// timeExpr.ExpressionInterfaceImpl.parent = &timeExpr
// timeExpr.timeInterfaceImpl.parent = &timeExpr
//
// return &timeExpr

View file

@ -2,52 +2,53 @@ package jet
import (
"testing"
"time"
)
var timeVar = Time(10, 20, 0, 0)
func TestTimeExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.EQ(table2ColTime), "(table1.col_time = table2.col_time)")
assertClauseSerialize(t, table1ColTime.EQ(timeVar), "(table1.col_time = $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.EQ(timeVar), "(table1.col_time = $1)", "10:20:00")
}
func TestTimeExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.NOT_EQ(table2ColTime), "(table1.col_time != table2.col_time)")
assertClauseSerialize(t, table1ColTime.NOT_EQ(timeVar), "(table1.col_time != $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.NOT_EQ(timeVar), "(table1.col_time != $1)", "10:20:00")
}
func TestTimeExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(table2ColTime), "(table1.col_time IS DISTINCT FROM table2.col_time)")
assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(timeVar), "(table1.col_time IS DISTINCT FROM $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.IS_DISTINCT_FROM(timeVar), "(table1.col_time IS DISTINCT FROM $1)", "10:20:00")
}
func TestTimeExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(table2ColTime), "(table1.col_time IS NOT DISTINCT FROM table2.col_time)")
assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(timeVar), "(table1.col_time IS NOT DISTINCT FROM $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.IS_NOT_DISTINCT_FROM(timeVar), "(table1.col_time IS NOT DISTINCT FROM $1)", "10:20:00")
}
func TestTimeExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTime.LT(table2ColTime), "(table1.col_time < table2.col_time)")
assertClauseSerialize(t, table1ColTime.LT(timeVar), "(table1.col_time < $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.LT(timeVar), "(table1.col_time < $1)", "10:20:00")
}
func TestTimeExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.LT_EQ(table2ColTime), "(table1.col_time <= table2.col_time)")
assertClauseSerialize(t, table1ColTime.LT_EQ(timeVar), "(table1.col_time <= $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.LT_EQ(timeVar), "(table1.col_time <= $1)", "10:20:00")
}
func TestTimeExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTime.GT(table2ColTime), "(table1.col_time > table2.col_time)")
assertClauseSerialize(t, table1ColTime.GT(timeVar), "(table1.col_time > $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.GT(timeVar), "(table1.col_time > $1)", "10:20:00")
}
func TestTimeExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTime.GT_EQ(table2ColTime), "(table1.col_time >= table2.col_time)")
assertClauseSerialize(t, table1ColTime.GT_EQ(timeVar), "(table1.col_time >= $1::time without time zone)", "10:20:00.000")
assertClauseSerialize(t, table1ColTime.GT_EQ(timeVar), "(table1.col_time >= $1)", "10:20:00")
}
func TestTimeExp(t *testing.T) {
assertClauseSerialize(t, TimeExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1)),
"(table1.col_float < $1::time without time zone)", string("01:01:01.001"))
assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1*time.Millisecond)),
"(table1.col_float < $1)", string("01:01:01.001"))
}

View file

@ -1,52 +1,55 @@
package jet
import "testing"
import (
"testing"
"time"
)
var timestamp = Timestamp(2000, 1, 31, 10, 20, 0, 0)
var timestamp = Timestamp(2000, 1, 31, 10, 20, 0, 3*time.Millisecond)
func TestTimestampExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.EQ(table2ColTimestamp), "(table1.col_timestamp = table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.EQ(timestamp),
"(table1.col_timestamp = $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
"(table1.col_timestamp = $1)", "2000-01-31 10:20:00.003")
}
func TestTimestampExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(table2ColTimestamp), "(table1.col_timestamp != table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(timestamp), "(table1.col_timestamp != $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
assertClauseSerialize(t, table1ColTimestamp.NOT_EQ(timestamp), "(table1.col_timestamp != $1)", "2000-01-31 10:20:00.003")
}
func TestTimestampExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS DISTINCT FROM table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS DISTINCT FROM $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
assertClauseSerialize(t, table1ColTimestamp.IS_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS DISTINCT FROM $1)", "2000-01-31 10:20:00.003")
}
func TestTimestampExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(table2ColTimestamp), "(table1.col_timestamp IS NOT DISTINCT FROM table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS NOT DISTINCT FROM $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
assertClauseSerialize(t, table1ColTimestamp.IS_NOT_DISTINCT_FROM(timestamp), "(table1.col_timestamp IS NOT DISTINCT FROM $1)", "2000-01-31 10:20:00.003")
}
func TestTimestampExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.LT(table2ColTimestamp), "(table1.col_timestamp < table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.LT(timestamp), "(table1.col_timestamp < $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
assertClauseSerialize(t, table1ColTimestamp.LT(timestamp), "(table1.col_timestamp < $1)", "2000-01-31 10:20:00.003")
}
func TestTimestampExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.LT_EQ(table2ColTimestamp), "(table1.col_timestamp <= table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.LT_EQ(timestamp), "(table1.col_timestamp <= $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
assertClauseSerialize(t, table1ColTimestamp.LT_EQ(timestamp), "(table1.col_timestamp <= $1)", "2000-01-31 10:20:00.003")
}
func TestTimestampExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.GT(table2ColTimestamp), "(table1.col_timestamp > table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.GT(timestamp), "(table1.col_timestamp > $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
assertClauseSerialize(t, table1ColTimestamp.GT(timestamp), "(table1.col_timestamp > $1)", "2000-01-31 10:20:00.003")
}
func TestTimestampExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestamp.GT_EQ(table2ColTimestamp), "(table1.col_timestamp >= table2.col_timestamp)")
assertClauseSerialize(t, table1ColTimestamp.GT_EQ(timestamp), "(table1.col_timestamp >= $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
assertClauseSerialize(t, table1ColTimestamp.GT_EQ(timestamp), "(table1.col_timestamp >= $1)", "2000-01-31 10:20:00.003")
}
func TestTimestampExp(t *testing.T) {
assertClauseSerialize(t, TimestampExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp),
"(table1.col_float < $1::timestamp without time zone)", "2000-01-31 10:20:00.000")
"(table1.col_float < $1)", "2000-01-31 10:20:00.003")
}

View file

@ -51,6 +51,25 @@ func (t *timestampzInterfaceImpl) GT_EQ(rhs TimestampzExpression) BoolExpression
return gtEq(t.parent, rhs)
}
//---------------------------------------------------//
type prefixTimestampzOperator struct {
ExpressionInterfaceImpl
timestampzInterfaceImpl
prefixOpExpression
}
func NewPrefixTimestampOperator(operator string, expression Expression) TimestampzExpression {
timeExpr := prefixTimestampzOperator{}
timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
timeExpr.ExpressionInterfaceImpl.Parent = &timeExpr
timeExpr.timestampzInterfaceImpl.parent = &timeExpr
return &timeExpr
}
//-------------------------------------------------
type timestampzExpressionWrapper struct {

View file

@ -1,52 +1,55 @@
package jet
import "testing"
import (
"testing"
"time"
)
var timestampz = Timestampz(2000, 1, 31, 10, 20, 0, 0, 2)
var timestampz = Timestampz(2000, 1, 31, 10, 20, 5, 23*time.Microsecond, "+200")
func TestTimestampzExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.EQ(table2ColTimestampz), "(table1.col_timestampz = table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.EQ(timestampz),
"(table1.col_timestampz = $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
"(table1.col_timestampz = $1)", "2000-01-31 10:20:05.000023 +200")
}
func TestTimestampzExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(table2ColTimestampz), "(table1.col_timestampz != table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.NOT_EQ(timestampz), "(table1.col_timestampz != $1)", "2000-01-31 10:20:05.000023 +200")
}
func TestTimestampzExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.IS_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS DISTINCT FROM $1)", "2000-01-31 10:20:05.000023 +200")
}
func TestTimestampzExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(table2ColTimestampz), "(table1.col_timestampz IS NOT DISTINCT FROM table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.IS_NOT_DISTINCT_FROM(timestampz), "(table1.col_timestampz IS NOT DISTINCT FROM $1)", "2000-01-31 10:20:05.000023 +200")
}
func TestTimestampzExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT(table2ColTimestampz), "(table1.col_timestampz < table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.LT(timestampz), "(table1.col_timestampz < $1)", "2000-01-31 10:20:05.000023 +200")
}
func TestTimestampzExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(table2ColTimestampz), "(table1.col_timestampz <= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.LT_EQ(timestampz), "(table1.col_timestampz <= $1)", "2000-01-31 10:20:05.000023 +200")
}
func TestTimestampzExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT(table2ColTimestampz), "(table1.col_timestampz > table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.GT(timestampz), "(table1.col_timestampz > $1)", "2000-01-31 10:20:05.000023 +200")
}
func TestTimestampzExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(table2ColTimestampz), "(table1.col_timestampz >= table2.col_timestampz)")
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
assertClauseSerialize(t, table1ColTimestampz.GT_EQ(timestampz), "(table1.col_timestampz >= $1)", "2000-01-31 10:20:05.000023 +200")
}
func TestTimestampzExp(t *testing.T) {
assertClauseSerialize(t, TimestampzExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz),
"(table1.col_float < $1::timestamp with time zone)", "2000-01-31 10:20:00.000 +002")
"(table1.col_float < $1)", "2000-01-31 10:20:05.000023 +200")
}

View file

@ -61,7 +61,7 @@ func (t *timezInterfaceImpl) GT_EQ(rhs TimezExpression) BoolExpression {
//---------------------------------------------------//
type prefixTimezExpression struct {
expressionInterfaceImpl
ExpressionInterfaceImpl
timezInterfaceImpl
prefixOpExpression
@ -71,7 +71,7 @@ type prefixTimezExpression struct {
// timeExpr := prefixTimezExpression{}
// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
//
// timeExpr.expressionInterfaceImpl.parent = &timeExpr
// timeExpr.ExpressionInterfaceImpl.parent = &timeExpr
// timeExpr.timezInterfaceImpl.parent = &timeExpr
//
// return &timeExpr

View file

@ -2,50 +2,50 @@ package jet
import "testing"
var timezVar = Timez(10, 20, 0, 0, 4)
var timezVar = Timez(10, 20, 0, 0, "+4:00")
func TestTimezExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.EQ(table2ColTimez), "(table1.col_timez = table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.EQ(timezVar), "(table1.col_timez = $1)", "10:20:00 +4:00")
}
func TestTimezExpressionNOT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.NOT_EQ(table2ColTimez), "(table1.col_timez != table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.NOT_EQ(timezVar), "(table1.col_timez != $1)", "10:20:00 +4:00")
}
func TestTimezExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.IS_DISTINCT_FROM(timezVar), "(table1.col_timez IS DISTINCT FROM $1)", "10:20:00 +4:00")
}
func TestTimezExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(table2ColTimez), "(table1.col_timez IS NOT DISTINCT FROM table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.IS_NOT_DISTINCT_FROM(timezVar), "(table1.col_timez IS NOT DISTINCT FROM $1)", "10:20:00 +4:00")
}
func TestTimezExpressionLT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT(table2ColTimez), "(table1.col_timez < table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.LT(timezVar), "(table1.col_timez < $1)", "10:20:00 +4:00")
}
func TestTimezExpressionLT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.LT_EQ(table2ColTimez), "(table1.col_timez <= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.LT_EQ(timezVar), "(table1.col_timez <= $1)", "10:20:00 +4:00")
}
func TestTimezExpressionGT(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT(table2ColTimez), "(table1.col_timez > table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.GT(timezVar), "(table1.col_timez > $1)", "10:20:00 +4:00")
}
func TestTimezExpressionGT_EQ(t *testing.T) {
assertClauseSerialize(t, table1ColTimez.GT_EQ(table2ColTimez), "(table1.col_timez >= table2.col_timez)")
assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1::time with time zone)", "10:20:00.000 +04")
assertClauseSerialize(t, table1ColTimez.GT_EQ(timezVar), "(table1.col_timez >= $1)", "10:20:00 +4:00")
}
func TestTimezExp(t *testing.T) {
assertClauseSerialize(t, TimezExp(table1ColFloat), "table1.col_float")
assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, 4)),
"(table1.col_float < $1::time with time zone)", string("01:01:01.001 +04"))
assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, "+4:00")),
"(table1.col_float < $1)", string("01:01:01.000000001 +4:00"))
}

163
internal/jet/utils.go Normal file
View file

@ -0,0 +1,163 @@
package jet
import (
"github.com/go-jet/jet/internal/utils"
"reflect"
)
func serializeOrderByClauseList(statement StatementType, orderByClauses []OrderByClause, out *SqlBuilder) {
for i, value := range orderByClauses {
if i > 0 {
out.WriteString(", ")
}
value.serializeForOrderBy(statement, out)
}
}
func serializeGroupByClauseList(statement StatementType, clauses []GroupByClause, out *SqlBuilder) {
for i, c := range clauses {
if i > 0 {
out.WriteString(", ")
}
if c == nil {
panic("jet: nil clause")
}
c.serializeForGroupBy(statement, out)
}
}
func SerializeClauseList(statement StatementType, clauses []Serializer, out *SqlBuilder) {
for i, c := range clauses {
if i > 0 {
out.WriteString(", ")
}
if c == nil {
panic("jet: nil clause")
}
c.serialize(statement, out)
}
}
func serializeExpressionList(statement StatementType, expressions []Expression, separator string, out *SqlBuilder) {
for i, value := range expressions {
if i > 0 {
out.WriteString(separator)
}
value.serialize(statement, out)
}
}
func SerializeProjectionList(statement StatementType, projections []Projection, out *SqlBuilder) {
for i, col := range projections {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
if col == nil {
panic("jet: Projection is nil")
}
col.serializeForProjection(statement, out)
}
}
func SerializeColumnNames(columns []Column, out *SqlBuilder) {
for i, col := range columns {
if i > 0 {
out.WriteString(", ")
}
if col == nil {
panic("jet: nil column in columns list")
}
out.WriteString(col.Name())
}
}
func ColumnListToProjectionList(columns []ColumnExpression) []Projection {
var ret []Projection
for _, column := range columns {
ret = append(ret, column)
}
return ret
}
func valueToClause(value interface{}) Serializer {
if clause, ok := value.(Serializer); ok {
return clause
}
return literal(value)
}
func UnwindRowFromModel(columns []Column, data interface{}) []Serializer {
structValue := reflect.Indirect(reflect.ValueOf(data))
row := []Serializer{}
utils.ValueMustBe(structValue, reflect.Struct, "jet: data has to be a struct")
for _, column := range columns {
columnName := column.Name()
structFieldName := utils.ToGoIdentifier(columnName)
structField := structValue.FieldByName(structFieldName)
if !structField.IsValid() {
panic("missing struct field for column : " + columnName)
}
var field interface{}
if structField.Kind() == reflect.Ptr && structField.IsNil() {
field = nil
} else {
field = reflect.Indirect(structField).Interface()
}
row = append(row, literal(field))
}
return row
}
func UnwindRowsFromModels(columns []Column, data interface{}) [][]Serializer {
sliceValue := reflect.Indirect(reflect.ValueOf(data))
utils.ValueMustBe(sliceValue, reflect.Slice, "jet: data has to be a slice.")
rows := [][]Serializer{}
for i := 0; i < sliceValue.Len(); i++ {
structValue := sliceValue.Index(i)
rows = append(rows, UnwindRowFromModel(columns, structValue.Interface()))
}
return rows
}
func UnwindRowFromValues(value interface{}, values []interface{}) []Serializer {
row := []Serializer{}
allValues := append([]interface{}{value}, values...)
for _, val := range allValues {
row = append(row, valueToClause(val))
}
return row
}

View file

@ -0,0 +1,148 @@
package testutils
import (
"bytes"
"encoding/json"
"fmt"
"github.com/go-jet/jet/execution"
"github.com/go-jet/jet/internal/jet"
"gotest.tools/assert"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"testing"
)
func AssertExec(t *testing.T, stmt jet.Statement, db execution.DB, rowsAffected ...int64) {
res, err := stmt.Exec(db)
assert.NilError(t, err)
rows, err := res.RowsAffected()
assert.NilError(t, err)
if len(rowsAffected) > 0 {
assert.Equal(t, rows, rowsAffected[0])
}
}
func AssertExecErr(t *testing.T, stmt jet.Statement, db execution.DB, errorStr string) {
_, err := stmt.Exec(db)
assert.Error(t, err, errorStr)
}
func getFullPath(relativePath string) string {
goPath := os.Getenv("GOPATH")
return filepath.Join(goPath, "src/github.com/go-jet/jet/tests", relativePath)
}
func PrintJson(v interface{}) {
jsonText, _ := json.MarshalIndent(v, "", "\t")
fmt.Println(string(jsonText))
}
func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err)
assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON)
}
func SaveJsonFile(v interface{}, testRelativePath string) {
jsonText, _ := json.MarshalIndent(v, "", "\t")
filePath := getFullPath(testRelativePath)
err := ioutil.WriteFile(filePath, jsonText, 0644)
if err != nil {
panic(err)
}
}
func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath)
fileJSONData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err)
if runtime.GOOS == "windows" {
fileJSONData = bytes.Replace(fileJSONData, []byte("\r\n"), []byte("\n"), -1)
}
jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err)
assert.Assert(t, string(fileJSONData) == string(jsonData))
//assert.DeepEqual(t, string(fileJSONData), string(jsonData))
}
func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) {
queryStr, args := query.Sql()
assert.Equal(t, queryStr, expectedQuery)
if len(expectedArgs) == 0 {
return
}
assert.DeepEqual(t, args, expectedArgs)
}
func AssertStatementSqlErr(t *testing.T, stmt jet.Statement, errorStr string) {
defer func() {
r := recover()
assert.Equal(t, r, errorStr)
}()
stmt.Sql()
}
func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) {
_, args := query.Sql()
if len(expectedArgs) > 0 {
assert.DeepEqual(t, args, expectedArgs)
}
debuqSql := query.DebugSql()
assert.Equal(t, debuqSql, expectedQuery)
}
func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) {
out := jet.SqlBuilder{Dialect: dialect}
jet.Serialize(clause, jet.SelectStatementType, &out)
//fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query)
if len(args) > 0 {
assert.DeepEqual(t, out.Args, args)
}
}
func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) {
defer func() {
r := recover()
assert.Equal(t, r, errString)
}()
out := jet.SqlBuilder{Dialect: dialect}
jet.Serialize(clause, jet.SelectStatementType, &out)
}
func AssertProjectionSerialize(t *testing.T, dialect jet.Dialect, projection jet.Projection, query string, args ...interface{}) {
out := jet.SqlBuilder{Dialect: dialect}
jet.SerializeForProjection(projection, jet.SelectStatementType, &out)
assert.DeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args)
}
func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db execution.DB, dest interface{}, errString string) {
defer func() {
r := recover()
assert.Equal(t, r, errString)
}()
stmt.Query(db, dest)
}

View file

@ -0,0 +1,70 @@
package testutils
import (
"strings"
"time"
)
func Date(t string) *time.Time {
newTime, err := time.Parse("2006-01-02", t)
if err != nil {
panic(err)
}
return &newTime
}
func TimestampWithoutTimeZone(t string, precision int) *time.Time {
precisionStr := ""
if precision > 0 {
precisionStr = "." + strings.Repeat("9", precision)
}
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000")
if err != nil {
panic(err)
}
return &newTime
}
func TimeWithoutTimeZone(t string) *time.Time {
newTime, err := time.Parse("15:04:05", t)
if err != nil {
panic(err)
}
return &newTime
}
func TimeWithTimeZone(t string) *time.Time {
newTimez, err := time.Parse("15:04:05 -0700", t)
if err != nil {
panic(err)
}
return &newTimez
}
func TimestampWithTimeZone(t string, precision int) *time.Time {
precisionStr := ""
if precision > 0 {
precisionStr = "." + strings.Repeat("9", precision)
}
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t)
if err != nil {
panic(err)
}
return &newTime
}

View file

@ -1,7 +1,7 @@
package utils
import (
"bytes"
"database/sql"
"github.com/go-jet/jet/internal/3rdparty/snaker"
"go/format"
"os"
@ -9,7 +9,6 @@ import (
"reflect"
"strconv"
"strings"
"text/template"
"time"
)
@ -62,28 +61,6 @@ func EnsureDirPath(dirPath string) error {
return nil
}
// GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": ToGoIdentifier,
"now": func() string {
return time.Now().Format(time.RFC850)
},
}).Parse(templateText)
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := t.Execute(&buf, templateData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// CleanUpGeneratedFiles deletes everything at folder dir.
func CleanUpGeneratedFiles(dir string) error {
exist, err := DirExists(dir)
@ -103,6 +80,14 @@ func CleanUpGeneratedFiles(dir string) error {
return nil
}
func DBClose(db *sql.DB) {
if db == nil {
return
}
db.Close()
}
// DirExists checks if folder at path exist.
func DirExists(path string) (bool, error) {
_, err := os.Stat(path)
@ -159,3 +144,28 @@ func FormatTimestamp(t time.Time) []byte {
func IsNil(v interface{}) bool {
return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil())
}
func MustBe(v interface{}, kind reflect.Kind, errorStr string) {
if reflect.TypeOf(v).Kind() != kind {
panic(errorStr)
}
}
func ValueMustBe(v reflect.Value, kind reflect.Kind, errorStr string) {
if v.Kind() != kind {
panic(errorStr)
}
}
func TypeMustBe(v reflect.Type, kind reflect.Kind, errorStr string) {
if v.Kind() != kind {
panic(errorStr)
}
}
func MustBeInitializedPtr(val interface{}, errorStr string) {
if IsNil(val) {
panic(errorStr)
}
}

View file

@ -1,7 +1,7 @@
package utils
import (
"github.com/stretchr/testify/assert"
"gotest.tools/assert"
"testing"
)

Some files were not shown because too many files have changed in this diff Show more