Generator clean up.

Ensure all sql types can be processed.
This commit is contained in:
zer0sub 2019-05-27 13:11:15 +02:00
parent b3a52ceb31
commit 64ba909381
21 changed files with 495 additions and 208 deletions

View file

@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"github.com/serenize/snaker"
"strings"
)
type ColumnInfo struct {
@ -13,26 +14,21 @@ type ColumnInfo struct {
EnumName string
}
func (c ColumnInfo) ToSqlBuilderColumnType() string {
func (c ColumnInfo) SqlBuilderColumnType() string {
switch c.DataType {
case "boolean":
return "BoolColumn"
case "smallint":
case "smallint", "integer", "bigint":
return "IntegerColumn"
case "integer":
return "IntegerColumn"
case "bigint":
return "IntegerColumn"
case "date", "timestamp without time zone", "timestamp with time zone":
case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone":
return "TimeColumn"
case "text", "character", "character varying", "bytea", "uuid":
case "USER-DEFINED", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "interval", "line", "ARRAY":
return "StringColumn"
case "real":
return "NumericColumn"
case "numeric", "double precision":
case "real", "numeric", "double precision":
return "NumericColumn"
default:
fmt.Println("Unknownl type: " + c.DataType + ", using string column instead.")
fmt.Println("Unknown sql type: " + c.DataType + ", using string column instead for sql builder.")
return "StringColumn"
}
}
@ -49,11 +45,12 @@ func (c ColumnInfo) GoBaseType() string {
return "int32"
case "bigint":
return "int64"
case "date", "timestamp without time zone", "timestamp with time zone":
case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone":
return "time.Time"
case "bytea":
return "[]byte"
case "text", "character", "character varying":
case "text", "character", "character varying", "tsvector", "bit", "bit varying", "money", "json", "jsonb",
"xml", "point", "interval", "line", "ARRAY":
return "string"
case "real":
return "float32"
@ -61,31 +58,21 @@ func (c ColumnInfo) GoBaseType() string {
return "float64"
case "uuid":
return "uuid.UUID"
case "json", "jsonb":
return "types.JSONText"
default:
fmt.Println("Unknown go map type: " + c.DataType + ", " + c.EnumName + ", using string instead.")
fmt.Println("Unknown sql type: " + c.DataType + ", " + c.EnumName + ", using string instead for model type.")
return "string"
}
}
func (c ColumnInfo) ToGoType() string {
func (c ColumnInfo) GoModelType() string {
typeStr := c.GoBaseType()
if c.IsNullable {
if c.IsNullable && !strings.HasPrefix(typeStr, "[]") {
return "*" + typeStr
}
return typeStr
}
func (c ColumnInfo) ToGoFieldName() string {
return snaker.SnakeToCamel(c.Name)
}
func (c ColumnInfo) ToGoVarName() string {
return snaker.SnakeToCamel(c.Name) + "Column"
}
func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) {
query := `

View file

@ -3,7 +3,6 @@ package postgres_metadata
import (
"database/sql"
"github.com/serenize/snaker"
"strings"
)
type TableInfo struct {
@ -32,8 +31,6 @@ func (t TableInfo) GetImports() []string {
imports["time.Time"] = "time"
case "uuid.UUID":
imports["uuid.UUID"] = "github.com/google/uuid"
case "types.JSONText":
imports["types.JSONText"] = "github.com/sub0zero/go-sqlbuilder/types"
}
}
@ -46,26 +43,10 @@ func (t TableInfo) GetImports() []string {
return ret
}
func (t TableInfo) ToGoModelStructName() string {
return snaker.SnakeToCamel(t.name)
}
func (t TableInfo) ToGoVarName() string {
return snaker.SnakeToCamel(t.name)
}
func (t TableInfo) ToGoStructName() string {
func (t TableInfo) GoStructName() string {
return snaker.SnakeToCamel(t.name) + "Table"
}
func (t TableInfo) ToGoColumnFieldList(sep string) string {
columnNames := []string{}
for _, columnInfo := range t.Columns {
columnNames = append(columnNames, columnInfo.ToGoVarName())
}
return strings.Join(columnNames, sep)
}
func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo TableInfo, err error) {
tableInfo.SchemaName = schemaName

View file

@ -13,47 +13,57 @@ var autoGenWarningTemplate = `
`
var sqlBuilderTableTemplate = `package table
var sqlBuilderTableTemplate = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- if gt $i 0 }}, {{end}}{{camelize $c.Name}}Column
{{- end}}
{{- end}}
{{define "nullable" -}}
{{- if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}}
{{- end}}
package table
import (
"github.com/sub0zero/go-sqlbuilder/sqlbuilder"
)
type {{.ToGoStructName}} struct {
type {{.GoStructName}} struct {
sqlbuilder.Table
//Columns
{{- range .Columns}}
{{.ToGoFieldName}} *sqlbuilder.{{.ToSqlBuilderColumnType}}
{{camelize .Name}} *sqlbuilder.{{.SqlBuilderColumnType}}
{{- end}}
AllColumns sqlbuilder.ColumnList
}
var {{.ToGoVarName}} = new{{.ToGoStructName}}()
var {{camelize .Name}} = new{{.GoStructName}}()
func new{{.ToGoStructName}}() *{{.ToGoStructName}} {
func new{{.GoStructName}}() *{{.GoStructName}} {
var (
{{- range .Columns}}
{{.ToGoVarName}} = sqlbuilder.New{{.ToSqlBuilderColumnType}}("{{.Name}}", {{if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}})
{{camelize .Name}}Column = sqlbuilder.New{{.SqlBuilderColumnType}}("{{.Name}}", {{template "nullable" .}})
{{- end}}
)
return &{{.ToGoStructName}}{
Table: *sqlbuilder.NewTable("{{.SchemaName}}", "{{.Name}}", {{.ToGoColumnFieldList ", "}}),
return &{{.GoStructName}}{
Table: *sqlbuilder.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}),
//Columns
{{- range .Columns}}
{{.ToGoFieldName}}: {{.ToGoVarName}},
{{camelize .Name}}: {{camelize .Name}}Column,
{{- end}}
AllColumns: sqlbuilder.ColumnList{ {{.ToGoColumnFieldList ", "}} },
AllColumns: sqlbuilder.ColumnList{ {{template "column-list" .Columns}} },
}
}
func (a *{{.ToGoStructName}}) AS(alias string) *{{.ToGoStructName}} {
aliasTable := new{{.ToGoStructName}}()
func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
aliasTable := new{{.GoStructName}}()
aliasTable.Table.SetAlias(alias)
@ -73,9 +83,9 @@ import (
{{end}}
type {{.ToGoModelStructName}} struct {
type {{camelize .Name}} struct {
{{- range .Columns}}
{{.ToGoFieldName}} {{.ToGoType}} {{if $.IsUnique .Name}}` + "`sql:\"unique\"`" + ` {{end}}
{{camelize .Name}} {{.GoModelType}} {{if $.IsUnique .Name}}` + "`sql:\"unique\"`" + ` {{end}}
{{- end}}
}
`

View file

@ -3,7 +3,7 @@ package sqlbuilder
import (
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
)
type deleteStatement interface {
@ -71,10 +71,10 @@ func (d *deleteStatementImpl) DebugSql() (query string, err error) {
return DebugSql(d)
}
func (d *deleteStatementImpl) Query(db types.Db, destination interface{}) error {
func (d *deleteStatementImpl) Query(db execution.Db, destination interface{}) error {
return Query(d, db, destination)
}
func (d *deleteStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
func (d *deleteStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
return Execute(d, db)
}

View file

@ -1,4 +1,4 @@
package types
package execution
import "database/sql"

View file

@ -6,14 +6,13 @@ import (
"errors"
"fmt"
"github.com/serenize/snaker"
"github.com/sub0zero/go-sqlbuilder/types"
"reflect"
"strconv"
"strings"
"time"
)
func Query(db types.Db, query string, args []interface{}, destinationPtr interface{}) error {
func Query(db Db, query string, args []interface{}, destinationPtr interface{}) error {
if destinationPtr == nil {
return errors.New("Destination is nil. ")
@ -54,7 +53,7 @@ func Query(db types.Db, query string, args []interface{}, destinationPtr interfa
}
}
func queryToSlice(db types.Db, query string, args []interface{}, slicePtr interface{}) error {
func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{}) error {
if db == nil {
return errors.New("db is nil")
}
@ -527,8 +526,8 @@ func isGoBaseType(objType reflect.Type) bool {
typeStr := objType.String()
switch typeStr {
case "string", "int", "int32", "int16", "float32", "float64", "time.Time", "bool", "[]byte", "[]uint8",
"*string", "*int", "*int32", "*int16", "*float32", "*float64", "*time.Time", "*bool", "*[]byte", "*[]uint8":
case "string", "int", "int16", "int32", "int64", "float32", "float64", "time.Time", "bool", "[]byte", "[]uint8",
"*string", "*int", "*int16", "*int32", "*int64", "*float32", "*float64", "*time.Time", "*bool", "*[]byte", "*[]uint8":
return true
}
@ -544,10 +543,10 @@ func setReflectValue(source, destination reflect.Value) error {
if source.CanAddr() {
sourceElem = source.Addr()
} else {
newDestination := reflect.New(destination.Type().Elem())
newDestination.Elem().Set(source)
sourceCopy := reflect.New(source.Type())
sourceCopy.Elem().Set(source)
sourceElem = newDestination
sourceElem = sourceCopy
}
}
} else {
@ -599,6 +598,7 @@ var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
var nullStringType = reflect.TypeOf(sql.NullString{})
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullTimeType = reflect.TypeOf(NullTime{})
var nullByteArrayType = reflect.TypeOf(NullByteArray{})
func newScanType(columnType *sql.ColumnType) reflect.Type {
switch columnType.DatabaseTypeName() {
@ -608,7 +608,7 @@ func newScanType(columnType *sql.ColumnType) reflect.Type {
return nullInt32Type
case "INT8":
return nullInt64Type
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "BYTEA", "UUID", "JSON", "JSONB":
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
return nullStringType
case "FLOAT4":
return nullFloatType
@ -616,10 +616,13 @@ func newScanType(columnType *sql.ColumnType) reflect.Type {
return nullFloat64Type
case "BOOL":
return nullBoolType
case "DATE", "TIMESTAMP", "TIMESTAMPTZ":
case "BYTEA":
return nullByteArrayType
case "DATE", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
return nullTimeType
default:
panic("Unknown column database type " + columnType.DatabaseTypeName())
fmt.Println("Unknown column database type " + columnType.DatabaseTypeName() + " using string as default.")
return nullStringType
}
}

View file

@ -5,6 +5,27 @@ import (
"time"
)
// NullByteArray
type NullByteArray struct {
ByteArray []byte
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
func (nb *NullByteArray) Scan(value interface{}) error {
nb.ByteArray, nb.Valid = value.([]byte)
return nil
}
// Value implements the driver Valuer interface.
func (nb NullByteArray) Value() (driver.Value, error) {
if !nb.Valid {
return nil, nil
}
return nb.ByteArray, nil
}
//NullTime
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL

View file

@ -4,7 +4,7 @@ import (
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/serenize/snaker"
"github.com/sub0zero/go-sqlbuilder/types"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
"reflect"
"strings"
)
@ -39,11 +39,11 @@ type insertStatementImpl struct {
errors []string
}
func (s *insertStatementImpl) Query(db types.Db, destination interface{}) error {
func (s *insertStatementImpl) Query(db execution.Db, destination interface{}) error {
return Query(s, db, destination)
}
func (u *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
func (u *insertStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
return Execute(u, db)
}

View file

@ -3,7 +3,7 @@ package sqlbuilder
import (
"database/sql"
"github.com/pkg/errors"
"github.com/sub0zero/go-sqlbuilder/types"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
)
type lockMode string
@ -92,10 +92,10 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
return
}
func (l *lockStatementImpl) Query(db types.Db, destination interface{}) error {
func (l *lockStatementImpl) Query(db execution.Db, destination interface{}) error {
return Query(l, db, destination)
}
func (l *lockStatementImpl) Execute(db types.Db) (sql.Result, error) {
func (l *lockStatementImpl) Execute(db execution.Db) (sql.Result, error) {
return Execute(l, db)
}

View file

@ -3,7 +3,7 @@ package sqlbuilder
import (
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
)
type selectStatement interface {
@ -259,11 +259,11 @@ func (s *selectStatementImpl) FOR_UPDATE() selectStatement {
return s
}
func (s *selectStatementImpl) Query(db types.Db, destination interface{}) error {
func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) error {
return Query(s, db, destination)
}
func (s *selectStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
func (s *selectStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
return Execute(s, db)
}

View file

@ -3,7 +3,7 @@ package sqlbuilder
import (
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
)
type setStatement interface {
@ -197,10 +197,10 @@ func (s *setStatementImpl) DebugSql() (query string, err error) {
return DebugSql(s)
}
func (s *setStatementImpl) Query(db types.Db, destination interface{}) error {
func (s *setStatementImpl) Query(db execution.Db, destination interface{}) error {
return Query(s, db, destination)
}
func (u *setStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
func (u *setStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
return Execute(u, db)
}

View file

@ -2,7 +2,7 @@ package sqlbuilder
import (
"database/sql"
"github.com/sub0zero/go-sqlbuilder/types"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
"strconv"
"strings"
)
@ -13,8 +13,8 @@ type Statement interface {
DebugSql() (query string, err error)
Query(db types.Db, destination interface{}) error
Execute(db types.Db) (sql.Result, error)
Query(db execution.Db, destination interface{}) error
Execute(db execution.Db) (sql.Result, error)
}
func DebugSql(statement Statement) (string, error) {

View file

@ -3,7 +3,7 @@ package sqlbuilder
import (
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
)
type updateStatement interface {
@ -135,10 +135,10 @@ func (u *updateStatementImpl) DebugSql() (query string, err error) {
return DebugSql(u)
}
func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error {
func (u *updateStatementImpl) Query(db execution.Db, destination interface{}) error {
return Query(u, db, destination)
}
func (u *updateStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
func (u *updateStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
return Execute(u, db)
}

View file

@ -4,7 +4,6 @@ import (
"database/sql"
"errors"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
"github.com/sub0zero/go-sqlbuilder/types"
)
func serializeOrderByClauseList(statement statementType, orderByClauses []orderByClause, out *queryData) error {
@ -114,7 +113,7 @@ func serializeColumnList(statement statementType, columns []column, out *queryDa
return nil
}
func Query(statement Statement, db types.Db, destination interface{}) error {
func Query(statement Statement, db execution.Db, destination interface{}) error {
query, args, err := statement.Sql()
if err != nil {
@ -124,7 +123,7 @@ func Query(statement Statement, db types.Db, destination interface{}) error {
return execution.Query(db, query, args, destination)
}
func Execute(statement Statement, db types.Db) (res sql.Result, err error) {
func Execute(statement Statement, db execution.Db) (res sql.Result, err error) {
query, args, err := statement.Sql()
if err != nil {

138
tests/data/test_sample.sql Normal file
View file

@ -0,0 +1,138 @@
DROP TABLE IF EXISTS test_sample.all_types;
CREATE TABLE test_sample.all_types
(
-- numeric
smallint_ptr smallint,
smallint smallint NOT NULL,
integer_ptr integer,
integer integer NOT NULL,
bigint_ptr bigint,
bigint bigint NOT NULL,
decimal_ptr decimal(10, 2),
decimal decimal(10, 2) NOT NULL,
numeric_ptr numeric(20, 3),
numeric numeric(20,3) NOT NULL,
real_ptr real,
real real NOT NULL,
double_precision_ptr double precision,
double_precision double precision NOT NULL,
smallserial smallserial NOT NULL,
serial serial NOT NULL,
bigserial bigserial NOT NULL,
--monetary
-- money_ptr money,
-- money money NOT NULL,
character_varying_ptr character varying(100),
character_varying character varying(200) NOT NULL,
character_ptr character(80),
character character(80) NOT NULL,
text_ptr text,
text text NOT NULL,
--binary
bytea_ptr bytea,
bytea bytea NOT NULL,
--datetime
timestampz_ptr timestamp with time zone,
timestampz timestamp with time zone NOT NULL,
timestamp_ptr timestamp without time zone,
timestamp timestamp without time zone NOT NULL,
date_ptr date,
date date NOT NULL,
timez_ptr time with time zone,
timez time with time zone NOT NULL,
time_ptr time without time zone,
time time without time zone NOT NULL,
interval_ptr interval,
interval interval NOT NULL,
--boolean
boolean_ptr boolean,
boolean boolean NOT NULL,
--geometry
point_ptr point,
--bitstrings
bit_ptr bit(3),
bit bit(3) NOT NULL,
bit_varying_ptr bit varying(20),
bit_varying bit varying(40) NOT NULL,
--textsearch
tsvector_ptr tsvector,
tsvector tsvector NOT NULL,
--uuid
uuid_ptr uuid,
uuid uuid NOT NULL,
--xml
xml_ptr xml,
xml xml NOT NULL,
--json
json_ptr json,
json json NOT NULL,
jsonb_ptr jsonb,
jsonb jsonb NOT NULL,
--array
integer_array_ptr integer[],
integer_array integer[] NOT NULL,
text_array_ptr text[],
text_array text[] NOT NULL,
jsonb_array jsonb[] NOT NULL,
text_multi_dim_array_ptr text[][],
text_multi_dim_array text[][] NOT NULL
);
INSERT INTO test_sample.all_types(
smallint_ptr, "smallint", integer_ptr, "integer", bigint_ptr, "bigint", decimal_ptr, "decimal", numeric_ptr, "numeric", real_ptr, "real", double_precision_ptr, double_precision, smallserial, serial, bigserial,
-- money_ptr, money,
character_varying_ptr, character_varying, character_ptr, "character", text_ptr, text,
bytea_ptr, bytea,
timestampz_ptr, timestampz, timestamp_ptr, "timestamp", date_ptr, date, timez_ptr, timez, time_ptr, "time", interval_ptr, "interval",
boolean_ptr, "boolean",
point_ptr,
bit_ptr, "bit", bit_varying_ptr, bit_varying,
tsvector_ptr, tsvector,
uuid_ptr, uuid,
xml_ptr, xml,
json_ptr, json, jsonb_ptr, jsonb,
integer_array_ptr, integer_array, text_array_ptr, text_array, jsonb_array, text_multi_dim_array_ptr, text_multi_dim_array)
VALUES (1, 1, 300, 300, 50000, 5000, 11.44, 11.44, 55.77, 55.77, 99.1, 99.1, 11111111.22, 11111111.22, DEFAULT, DEFAULT, DEFAULT,
-- 100000, 100000,
'ABBA', 'ABBA', 'JOHN', 'JOHN', 'Some text', 'Some text',
'bytea', 'bytea',
'January 8 04:05:06 1999 PST', 'January 8 04:05:06 1999 PST', '1999-01-08 04:05:06', '1999-01-08 04:05:06', '1999-01-08', '1999-01-08', '04:05:06 -8:00', '04:05:06 -8:00', '04:05:06', '04:05:06', '3 4:05:06', '3 4:05:06',
TRUE, FALSE,
'(2,3)',
B'101', B'101', B'101111', B'101111',
to_tsvector('supernovae'), to_tsvector('supernovae'),
'A0EEBC99-9C0B-4EF8-BB6D-6BB9BD380A11', 'A0EEBC99-9C0B-4EF8-BB6D-6BB9BD380A11',
'<Sub>abc</Sub>', '<Sub>abc</Sub>',
'{"a": 1, "b": 3}', '{"a": 1, "b": 3}', '{"a": 1, "b": 3}', '{"a": 1, "b": 3}',
'{1, 2, 3}', '{1, 2, 3}', '{"breakfast", "consulting"}', '{"breakfast", "consulting"}', ARRAY['{"a": 1, "b": 2}'::jsonb, '{"a":3, "b": 4}'::jsonb], '{{"meeting", "lunch"}, {"training", "presentation"}}', '{{"meeting", "lunch"}, {"training", "presentation"}}')
,
(NULL, 1, NULL, 300, NULL, 5000, NULL, 11.44, NULL, 55.77, NULL, 99.1, NULL, 11111111.22, DEFAULT, DEFAULT, DEFAULT,
-- NULL, 100000,
NULL, 'ABBA', NULL, 'JOHN', NULL, 'Some text',
NULL, 'bytea',
NULL, 'January 8 04:05:06 1999 PST', NULL, '1999-01-08 04:05:06', NULL, '1999-01-08', NULL, '04:05:06 -8:00', NULL, '04:05:06', NULL, '3 4:05:06',
NULL, FALSE,
NULL,
NULL, B'101', NULL, B'101111',
NULL, to_tsvector('supernovae'),
NULL, 'A0EEBC99-9C0B-4EF8-BB6D-6BB9BD380A11',
NULL, '<Sub>abc</Sub>',
NULL, '{"a": 1, "b": 3}', NULL, '{"a": 1, "b": 3}',
NULL, '{1, 2, 3}', NULL, '{"breakfast", "consulting"}', ARRAY['{"a": 1, "b": 2}'::jsonb, '{"a":3, "b": 4}'::jsonb], NULL, '{{"meeting", "lunch"}, {"training", "presentation"}}')
;

View file

@ -136,5 +136,5 @@ func TestGenerateModel(t *testing.T) {
staff := model.Staff{}
assert.Equal(t, reflect.TypeOf(staff.Email).String(), "*string")
assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8")
assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "[]uint8")
}

View file

@ -185,8 +185,6 @@ func TestScanToStruct(t *testing.T) {
err := query.Query(db, &dest)
assert.Error(t, err, `Scan: unable to scan type int32 into UUID, at struct field: InventoryID uuid.UUID of type tests.Inventory. `)
fmt.Println(err)
})
}
@ -681,7 +679,7 @@ var address256 = model.Address{
CityID: 312,
PostalCode: stringPtr("3433"),
Phone: "246810237916",
LastUpdate: *timeWithoutTimeZone("2006-02-15 09:45:30", 0),
LastUpdate: *timestampWithoutTimeZone("2006-02-15 09:45:30", 0),
}
var addres517 = model.Address{
@ -692,7 +690,7 @@ var addres517 = model.Address{
CityID: 312,
PostalCode: stringPtr("35653"),
Phone: "879347453467",
LastUpdate: *timeWithoutTimeZone("2006-02-15 09:45:30", 0),
LastUpdate: *timestampWithoutTimeZone("2006-02-15 09:45:30", 0),
}
var customer256 = model.Customer{
@ -703,8 +701,8 @@ var customer256 = model.Customer{
Email: stringPtr("mattie.hoffman@sakilacustomer.org"),
AddressID: 256,
Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 0),
CreateDate: *timestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timestampWithoutTimeZone("2013-05-26 14:49:45.738", 0),
Active: int32Ptr(1),
}
@ -716,36 +714,36 @@ var customer512 = model.Customer{
Email: stringPtr("cecil.vines@sakilacustomer.org"),
AddressID: 517,
Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 0),
CreateDate: *timestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timestampWithoutTimeZone("2013-05-26 14:49:45.738", 0),
Active: int32Ptr(1),
}
var countryUk = model.Country{
CountryID: 102,
Country: "United Kingdom",
LastUpdate: *timeWithoutTimeZone("2006-02-15 09:44:00", 0),
LastUpdate: *timestampWithoutTimeZone("2006-02-15 09:44:00", 0),
}
var cityLondon = model.City{
CityID: 312,
City: "London",
CountryID: 102,
LastUpdate: *timeWithoutTimeZone("2006-02-15 09:45:25", 0),
LastUpdate: *timestampWithoutTimeZone("2006-02-15 09:45:25", 0),
}
var inventory1 = model.Inventory{
InventoryID: 1,
FilmID: 1,
StoreID: 1,
LastUpdate: *timeWithoutTimeZone("2006-02-15 10:09:17", 0),
LastUpdate: *timestampWithoutTimeZone("2006-02-15 10:09:17", 0),
}
var inventory2 = model.Inventory{
InventoryID: 2,
FilmID: 1,
StoreID: 1,
LastUpdate: *timeWithoutTimeZone("2006-02-15 10:09:17", 0),
LastUpdate: *timestampWithoutTimeZone("2006-02-15 10:09:17", 0),
}
var film1 = model.Film{
@ -759,7 +757,7 @@ var film1 = model.Film{
Length: int16Ptr(86),
ReplacementCost: 20.99,
Rating: &pgRating,
LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951", 3),
LastUpdate: *timestampWithoutTimeZone("2013-05-26 14:50:58.951", 3),
SpecialFeatures: stringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"),
Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17",
}
@ -775,7 +773,7 @@ var film2 = model.Film{
Length: int16Ptr(48),
ReplacementCost: 12.99,
Rating: &gRating,
LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951", 3),
LastUpdate: *timestampWithoutTimeZone("2013-05-26 14:50:58.951", 3),
SpecialFeatures: stringPtr(`{Trailers,"Deleted Scenes"}`),
Fulltext: `'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14`,
}
@ -784,7 +782,7 @@ var store1 = model.Store{
StoreID: 1,
ManagerStaffID: 1,
AddressID: 1,
LastUpdate: *timeWithoutTimeZone("2006-02-15 09:57:12", 0),
LastUpdate: *timestampWithoutTimeZone("2006-02-15 09:57:12", 0),
}
var pgRating = model.MpaaRating_PG
@ -793,5 +791,5 @@ var gRating = model.MpaaRating_G
var language1 = model.Language{
LanguageID: 1,
Name: "English ",
LastUpdate: *timeWithoutTimeZone("2006-02-15 10:02:19", 0),
LastUpdate: *timestampWithoutTimeZone("2006-02-15 10:02:19", 0),
}

View file

@ -2,6 +2,7 @@ package tests
import (
"fmt"
"github.com/davecgh/go-spew/spew"
. "github.com/sub0zero/go-sqlbuilder/sqlbuilder"
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model"
. "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table"
@ -36,7 +37,7 @@ WHERE actor.actor_id = 1;
ActorID: 1,
FirstName: "Penelope",
LastName: "Guiness",
LastUpdate: *timeWithoutTimeZone("2013-05-26 14:47:57.62", 2),
LastUpdate: *timestampWithoutTimeZone("2013-05-26 14:47:57.62", 2),
}
assert.DeepEqual(t, actor, expectedActor)
@ -834,7 +835,7 @@ ORDER BY film.film_id ASC;
ReplacementCost: 12.99,
Rating: &gRating,
RentalDuration: 3,
LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951", 3),
LastUpdate: *timestampWithoutTimeZone("2013-05-26 14:50:58.951", 3),
SpecialFeatures: stringPtr("{Trailers,\"Deleted Scenes\"}"),
Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14",
})
@ -936,14 +937,24 @@ ORDER BY customer_payment_sum.amount_sum ASC;
AddressID: 323,
Email: stringPtr("brian.wyman@sakilacustomer.org"),
Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3),
CreateDate: *timestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: int32Ptr(1),
})
assert.Equal(t, customersWithAmounts[0].AmountSum, 27.93)
}
func TestSelectStaff(t *testing.T) {
staffs := []model.Staff{}
err := Staff.SELECT(Staff.AllColumns).Query(db, &staffs)
assert.NilError(t, err)
spew.Dump(staffs)
}
func TestSelectTimeColumns(t *testing.T) {
expectedSql := `
@ -979,7 +990,7 @@ ORDER BY payment.payment_date ASC;
StaffID: 2,
RentalID: 1158,
Amount: 2.99,
PaymentDate: *timeWithoutTimeZone("2007-02-14 21:21:59.996577", 6),
PaymentDate: *timestampWithoutTimeZone("2007-02-14 21:21:59.996577", 6),
})
}

View file

@ -1,6 +1,7 @@
package tests
import (
"github.com/google/uuid"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder"
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model"
"gotest.tools/assert"
@ -20,6 +21,10 @@ func assertQuery(t *testing.T, query sqlbuilder.Statement, expectedQuery string,
assert.Equal(t, debuqSql, expectedQuery)
}
func boolPtr(b bool) *bool {
return &b
}
func int16Ptr(i int16) *int16 {
return &i
}
@ -28,11 +33,48 @@ func int32Ptr(i int32) *int32 {
return &i
}
func int64Ptr(i int64) *int64 {
return &i
}
func stringPtr(s string) *string {
return &s
}
func timeWithoutTimeZone(t string, precision int) *time.Time {
func float32Ptr(f float32) *float32 {
return &f
}
func float64Ptr(f float64) *float64 {
return &f
}
func uuidPtr(u string) *uuid.UUID {
uuid := uuid.MustParse(u)
return &uuid
}
func timeWithoutTimeZone(t string) *time.Time {
time, err := time.Parse("15:04:05", t)
if err != nil {
panic(err)
}
return &time
}
func timeWithTimeZone(t string) *time.Time {
time, err := time.Parse("15:04:05 -0700", t)
if err != nil {
panic(err)
}
return &time
}
func timestampWithoutTimeZone(t string, precision int) *time.Time {
precisionStr := ""
@ -49,6 +91,27 @@ func timeWithoutTimeZone(t string, precision int) *time.Time {
return &time
}
func timestampWithTimeZone(t string, precision int) *time.Time {
precisionStr := ""
if precision > 0 {
precisionStr = "." + strings.Repeat("9", precision)
}
time, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t)
if err != nil {
panic(err)
}
return &time
}
func M3(a, b, c interface{}) []interface{} {
return []interface{}{a, b, c}
}
var customer0 = model.Customer{
CustomerID: 1,
StoreID: 1,
@ -57,8 +120,8 @@ var customer0 = model.Customer{
Email: stringPtr("mary.smith@sakilacustomer.org"),
AddressID: 5,
Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3),
CreateDate: *timestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: int32Ptr(1),
}
@ -70,8 +133,8 @@ var customer1 = model.Customer{
Email: stringPtr("patricia.johnson@sakilacustomer.org"),
AddressID: 6,
Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3),
CreateDate: *timestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: int32Ptr(1),
}
@ -83,7 +146,7 @@ var lastCustomer = model.Customer{
Email: stringPtr("austin.cintron@sakilacustomer.org"),
AddressID: 605,
Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3),
CreateDate: *timestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: int32Ptr(1),
}

157
tests/types_test.go Normal file
View file

@ -0,0 +1,157 @@
package tests
import (
"fmt"
"github.com/google/uuid"
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model"
. "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table"
"gotest.tools/assert"
"testing"
)
func TestAllTypesSelect(t *testing.T) {
dest := []model.AllTypes{}
err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest)
fmt.Println(err)
assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0], dest0)
assert.DeepEqual(t, dest[1], dest1)
}
var dest0 = model.AllTypes{
SmallintPtr: int16Ptr(1),
Smallint: 1,
IntegerPtr: int32Ptr(300),
Integer: 300,
BigintPtr: int64Ptr(50000),
Bigint: 5000,
DecimalPtr: float64Ptr(11.44),
Decimal: 11.44,
NumericPtr: float64Ptr(55.77),
Numeric: 55.77,
RealPtr: float32Ptr(99.1),
Real: 99.1,
DoublePrecisionPtr: float64Ptr(11111111.22),
DoublePrecision: 11111111.22,
Smallserial: 1,
Serial: 1,
Bigserial: 1,
//MoneyPtr: nil,
//Money:
CharacterVaryingPtr: stringPtr("ABBA"),
CharacterVarying: "ABBA",
CharacterPtr: stringPtr("JOHN "),
Character: "JOHN ",
TextPtr: stringPtr("Some text"),
Text: "Some text",
ByteaPtr: []byte("bytea"),
Bytea: []byte("bytea"),
TimestampzPtr: timestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0),
Timestampz: *timestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0),
TimestampPtr: timestampWithoutTimeZone("1999-01-08 04:05:06", 0),
Timestamp: *timestampWithoutTimeZone("1999-01-08 04:05:06", 0),
DatePtr: timestampWithoutTimeZone("1999-01-08 00:00:00", 0),
Date: *timestampWithoutTimeZone("1999-01-08 00:00:00", 0),
TimezPtr: timeWithTimeZone("04:05:06 -0800"),
Timez: *timeWithTimeZone("04:05:06 -0800"),
TimePtr: timeWithoutTimeZone("04:05:06"),
Time: *timeWithoutTimeZone("04:05:06"),
IntervalPtr: stringPtr("3 days 04:05:06"),
Interval: "3 days 04:05:06",
BooleanPtr: boolPtr(true),
Boolean: false,
PointPtr: stringPtr("(2,3)"),
BitPtr: stringPtr("101"),
Bit: "101",
BitVaryingPtr: stringPtr("101111"),
BitVarying: "101111",
TsvectorPtr: stringPtr("'supernova':1"),
Tsvector: "'supernova':1",
UUIDPtr: uuidPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"),
UUID: uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"),
XMLPtr: stringPtr("<Sub>abc</Sub>"),
XML: "<Sub>abc</Sub>",
JSONPtr: stringPtr(`{"a": 1, "b": 3}`),
JSON: `{"a": 1, "b": 3}`,
JsonbPtr: stringPtr(`{"a": 1, "b": 3}`),
Jsonb: `{"a": 1, "b": 3}`,
IntegerArrayPtr: stringPtr("{1,2,3}"),
IntegerArray: "{1,2,3}",
TextArrayPtr: stringPtr("{breakfast,consulting}"),
TextArray: "{breakfast,consulting}",
JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`,
TextMultiDimArrayPtr: stringPtr("{{meeting,lunch},{training,presentation}}"),
TextMultiDimArray: "{{meeting,lunch},{training,presentation}}",
}
var dest1 = model.AllTypes{
SmallintPtr: nil,
Smallint: 1,
IntegerPtr: nil,
Integer: 300,
BigintPtr: nil,
Bigint: 5000,
DecimalPtr: nil,
Decimal: 11.44,
NumericPtr: nil,
Numeric: 55.77,
RealPtr: nil,
Real: 99.1,
DoublePrecisionPtr: nil,
DoublePrecision: 11111111.22,
Smallserial: 2,
Serial: 2,
Bigserial: 2,
//MoneyPtr: nil,
//Money:
CharacterVaryingPtr: nil,
CharacterVarying: "ABBA",
CharacterPtr: nil,
Character: "JOHN ",
TextPtr: nil,
Text: "Some text",
ByteaPtr: nil,
Bytea: []byte("bytea"),
TimestampzPtr: nil,
Timestampz: *timestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0),
TimestampPtr: nil,
Timestamp: *timestampWithoutTimeZone("1999-01-08 04:05:06", 0),
DatePtr: nil,
Date: *timestampWithoutTimeZone("1999-01-08 00:00:00", 0),
TimezPtr: nil,
Timez: *timeWithTimeZone("04:05:06 -0800"),
TimePtr: nil,
Time: *timeWithoutTimeZone("04:05:06"),
IntervalPtr: nil,
Interval: "3 days 04:05:06",
BooleanPtr: nil,
Boolean: false,
PointPtr: nil,
BitPtr: nil,
Bit: "101",
BitVaryingPtr: nil,
BitVarying: "101111",
TsvectorPtr: nil,
Tsvector: "'supernova':1",
UUIDPtr: nil,
UUID: uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"),
XMLPtr: nil,
XML: "<Sub>abc</Sub>",
JSONPtr: nil,
JSON: `{"a": 1, "b": 3}`,
JsonbPtr: nil,
Jsonb: `{"a": 1, "b": 3}`,
IntegerArrayPtr: nil,
IntegerArray: "{1,2,3}",
TextArrayPtr: nil,
TextArray: "{breakfast,consulting}",
JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`,
TextMultiDimArrayPtr: nil,
TextMultiDimArray: "{{meeting,lunch},{training,presentation}}",
}

View file

@ -1,81 +0,0 @@
package types
import (
"database/sql/driver"
"encoding/json"
"errors"
)
// JSONText is a json.RawMessage, which is a []byte underneath.
// Value() validates the json format in the source, and returns an error if
// the json is not valid. Scan does no validation. JSONText additionally
// implements `Unmarshal`, which unmarshals the json within to an interface{}
type JSONText json.RawMessage
var emptyJSON = JSONText("{}")
// MarshalJSON returns the *j as the JSON encoding of j.
func (j JSONText) MarshalJSON() ([]byte, error) {
if len(j) == 0 {
return emptyJSON, nil
}
return j, nil
}
// UnmarshalJSON sets *j to a copy of data
func (j *JSONText) UnmarshalJSON(data []byte) error {
if j == nil {
return errors.New("JSONText: UnmarshalJSON on nil pointer")
}
*j = append((*j)[0:0], data...)
return nil
}
// Value returns j as a value. This does a validating unmarshal into another
// RawMessage. If j is invalid json, it returns an error.
func (j JSONText) Value() (driver.Value, error) {
var m json.RawMessage
var err = j.Unmarshal(&m)
if err != nil {
return []byte{}, err
}
return []byte(j), nil
}
// Scan stores the src in *j. No validation is done.
func (j *JSONText) Scan(src interface{}) error {
if j == nil {
return errors.New("JSONText: Scan on nil pointer")
}
var source []byte
switch t := src.(type) {
case string:
source = []byte(t)
case []byte:
if len(t) == 0 {
source = emptyJSON
} else {
source = t
}
case nil:
*j = emptyJSON
default:
return errors.New("Incompatible type for JSONText")
}
*j = JSONText(append((*j)[0:0], source...))
return nil
}
// Unmarshal unmarshal's the json in j to v, as in json.Unmarshal.
func (j *JSONText) Unmarshal(v interface{}) error {
if len(*j) == 0 {
*j = emptyJSON
}
return json.Unmarshal([]byte(*j), v)
}
// String supports pretty printing for JSONText types.
func (j JSONText) String() string {
return string(j)
}