package jet import ( "bytes" "database/sql/driver" "fmt" "reflect" "sort" "strconv" "strings" "time" "unicode" "github.com/go-jet/jet/v2/internal/3rdparty/pq" "github.com/go-jet/jet/v2/internal/utils/is" "github.com/google/uuid" ) // SQLBuilder generates output SQL type SQLBuilder struct { Dialect Dialect Buff bytes.Buffer Args []interface{} lastChar byte ident int Debug bool } const tabSize = 4 const defaultIdent = 5 // IncreaseIdent adds ident or defaultIdent number of spaces to each new line func (s *SQLBuilder) IncreaseIdent(ident ...int) { if len(ident) > 0 { s.ident += ident[0] } else { s.ident += defaultIdent } } // DecreaseIdent removes ident or defaultIdent number of spaces for each new line func (s *SQLBuilder) DecreaseIdent(ident ...int) { toDecrease := defaultIdent if len(ident) > 0 { toDecrease = ident[0] } if s.ident < toDecrease { s.ident = 0 } s.ident -= toDecrease } // WriteProjections func func (s *SQLBuilder) WriteProjections(statement StatementType, projections []Projection) { s.IncreaseIdent() SerializeProjectionList(statement, projections, s) s.DecreaseIdent() } // WriteRowToJsonProjections serializes slice of projections intended for row_to_json json aggregation func (s *SQLBuilder) WriteRowToJsonProjections(statement StatementType, projections []Projection) { for i, projection := range projections { if i > 0 { s.WriteString(",") s.NewLine() } projection.serializeForRowToJsonProjection(statement, s) } } // NewLine adds new line to output SQL func (s *SQLBuilder) NewLine() { s.write([]byte{'\n'}) s.write(bytes.Repeat([]byte{' '}, s.ident)) } func (s *SQLBuilder) write(data []byte) { if len(data) == 0 { return } if !isPreSeparator(s.lastChar) && !isPostSeparator(data[0]) && s.Buff.Len() > 0 { s.Buff.WriteByte(' ') } s.Buff.Write(data) s.lastChar = data[len(data)-1] } func isPreSeparator(b byte) bool { return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':' } func isPostSeparator(b byte) bool { return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' } // WriteAlias is used to add alias to output SQL func (s *SQLBuilder) WriteAlias(str string) { aliasQuoteChar := string(s.Dialect.AliasQuoteChar()) s.WriteString(aliasQuoteChar + str + aliasQuoteChar) } // WriteString writes sting to output SQL func (s *SQLBuilder) WriteString(str string) { s.write([]byte(str)) } // WriteJsonObjKey serializes json object key func (s *SQLBuilder) WriteJsonObjKey(key string) { s.WriteString(fmt.Sprintf(`'%s', `, key)) } // WriteIdentifier adds identifier to output SQL func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { if s.shouldQuote(name, alwaysQuote...) { identQuoteChar := string(s.Dialect.IdentifierQuoteChar()) s.WriteString(identQuoteChar + name + identQuoteChar) } else { s.WriteString(name) } } func (s *SQLBuilder) shouldQuote(name string, alwaysQuote ...bool) bool { return s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 } // WriteByte writes byte to output SQL func (s *SQLBuilder) WriteByte(b byte) { s.write([]byte{b}) } func (s *SQLBuilder) finalize() (string, []interface{}) { return s.Buff.String() + ";\n", s.Args } func (s *SQLBuilder) insertConstantArgument(arg interface{}) { s.WriteString(s.argToString(arg)) } func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) { if s.Debug { s.insertConstantArgument(arg) return } s.Args = append(s.Args, arg) argPlaceholder := s.Dialect.ArgumentPlaceholder()(len(s.Args)) s.WriteString(argPlaceholder) } func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{}) { type namedArgumentPosition struct { Name string Value interface{} Position int } var namedArgumentPositions []namedArgumentPosition for namedArg, value := range namedArg { rawCopy := raw rawIndex := 0 exists := false // one named argument can occur multiple times inside raw string for { index := strings.Index(rawCopy, namedArg) if index == -1 { break } exists = true namedArgumentPositions = append(namedArgumentPositions, namedArgumentPosition{ Name: namedArg, Value: value, Position: rawIndex + index, }) rawCopy = rawCopy[index+len(namedArg):] rawIndex += index + len(namedArg) } if !exists { panic("jet: named argument '" + namedArg + "' does not appear in raw query") } } sort.Slice(namedArgumentPositions, func(i, j int) bool { return namedArgumentPositions[i].Position < namedArgumentPositions[j].Position }) for _, namedArgumentPos := range namedArgumentPositions { // if named argument does not exists in raw string do not add argument to the list of arguments // It can happen if the same argument occurs multiple times in postgres query. if !strings.Contains(raw, namedArgumentPos.Name) { continue } s.Args = append(s.Args, namedArgumentPos.Value) currentArgNum := len(s.Args) placeholder := s.Dialect.ArgumentPlaceholder()(currentArgNum) // if placeholder is not unique identifier ($1, $2, etc..), we will replace just one occurrence of the argument toReplace := -1 // all occurrences if placeholder == "?" { toReplace = 1 // just one occurrence } if s.Debug { placeholder = s.argToString(namedArgumentPos.Value) } raw = strings.Replace(raw, namedArgumentPos.Name, placeholder, toReplace) } s.WriteString(raw) } func (s *SQLBuilder) argToString(value interface{}) string { if is.Nil(value) { return "NULL" } strVal, ok := s.Dialect.ArgumentToString(value) if ok { return strVal } switch bindVal := value.(type) { case bool: if bindVal { return "TRUE" } return "FALSE" case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: return integerTypesToString(bindVal) case float32: return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) case float64: return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) case string: return stringQuote(bindVal) case []byte: return stringQuote(string(bindVal)) case uuid.UUID: return stringQuote(bindVal.String()) case time.Time: return stringQuote(string(pq.FormatTimestamp(bindVal))) default: if strBindValue, ok := bindVal.(fmt.Stringer); ok { return stringQuote(strBindValue.String()) } if valuer, ok := bindVal.(driver.Valuer); ok { val, err := valuer.Value() if err != nil { // If valuer for some reason returns an error, we return error string representation. // This is fine because argToString is called only from DebugSQL, and DebugSQL shouldn't be used in production. return err.Error() } return s.argToString(val) } panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String())) } } func integerTypesToString(value interface{}) string { switch bindVal := value.(type) { case int: return strconv.FormatInt(int64(bindVal), 10) case uint: return strconv.FormatUint(uint64(bindVal), 10) case int8: return strconv.FormatInt(int64(bindVal), 10) case uint8: return strconv.FormatUint(uint64(bindVal), 10) case int16: return strconv.FormatInt(int64(bindVal), 10) case uint16: return strconv.FormatUint(uint64(bindVal), 10) case int32: return strconv.FormatInt(int64(bindVal), 10) case uint32: return strconv.FormatUint(uint64(bindVal), 10) case int64: return strconv.FormatInt(bindVal, 10) case uint64: return strconv.FormatUint(bindVal, 10) } panic("jet: Unsupported integer type: " + reflect.TypeOf(value).String()) } func shouldQuoteIdentifier(identifier string) bool { if len(identifier) == 0 { return true } _, err := strconv.ParseInt(identifier, 10, 64) if err == nil { // if it is a number we should quote it return true } firstChar := rune(identifier[0]) if unicode.IsNumber(firstChar) { return true } // check if contains non ascii characters for _, c := range identifier { if unicode.IsNumber(c) || c == '_' { continue } if c > unicode.MaxASCII || !unicode.IsLetter(c) || unicode.IsUpper(c) { return true } } return false } func stringQuote(value string) string { return `'` + strings.Replace(value, "'", "''", -1) + `'` }