Tests clean up.

This commit is contained in:
go-jet 2019-06-11 12:47:35 +02:00
parent ffba8718ca
commit 367602757f
20 changed files with 46932 additions and 178 deletions

View file

@ -46,7 +46,7 @@ func main() {
SslMode: sslmode, SslMode: sslmode,
Params: params, Params: params,
DbName: dbName, DBName: dbName,
SchemaName: schemaName, SchemaName: schemaName,
} }

View file

@ -18,14 +18,14 @@ type GeneratorData struct {
SslMode string SslMode string
Params string Params string
DbName string DBName string
SchemaName string SchemaName string
} }
func Generate(destDir string, genData GeneratorData) error { func Generate(destDir string, genData GeneratorData) error {
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s", connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
genData.Host, genData.Port, genData.User, genData.Password, genData.DbName, genData.SslMode, genData.Params) genData.Host, genData.Port, genData.User, genData.Password, genData.DBName, genData.SslMode, genData.Params)
db, err := sql.Open("postgres", connectionString) db, err := sql.Open("postgres", connectionString)
if err != nil { if err != nil {
@ -39,13 +39,13 @@ func Generate(destDir string, genData GeneratorData) error {
return err return err
} }
err = cleanUpGeneratedFiles(path.Join(destDir, genData.DbName, genData.SchemaName)) err = cleanUpGeneratedFiles(path.Join(destDir, genData.DBName, genData.SchemaName))
if err != nil { if err != nil {
return err return err
} }
schemaInfo, err := postgres_metadata.GetSchemaInfo(db, genData.DbName, genData.SchemaName) schemaInfo, err := postgres_metadata.GetSchemaInfo(db, genData.DBName, genData.SchemaName)
if err != nil { if err != nil {
return err return err

View file

@ -2,7 +2,10 @@ package sqlbuilder
import ( import (
"bytes" "bytes"
"github.com/google/uuid"
"strconv" "strconv"
"strings"
"time"
) )
type serializeOption int type serializeOption int
@ -175,6 +178,10 @@ func (q *queryData) reset() {
} }
func ArgToString(value interface{}) string { func ArgToString(value interface{}) string {
if isNil(value) {
return "NULL"
}
switch bindVal := value.(type) { switch bindVal := value.(type) {
case bool: case bool:
if bindVal { if bindVal {
@ -210,11 +217,18 @@ func ArgToString(value interface{}) string {
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case string: case string:
return `'` + bindVal + `'` return stringQuote(bindVal)
case []byte: case []byte:
return `'` + string(bindVal) + `'` return stringQuote(string(bindVal))
//TODO: implement case uuid.UUID:
return stringQuote(bindVal.String())
case time.Time:
return stringQuote(bindVal.String())
default: default:
return "[Unknown type]" return "[Unknown type]"
} }
} }
func stringQuote(value string) string {
return `'` + strings.Replace(value, "'", "''", -1) + `'`
}

View file

@ -57,7 +57,7 @@ func (c *columnImpl) defaultAlias() string {
return c.name return c.name
} }
func (c *columnImpl) serializeAsOrderBy(statement statementType, out *queryData) error { func (c *columnImpl) serializeForOrderBy(statement statementType, out *queryData) error {
if statement == set_statement { if statement == set_statement {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause // set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
columnRef := "" columnRef := ""

View file

@ -141,7 +141,7 @@ func (e *expressionInterfaceImpl) serializeForProjection(statement statementType
return e.parent.serialize(statement, out, NO_WRAP) return e.parent.serialize(statement, out, NO_WRAP)
} }
func (e *expressionInterfaceImpl) serializeAsOrderBy(statement statementType, out *queryData) error { func (e *expressionInterfaceImpl) serializeForOrderBy(statement statementType, out *queryData) error {
return e.parent.serialize(statement, out, NO_WRAP) return e.parent.serialize(statement, out, NO_WRAP)
} }

View file

@ -7,6 +7,7 @@ import (
func TestExpressionIS_NULL(t *testing.T) { func TestExpressionIS_NULL(t *testing.T) {
assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL") assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL")
assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL") assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL")
assertClauseSerializeErr(t, table2Col3.ADD(nil), "nil rhs.")
} }
func TestExpressionIS_NOT_NULL(t *testing.T) { func TestExpressionIS_NOT_NULL(t *testing.T) {

View file

@ -14,12 +14,12 @@ type InsertStatement interface {
// Add a row of values to the insert Statement. // Add a row of values to the insert Statement.
VALUES(values ...interface{}) InsertStatement VALUES(values ...interface{}) InsertStatement
// Map or stracture mapped to column names // Model structure mapped to column names
VALUES_MAPPING(data interface{}) InsertStatement MODEL(data interface{}) InsertStatement
RETURNING(projections ...projection) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement QUERY(selectStatement SelectStatement) InsertStatement
RETURNING(projections ...projection) InsertStatement
} }
func newInsertStatement(t WritableTable, columns ...Column) InsertStatement { func newInsertStatement(t WritableTable, columns ...Column) InsertStatement {
@ -39,15 +39,14 @@ type insertStatementImpl struct {
errors []string errors []string
} }
func (s *insertStatementImpl) Query(db execution.Db, destination interface{}) error { func (i *insertStatementImpl) Query(db execution.Db, destination interface{}) error {
return Query(s, db, destination) return Query(i, db, destination)
} }
func (u *insertStatementImpl) Execute(db execution.Db) (res sql.Result, err error) { func (i *insertStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
return Execute(u, db) return Execute(i, db)
} }
// Expression or default keyword
func (i *insertStatementImpl) VALUES(values ...interface{}) InsertStatement { func (i *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {
if len(values) == 0 { if len(values) == 0 {
return i return i
@ -67,20 +66,16 @@ func (i *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {
return i return i
} }
func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement { func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
if data == nil { if data == nil {
i.addError("ADD method data is nil.") i.addError("MODEL : data is nil.")
return i return i
} }
value := reflect.ValueOf(data) value := reflect.Indirect(reflect.ValueOf(data))
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
if value.Kind() != reflect.Struct { if value.Kind() != reflect.Struct {
i.addError("ADD method data is not struct or pointer to struct.") i.addError("MODEL : data is not struct or pointer to struct.")
return i return i
} }
@ -93,11 +88,21 @@ func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement {
structField := value.FieldByName(structFieldName) structField := value.FieldByName(structFieldName)
if !structField.IsValid() { if !structField.IsValid() {
i.addError("ADD() : Data structure doesn't contain field : " + structFieldName + " for column " + columnName) i.addError("MODEL : Data structure doesn't contain field for column " + columnName)
return i return i
} }
rowValues = append(rowValues, literal(structField.Interface())) var field interface{}
fieldValue := reflect.Indirect(structField)
if fieldValue.IsValid() {
field = fieldValue.Interface()
} else {
field = nil
}
rowValues = append(rowValues, literal(field))
} }
i.rows = append(i.rows, rowValues) i.rows = append(i.rows, rowValues)
@ -106,15 +111,12 @@ func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement {
} }
func (i *insertStatementImpl) RETURNING(projections ...projection) InsertStatement { func (i *insertStatementImpl) RETURNING(projections ...projection) InsertStatement {
//i.returning = defaultProjectionAliasing(projections)
i.returning = projections i.returning = projections
return i return i
} }
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
i.query = selectStatement i.query = selectStatement
return i return i
} }
@ -126,9 +128,9 @@ func (i *insertStatementImpl) DebugSql() (query string, err error) {
return DebugSql(i) return DebugSql(i)
} }
func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) { func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) {
if len(s.errors) > 0 { if len(i.errors) > 0 {
return "", nil, errors.New("sql builder errors: " + strings.Join(s.errors, ", ")) return "", nil, errors.New("errors: " + strings.Join(i.errors, ", "))
} }
queryData := &queryData{} queryData := &queryData{}
@ -136,42 +138,40 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
queryData.nextLine() queryData.nextLine()
queryData.writeString("INSERT INTO") queryData.writeString("INSERT INTO")
if s.table == nil { if isNil(i.table) {
return "", nil, errors.New("nil tableName.") return "", nil, errors.New("table is nil")
} }
err = s.table.serialize(insert_statement, queryData) err = i.table.serialize(insert_statement, queryData)
queryData.writeByte(' ')
if err != nil { if err != nil {
return "", nil, err return
} }
if len(s.columns) > 0 { if len(i.columns) > 0 {
queryData.writeString("(") queryData.writeString("(")
err = serializeColumnList(insert_statement, s.columns, queryData) err = serializeColumnList(insert_statement, i.columns, queryData)
if err != nil { if err != nil {
return "", nil, err return
} }
queryData.writeString(")") queryData.writeString(")")
} }
if len(s.rows) == 0 && s.query == nil { if len(i.rows) == 0 && i.query == nil {
return "", nil, errors.New("No row or query specified.") return "", nil, errors.New("no row values or query specified")
} }
if len(s.rows) > 0 && s.query != nil { if len(i.rows) > 0 && i.query != nil {
return "", nil, errors.New("Only new rows or query has to be specified.") return "", nil, errors.New("only row values or query has to be specified")
} }
if len(s.rows) > 0 { if len(i.rows) > 0 {
queryData.writeString("VALUES") queryData.writeString("VALUES")
for row_i, row := range s.rows { for row_i, row := range i.rows {
if row_i > 0 { if row_i > 0 {
queryData.writeString(",") queryData.writeString(",")
} }
@ -180,8 +180,8 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
queryData.nextLine() queryData.nextLine()
queryData.writeString("(") queryData.writeString("(")
if len(row) != len(s.columns) { if len(row) != len(i.columns) {
return "", nil, errors.New("# of values does not match # of columns.") return "", nil, errors.New("number of values does not match number of columns")
} }
err = serializeClauseList(insert_statement, row, queryData) err = serializeClauseList(insert_statement, row, queryData)
@ -195,19 +195,19 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
} }
} }
if s.query != nil { if i.query != nil {
err = s.query.serialize(insert_statement, queryData) err = i.query.serialize(insert_statement, queryData)
if err != nil { if err != nil {
return return
} }
} }
if len(s.returning) > 0 { if len(i.returning) > 0 {
queryData.nextLine() queryData.nextLine()
queryData.writeString("RETURNING") queryData.writeString("RETURNING")
err = queryData.writeProjections(insert_statement, s.returning) err = queryData.writeProjections(insert_statement, i.returning)
if err != nil { if err != nil {
return return

View file

@ -105,28 +105,28 @@ INSERT INTO db.table1 (col1,colFloat) VALUES
func TestInsertValuesFromModel(t *testing.T) { func TestInsertValuesFromModel(t *testing.T) {
type Table1Model struct { type Table1Model struct {
Col1 int Col1 *int
ColFloat float64 ColFloat float64
} }
one := 1
toInsert := Table1Model{ toInsert := Table1Model{
Col1: 1, Col1: &one,
ColFloat: 1.11, ColFloat: 1.11,
} }
stmt := table1.INSERT(table1Col1, table1ColFloat). stmt := table1.INSERT(table1Col1, table1ColFloat).
VALUES_MAPPING(toInsert) MODEL(toInsert).
MODEL(&toInsert)
sql, _, err := stmt.Sql() expectedSql := `
assert.NilError(t, err)
fmt.Println(sql)
assert.Equal(t, sql, `
INSERT INTO db.table1 (col1,colFloat) VALUES INSERT INTO db.table1 (col1,colFloat) VALUES
($1, $2); ($1, $2),
`) ($3, $4);
`
assertQuery(t, stmt, expectedSql, int(1), float64(1.11), int(1), float64(1.11))
} }
func TestInsertValuesFromModelColumnMismatch(t *testing.T) { func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
@ -141,11 +141,11 @@ func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
} }
stmt := table1.INSERT(table1Col1, table1ColFloat). stmt := table1.INSERT(table1Col1, table1ColFloat).
VALUES_MAPPING(toInsert) MODEL(toInsert)
_, _, err := stmt.Sql() _, _, err := stmt.Sql()
//fmt.Println(err) fmt.Println(err)
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
} }

View file

@ -3,7 +3,7 @@ package sqlbuilder
import "errors" import "errors"
type OrderByClause interface { type OrderByClause interface {
serializeAsOrderBy(statement statementType, out *queryData) error serializeForOrderBy(statement statementType, out *queryData) error
} }
type orderByClauseImpl struct { type orderByClauseImpl struct {
@ -11,12 +11,12 @@ type orderByClauseImpl struct {
ascent bool ascent bool
} }
func (o *orderByClauseImpl) serializeAsOrderBy(statement statementType, out *queryData) error { func (o *orderByClauseImpl) serializeForOrderBy(statement statementType, out *queryData) error {
if o.expression == nil { if o.expression == nil {
return errors.New("nil orderBy by clause.") return errors.New("nil orderBy by clause.")
} }
if err := o.expression.serializeAsOrderBy(statement, out); err != nil { if err := o.expression.serializeForOrderBy(statement, out); err != nil {
return err return err
} }

View file

@ -57,7 +57,7 @@ type setStatementImpl struct {
selects []rowsType selects []rowsType
orderBy []OrderByClause orderBy []OrderByClause
limit, offset int64 limit, offset int64
// True if results of the union should be deduped.
all bool all bool
} }
@ -76,7 +76,6 @@ func newSetStatementImpl(operator string, all bool, selects ...rowsType) SetStat
} }
func (us *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement { func (us *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement {
us.orderBy = orderBy us.orderBy = orderBy
return us return us
} }
@ -100,7 +99,9 @@ func (s *setStatementImpl) serialize(statement statementType, out *queryData, op
return errors.New("Set expression is nil. ") return errors.New("Set expression is nil. ")
} }
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 { wrap := s.orderBy != nil || s.limit >= 0 || s.offset >= 0
if wrap {
out.writeString("(") out.writeString("(")
out.increaseIdent() out.increaseIdent()
} }
@ -111,7 +112,7 @@ func (s *setStatementImpl) serialize(statement statementType, out *queryData, op
return err return err
} }
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 { if wrap {
out.decreaseIdent() out.decreaseIdent()
out.nextLine() out.nextLine()
out.writeString(")") out.writeString(")")

View file

@ -162,3 +162,123 @@ func TestUnionInUnion(t *testing.T) {
assert.Equal(t, len(args), 0) assert.Equal(t, len(args), 0)
assert.Equal(t, queryStr, expectedSql) assert.Equal(t, queryStr, expectedSql)
} }
func TestUnionALL(t *testing.T) {
query, args, err := UNION_ALL(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestINTERSECT(t *testing.T) {
query, args, err := INTERSECT(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
INTERSECT
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestINTERSECT_ALL(t *testing.T) {
query, args, err := INTERSECT_ALL(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
INTERSECT ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestEXCEPT(t *testing.T) {
query, args, err := EXCEPT(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
EXCEPT
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}
func TestEXCEPT_ALL(t *testing.T) {
query, args, err := EXCEPT_ALL(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
EXCEPT ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
}

View file

@ -80,3 +80,10 @@ func assertProjectionSerialize(t *testing.T, projection projection, query string
assert.DeepEqual(t, out.buff.String(), query) assert.DeepEqual(t, out.buff.String(), query)
assert.DeepEqual(t, out.args, args) assert.DeepEqual(t, out.args, args)
} }
func assertQuery(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) {
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, queryStr, expectedQuery)
assert.DeepEqual(t, args, expectedArgs)
}

View file

@ -14,7 +14,7 @@ func serializeOrderByClauseList(statement statementType, orderByClauses []OrderB
out.writeString(", ") out.writeString(", ")
} }
err := value.serializeAsOrderBy(statement, out) err := value.serializeForOrderBy(statement, out)
if err != nil { if err != nil {
return err return err

View file

@ -0,0 +1,13 @@
package dbconfig
import "fmt"
const (
Host = "localhost"
Port = 5432
User = "postgres"
Password = "postgres"
DBName = "dvd_rental"
)
var ConnectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName)

46553
tests/init/data/dvds.sql Normal file

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,6 @@
-- AllTypes table -----------------------------
DROP TABLE IF EXISTS test_sample.all_types; DROP TABLE IF EXISTS test_sample.all_types;
CREATE TABLE test_sample.all_types CREATE TABLE test_sample.all_types
@ -135,4 +137,55 @@ VALUES (1, 1, 300, 300, 50000, 5000, 11.44, 11.44, 55.77, 55.77, 99.1, 99.1, 111
NULL, '{1, 2, 3}', NULL, '{"breakfast", "consulting"}', ARRAY['{"a": 1, "b": 2}'::jsonb, '{"a":3, "b": 4}'::jsonb], NULL, '{{"meeting", "lunch"}, {"training", "presentation"}}') NULL, '{1, 2, 3}', NULL, '{"breakfast", "consulting"}', ARRAY['{"a": 1, "b": 2}'::jsonb, '{"a":3, "b": 4}'::jsonb], NULL, '{{"meeting", "lunch"}, {"training", "presentation"}}')
; ;
-- Link table --------------------
DROP TABLE IF EXISTS test_sample.link;
CREATE TABLE IF NOT EXISTS test_sample.link (
ID serial PRIMARY KEY,
url VARCHAR (255) NOT NULL,
name VARCHAR (255) NOT NULL,
description VARCHAR (255),
rel VARCHAR (50)
);
-- Employee table ---------------
DROP TABLE IF EXISTS test_sample.employee;
CREATE TABLE test_sample.employee (
employee_id INT PRIMARY KEY,
first_name VARCHAR (255) NOT NULL,
last_name VARCHAR (255) NOT NULL,
manager_id INT,
FOREIGN KEY (manager_id)
REFERENCES test_sample.employee (employee_id)
ON DELETE CASCADE
);
INSERT INTO test_sample.employee (
employee_id,
first_name,
last_name,
manager_id
)
VALUES
(1, 'Windy', 'Hays', NULL),
(2, 'Ava', 'Christensen', 1),
(3, 'Hassan', 'Conner', 1),
(4, 'Anna', 'Reeves', 2),
(5, 'Sau', 'Norman', 2),
(6, 'Kelsie', 'Hays', 3),
(7, 'Tory', 'Goff', 3),
(8, 'Salley', 'Lester', 3);
-- Person table ------------------
DROP TABLE IF EXISTS test_sample.person;
CREATE TABLE test_sample.person(
person_id uuid,
first_name varchar(100),
last_name varchar(100)
)

60
tests/init/init.go Normal file
View file

@ -0,0 +1,60 @@
package main
import (
"database/sql"
"github.com/go-jet/jet/generator"
"github.com/go-jet/jet/tests/dbconfig"
"io/ioutil"
)
func main() {
db, err := sql.Open("postgres", dbconfig.ConnectString)
if err != nil {
panic("Failed to connect to test db")
}
defer db.Close()
testSampleSql, err := ioutil.ReadFile("./init/data/test_sample.sql")
panicOnError(err)
_, err = db.Exec(string(testSampleSql))
panicOnError(err)
dvdsSql, err := ioutil.ReadFile("./init/data/dvds.sql")
panicOnError(err)
_, err = db.Exec(string(dvdsSql))
panicOnError(err)
err = generator.Generate("./.test_files", generator.GeneratorData{
Host: dbconfig.Host,
Port: "5432",
User: dbconfig.User,
Password: dbconfig.Password,
DBName: dbconfig.DBName,
SchemaName: "dvds",
})
panicOnError(err)
err = generator.Generate("./.test_files", generator.GeneratorData{
Host: dbconfig.Host,
Port: "5432",
User: dbconfig.User,
Password: dbconfig.Password,
DBName: dbconfig.DBName,
SchemaName: "test_sample",
})
panicOnError(err)
}
func panicOnError(err error) {
if err != nil {
panic(err)
}
}

View file

@ -3,20 +3,20 @@ package tests
import ( import (
"fmt" "fmt"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/go-jet/jet/sqlbuilder" . "github.com/go-jet/jet/sqlbuilder"
"github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/model" "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/model"
"github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/table" . "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/table"
"gotest.tools/assert" "gotest.tools/assert"
"testing" "testing"
) )
func TestInsertValues(t *testing.T) { func TestInsertValues(t *testing.T) {
insertQuery := table.Link.INSERT(table.Link.URL, table.Link.Name, table.Link.Rel). insertQuery := Link.INSERT(Link.URL, Link.Name, Link.Rel).
VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", sqlbuilder.DEFAULT). VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES("http://www.google.com", "Google", sqlbuilder.DEFAULT). VALUES("http://www.google.com", "Google", DEFAULT).
VALUES("http://www.yahoo.com", "Yahoo", sqlbuilder.DEFAULT). VALUES("http://www.yahoo.com", "Yahoo", DEFAULT).
VALUES("http://www.bing.com", "Bing", sqlbuilder.DEFAULT). VALUES("http://www.bing.com", "Bing", DEFAULT).
RETURNING(table.Link.ID) RETURNING(Link.ID)
insertQueryStr, args, err := insertQuery.Sql() insertQueryStr, args, err := insertQuery.Sql()
@ -44,7 +44,7 @@ RETURNING link.id AS "link.id";
link := []model.Link{} link := []model.Link{}
err = table.Link.SELECT(table.Link.AllColumns).Query(db, &link) err = Link.SELECT(Link.AllColumns).Query(db, &link)
assert.NilError(t, err) assert.NilError(t, err)
@ -72,9 +72,9 @@ func TestInsertDataObject(t *testing.T) {
Rel: nil, Rel: nil,
} }
query := table.Link. query := Link.
INSERT(table.Link.URL, table.Link.Name). INSERT(Link.URL, Link.Name).
VALUES_MAPPING(linkData) MODEL(linkData)
queryStr, args, err := query.Sql() queryStr, args, err := query.Sql()
@ -92,14 +92,14 @@ func TestInsertDataObject(t *testing.T) {
func TestInsertQuery(t *testing.T) { func TestInsertQuery(t *testing.T) {
_, err := table.Link.INSERT(table.Link.URL, table.Link.Name). _, err := Link.INSERT(Link.URL, Link.Name).
VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial").Execute(db) VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial").Execute(db)
assert.NilError(t, err) assert.NilError(t, err)
query := table.Link. query := Link.
INSERT(table.Link.URL, table.Link.Name). INSERT(Link.URL, Link.Name).
QUERY(table.Link.SELECT(table.Link.URL, table.Link.Name)) QUERY(Link.SELECT(Link.URL, Link.Name))
queryStr, args, err := query.Sql() queryStr, args, err := query.Sql()
@ -113,7 +113,7 @@ func TestInsertQuery(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
allLinks := []model.Link{} allLinks := []model.Link{}
err = table.Link.SELECT(table.Link.AllColumns).Query(db, &allLinks) err = Link.SELECT(Link.AllColumns).Query(db, &allLinks)
assert.NilError(t, err) assert.NilError(t, err)
spew.Dump(allLinks) spew.Dump(allLinks)

View file

@ -2,8 +2,8 @@ package tests
import ( import (
"database/sql" "database/sql"
"fmt"
"github.com/go-jet/jet/tests/.test_files/dvd_rental/dvds/model" "github.com/go-jet/jet/tests/.test_files/dvd_rental/dvds/model"
"github.com/go-jet/jet/tests/dbconfig"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/pkg/profile" "github.com/pkg/profile"
"gotest.tools/assert" "gotest.tools/assert"
@ -12,103 +12,23 @@ import (
"testing" "testing"
) )
const (
host = "localhost"
port = 5432
user = "postgres"
password = "postgres"
dbname = "dvd_rental"
)
var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname)
var db *sql.DB var db *sql.DB
//var tx *sql.Tx
//go:generate generator -host=localhost -port=5432 -user=postgres -password=postgres -dbname=dvd_rental -schema dvds -path .test_files
//go:generate generator -host=localhost -port=5432 -user=postgres -password=postgres -dbname=dvd_rental -sslmode=disable -schema test_sample -path .test_files
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
fmt.Println("Begin")
defer profile.Start().Stop() defer profile.Start().Stop()
var err error var err error
db, err = sql.Open("postgres", connectString) db, err = sql.Open("postgres", dbconfig.ConnectString)
if err != nil { if err != nil {
panic("Failed to connect to test db") panic("Failed to connect to test db")
} }
//tx, _ = db.Begin() defer db.Close()
defer cleanUp()
dbInit()
ret := m.Run() ret := m.Run()
cleanUp()
fmt.Println("END")
os.Exit(ret) os.Exit(ret)
} }
func cleanUp() {
fmt.Println("CLEAN UP")
//tx.Rollback()
db.Close()
}
func dbInit() {
linkTableCreate := `
DROP TABLE IF EXISTS test_sample.link;
CREATE TABLE IF NOT EXISTS test_sample.link (
ID serial PRIMARY KEY,
url VARCHAR (255) NOT NULL,
name VARCHAR (255) NOT NULL,
description VARCHAR (255),
rel VARCHAR (50)
);
DROP TABLE IF EXISTS test_sample.employee;
CREATE TABLE test_sample.employee (
employee_id INT PRIMARY KEY,
first_name VARCHAR (255) NOT NULL,
last_name VARCHAR (255) NOT NULL,
manager_id INT,
FOREIGN KEY (manager_id)
REFERENCES test_sample.employee (employee_id)
ON DELETE CASCADE
);
INSERT INTO test_sample.employee (
employee_id,
first_name,
last_name,
manager_id
)
VALUES
(1, 'Windy', 'Hays', NULL),
(2, 'Ava', 'Christensen', 1),
(3, 'Hassan', 'Conner', 1),
(4, 'Anna', 'Reeves', 2),
(5, 'Sau', 'Norman', 2),
(6, 'Kelsie', 'Hays', 3),
(7, 'Tory', 'Goff', 3),
(8, 'Salley', 'Lester', 3);
`
result, err := db.Exec(linkTableCreate)
if err != nil {
panic(err)
}
fmt.Println(result)
}
func TestGenerateModel(t *testing.T) { func TestGenerateModel(t *testing.T) {
actor := model.Actor{} actor := model.Actor{}

View file

@ -25,6 +25,18 @@ func TestAllTypesSelect(t *testing.T) {
assert.DeepEqual(t, dest[1], allTypesRow1) assert.DeepEqual(t, dest[1], allTypesRow1)
} }
func TestAllTypesInsert(t *testing.T) {
query := AllTypes.INSERT(AllTypes.AllColumns...).
MODEL(allTypesRow0).
MODEL(&allTypesRow1)
_, err := query.Execute(db)
assert.NilError(t, err)
fmt.Println(query.DebugSql())
}
func TestExpressionOperators(t *testing.T) { func TestExpressionOperators(t *testing.T) {
query := AllTypes.SELECT( query := AllTypes.SELECT(
AllTypes.Integer.IS_NULL(), AllTypes.Integer.IS_NULL(),