jet/clause.go

259 lines
5.5 KiB
Go
Raw Normal View History

2019-06-21 13:56:57 +02:00
package jet
2019-03-31 09:17:28 +02:00
import (
"bytes"
2019-07-04 17:54:15 +02:00
"github.com/go-jet/jet/internal/utils"
2019-06-11 12:47:35 +02:00
"github.com/google/uuid"
"strconv"
2019-06-11 12:47:35 +02:00
"strings"
"time"
)
2019-03-31 09:17:28 +02:00
type serializeOption int
const (
2019-06-27 19:55:21 +02:00
noWrap serializeOption = iota
)
2019-05-07 19:06:21 +02:00
type clause interface {
2019-07-08 10:48:03 +02:00
serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error
}
func contains(options []serializeOption, option serializeOption) bool {
for _, opt := range options {
if opt == option {
return true
}
}
return false
}
2019-07-08 10:48:03 +02:00
type sqlBuilder struct {
dialect Dialect
buff bytes.Buffer
args []interface{}
2019-05-12 18:15:23 +02:00
lastChar byte
ident int
2019-05-03 12:51:57 +02:00
}
type statementType string
2019-05-03 12:51:57 +02:00
const (
2019-07-18 17:43:11 +02:00
selectStatement statementType = "SELECT"
insertStatement statementType = "INSERT"
updateStatement statementType = "UPDATE"
deleteStatement statementType = "DELETE"
setStatement statementType = "SET"
lockStatement statementType = "LOCK"
2019-05-03 12:51:57 +02:00
)
2019-05-12 18:15:23 +02:00
const defaultIdent = 5
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) increaseIdent() {
2019-05-12 18:15:23 +02:00
q.ident += defaultIdent
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) decreaseIdent() {
2019-05-12 18:15:23 +02:00
if q.ident < defaultIdent {
q.ident = 0
}
q.ident -= defaultIdent
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) writeProjections(statement statementType, projections []projection) error {
2019-05-12 18:15:23 +02:00
q.increaseIdent()
err := serializeProjectionList(statement, projections, q)
q.decreaseIdent()
return err
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) writeFrom(statement statementType, table ReadableTable) error {
q.newLine()
2019-05-12 18:15:23 +02:00
q.writeString("FROM")
q.increaseIdent()
err := table.serialize(statement, q)
q.decreaseIdent()
return err
2019-05-03 12:51:57 +02:00
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) writeWhere(statement statementType, where Expression) error {
q.newLine()
2019-05-12 18:15:23 +02:00
q.writeString("WHERE")
q.increaseIdent()
2019-06-27 19:55:21 +02:00
err := where.serialize(statement, q, noWrap)
2019-05-12 18:15:23 +02:00
q.decreaseIdent()
return err
2019-05-03 12:51:57 +02:00
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) writeGroupBy(statement statementType, groupBy []groupByClause) error {
q.newLine()
2019-05-12 18:15:23 +02:00
q.writeString("GROUP BY")
2019-05-03 12:51:57 +02:00
2019-05-12 18:15:23 +02:00
q.increaseIdent()
err := serializeGroupByClauseList(statement, groupBy, q)
q.decreaseIdent()
return err
2019-05-03 12:51:57 +02:00
}
2019-07-18 17:43:11 +02:00
func (q *sqlBuilder) writeOrderBy(statement statementType, orderBy []orderByClause) error {
q.newLine()
2019-05-12 18:15:23 +02:00
q.writeString("ORDER BY")
q.increaseIdent()
err := serializeOrderByClauseList(statement, orderBy, q)
q.decreaseIdent()
return err
2019-05-03 12:51:57 +02:00
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) writeHaving(statement statementType, having Expression) error {
q.newLine()
2019-05-12 18:15:23 +02:00
q.writeString("HAVING")
q.increaseIdent()
2019-06-27 19:55:21 +02:00
err := having.serialize(statement, q, noWrap)
2019-05-12 18:15:23 +02:00
q.decreaseIdent()
return err
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) writeReturning(statement statementType, returning []projection) error {
if len(returning) == 0 {
return nil
}
q.newLine()
q.writeString("RETURNING")
q.increaseIdent()
return q.writeProjections(statement, returning)
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) newLine() {
2019-05-12 18:15:23 +02:00
q.write([]byte{'\n'})
q.write(bytes.Repeat([]byte{' '}, q.ident))
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) write(data []byte) {
2019-05-12 18:15:23 +02:00
if len(data) == 0 {
return
}
if !isPreSeparator(q.lastChar) && !isPostSeparator(data[0]) && q.buff.Len() > 0 {
q.buff.WriteByte(' ')
}
2019-05-01 14:42:46 +02:00
q.buff.Write(data)
2019-05-12 18:15:23 +02:00
q.lastChar = data[len(data)-1]
}
func isPreSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':'
2019-05-12 18:15:23 +02:00
}
func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
}
func (q *sqlBuilder) writeAlias(str string) {
aliasQuoteChar := string(q.dialect.AliasQuoteChar)
q.writeString(aliasQuoteChar + str + aliasQuoteChar)
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) writeString(str string) {
2019-05-12 18:15:23 +02:00
q.write([]byte(str))
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) writeIdentifier(name string) {
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
2019-06-17 12:05:52 +02:00
if quoteWrap {
identQuoteChar := string(q.dialect.IdentifierQuoteChar)
q.writeString(identQuoteChar + name + identQuoteChar)
2019-06-17 12:05:52 +02:00
} else {
q.writeString(name)
}
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) writeByte(b byte) {
2019-05-12 18:15:23 +02:00
q.write([]byte{b})
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) finalize() (string, []interface{}) {
2019-05-12 18:15:23 +02:00
return q.buff.String() + ";\n", q.args
}
2019-07-08 10:48:03 +02:00
func (q *sqlBuilder) insertConstantArgument(arg interface{}) {
2019-07-18 17:43:11 +02:00
q.writeString(argToString(arg))
2019-06-03 14:41:39 +02:00
}
2019-07-18 17:43:11 +02:00
func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) {
q.args = append(q.args, arg)
argPlaceholder := q.dialect.ArgumentPlaceholder(len(q.args))
2019-05-12 18:15:23 +02:00
q.writeString(argPlaceholder)
}
2019-07-18 17:43:11 +02:00
func argToString(value interface{}) string {
2019-07-20 17:44:43 +02:00
if utils.IsNil(value) {
2019-06-11 12:47:35 +02:00
return "NULL"
}
switch bindVal := value.(type) {
case bool:
if bindVal {
2019-05-12 18:15:23 +02:00
return "TRUE"
}
2019-07-18 17:43:11 +02:00
return "FALSE"
case int8:
2019-05-12 18:15:23 +02:00
return strconv.FormatInt(int64(bindVal), 10)
case int:
2019-05-12 18:15:23 +02:00
return strconv.FormatInt(int64(bindVal), 10)
case int16:
2019-05-12 18:15:23 +02:00
return strconv.FormatInt(int64(bindVal), 10)
case int32:
2019-05-12 18:15:23 +02:00
return strconv.FormatInt(int64(bindVal), 10)
case int64:
2019-05-12 18:15:23 +02:00
return strconv.FormatInt(int64(bindVal), 10)
case uint8:
2019-05-12 18:15:23 +02:00
return strconv.FormatUint(uint64(bindVal), 10)
case uint:
2019-05-12 18:15:23 +02:00
return strconv.FormatUint(uint64(bindVal), 10)
case uint16:
2019-05-12 18:15:23 +02:00
return strconv.FormatUint(uint64(bindVal), 10)
case uint32:
2019-05-12 18:15:23 +02:00
return strconv.FormatUint(uint64(bindVal), 10)
case uint64:
2019-05-12 18:15:23 +02:00
return strconv.FormatUint(uint64(bindVal), 10)
case float32:
2019-05-12 18:15:23 +02:00
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case float64:
2019-05-12 18:15:23 +02:00
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case string:
2019-06-11 12:47:35 +02:00
return stringQuote(bindVal)
case []byte:
2019-06-11 12:47:35 +02:00
return stringQuote(string(bindVal))
case uuid.UUID:
return stringQuote(bindVal.String())
case time.Time:
2019-07-04 17:54:15 +02:00
return stringQuote(string(utils.FormatTimestamp(bindVal)))
default:
2019-07-20 11:22:01 +02:00
return "[Unsupported type]"
}
2019-03-31 14:07:58 +02:00
}
2019-06-11 12:47:35 +02:00
func stringQuote(value string) string {
return `'` + strings.Replace(value, "'", "''", -1) + `'`
}