Add support for INSERT statements.

This commit is contained in:
zer0sub 2019-04-07 09:58:12 +02:00
parent d84deb8745
commit 599a8c537a
15 changed files with 586 additions and 277 deletions

View file

@ -19,7 +19,7 @@ func TestBinaryExpression(t *testing.T) {
alias := boolExpression.As("alias_eq_expression")
out := bytes.Buffer{}
err := alias.SerializeSql(&out)
err := alias.SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `2 = 3 AS "alias_eq_expression"`)
@ -59,7 +59,7 @@ func TestUnaryExpression(t *testing.T) {
alias := notExpression.As("alias_not_expression")
out := bytes.Buffer{}
err := alias.SerializeSql(&out)
err := alias.SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), ` NOT 2 = 1 AS "alias_not_expression"`)

View file

@ -31,7 +31,7 @@ func TestNewBoolColumn(t *testing.T) {
err = boolColumn.setTableName("table1")
assert.NilError(t, err)
aliasedBoolColumn := boolColumn.As("alias1")
err = aliasedBoolColumn.SerializeSql(&out, FOR_PROJECTION)
err = aliasedBoolColumn.SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`)
}
@ -61,7 +61,7 @@ func TestNewIntColumn(t *testing.T) {
err = integerColumn.setTableName("table1")
assert.NilError(t, err)
aliasedBoolColumn := integerColumn.As("alias1")
err = aliasedBoolColumn.SerializeSql(&out, FOR_PROJECTION)
err = aliasedBoolColumn.SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`)
}
@ -76,7 +76,7 @@ func TestNewNumericColumnColumn(t *testing.T) {
assert.Equal(t, out.String(), "col")
out.Reset()
err = numericColumn.SerializeSql(&out, FOR_PROJECTION)
err = numericColumn.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
@ -91,7 +91,7 @@ func TestNewNumericColumnColumn(t *testing.T) {
err = numericColumn.setTableName("table1")
assert.NilError(t, err)
aliasedBoolColumn := numericColumn.As("alias1")
err = aliasedBoolColumn.SerializeSql(&out, FOR_PROJECTION)
err = aliasedBoolColumn.SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`)
}

View file

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"github.com/serenize/snaker"
"github.com/sub0Zero/go-sqlbuilder/types"
"reflect"
"regexp"
"strconv"
@ -13,7 +14,7 @@ import (
"time"
)
func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
func Execute(db types.Db, query string, destinationPtr interface{}) error {
if db == nil {
return errors.New("db is nil")
}

View file

@ -0,0 +1,225 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/serenize/snaker"
"github.com/sub0Zero/go-sqlbuilder/types"
"reflect"
"strings"
)
type InsertStatement interface {
Statement
// Add a row of values to the insert statement.
VALUES(values ...interface{}) InsertStatement
// Map or stracture mapped to column names
VALUES_MAPPING(data interface{}) InsertStatement
RETURNING(column ...Expression) InsertStatement
Execute(db types.Db) (sql.Result, error)
}
func newInsertStatement(t WritableTable, columns ...Column) InsertStatement {
return &insertStatementImpl{
table: t,
columns: columns,
rows: make([][]Expression, 0, 1),
returning: make([]Expression, 0, 1),
}
}
type columnAssignment struct {
col Column
expr Expression
}
type insertStatementImpl struct {
table WritableTable
columns []Column
rows [][]Expression
returning []Expression
errors []string
}
func (i *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
query, err := i.String()
if err != nil {
return
}
res, err = db.Exec(query)
return
}
//func (i *insertStatementImpl) ExecuteInTx(tx *sql.Tx) (res sql.Result, err error) {
// query, err := i.String()
//
// if err != nil {
// return
// }
//
// res, err = tx.Exec(query)
//
// return
//}
func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {
literalRow := []Expression{}
for _, value := range values {
literalRow = append(literalRow, Literal(value))
}
s.rows = append(s.rows, literalRow)
return s
}
func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement {
if data == nil {
i.addError("Add method data is nil.")
return i
}
value := reflect.ValueOf(data)
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
if value.Kind() != reflect.Struct {
i.addError("Add method data is not struct or pointer to struct.")
return i
}
rowValues := []Expression{}
for _, column := range i.columns {
columnName := column.Name()
structFieldName := snaker.SnakeToCamel(columnName)
structField := value.FieldByName(structFieldName)
if !structField.IsValid() {
i.addError("Add() : Data structure doesn't contain field : " + structFieldName + " for column " + columnName)
return i
}
rowValues = append(rowValues, Literal(structField.Interface()))
}
i.rows = append(i.rows, rowValues)
return i
}
func (i *insertStatementImpl) RETURNING(column ...Expression) InsertStatement {
i.returning = column
return i
}
func (i *insertStatementImpl) addError(err string) {
i.errors = append(i.errors, err)
}
func (s *insertStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("INSERT ")
_, _ = buf.WriteString("INTO ")
if len(s.errors) > 0 {
return "", errors.New("sql builder errors: " + strings.Join(s.errors, ", "))
}
if s.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
buf.WriteString(s.table.SchemaName() + "." + s.table.TableName())
if len(s.columns) == 0 {
return "", errors.Newf(
"No column specified. Generated sql: %s",
buf.String())
}
_, _ = buf.WriteString(" (")
for i, col := range s.columns {
if i > 0 {
_ = buf.WriteByte(',')
}
if col == nil {
return "", errors.Newf(
"nil column in columns list. Generated sql: %s",
buf.String())
}
buf.WriteString(col.Name())
}
if len(s.rows) == 0 {
return "", errors.Newf(
"No row specified. Generated sql: %s",
buf.String())
}
_, _ = buf.WriteString(") VALUES (")
for row_i, row := range s.rows {
if row_i > 0 {
_, _ = buf.WriteString(", (")
}
if len(row) != len(s.columns) {
return "", errors.Newf(
"# of values does not match # of columns. Generated sql: %s",
buf.String())
}
for col_i, value := range row {
if col_i > 0 {
_ = buf.WriteByte(',')
}
if value == nil {
return "", errors.Newf(
"nil value in row %d col %d. Generated sql: %s",
row_i,
col_i,
buf.String())
}
if err = value.SerializeSql(buf); err != nil {
return
}
}
_ = buf.WriteByte(')')
}
if len(s.returning) > 0 {
buf.WriteString(" RETURNING ")
for i, column := range s.returning {
if i > 0 {
buf.WriteString(",")
}
err = column.SerializeSql(buf)
if err != nil {
return
}
}
}
buf.WriteByte(';')
return buf.String(), nil
}

View file

@ -0,0 +1,125 @@
package sqlbuilder
import (
"fmt"
"gotest.tools/assert"
"testing"
"time"
)
func TestInsertNoColumn(t *testing.T) {
_, err := table1.INSERT().VALUES().String()
assert.Assert(t, err != nil)
}
func TestInsertNoRow(t *testing.T) {
_, err := table1.INSERT(table1Col1).String()
assert.Assert(t, err != nil)
}
func TestInsertColumnLengthMismatch(t *testing.T) {
_, err := table1.INSERT(table1Col1, table1Col2).VALUES(nil).String()
fmt.Println(err)
assert.Assert(t, err != nil)
}
func TestInsertNilValue(t *testing.T) {
_, err := table1.INSERT(table1Col1).VALUES(nil).String()
assert.Assert(t, err != nil)
}
func TestInsertNilColumn(t *testing.T) {
_, err := table1.INSERT(nil).VALUES(1).String()
assert.Assert(t, err != nil)
}
func TestInsertSingleValue(t *testing.T) {
sql, err := table1.INSERT(table1Col1).VALUES(1).String()
assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 (col1) VALUES (1)")
}
func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
sql, err := table1.INSERT(table1Col4).VALUES(date).String()
assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 (col4) "+
"VALUES ('1999-01-02 03:04:05.000000')")
}
func TestInsertMultipleValues(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1Col2, table1Col3)
stmt.VALUES(1, 2, 3)
sql, err := stmt.String()
assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 "+
"(col1,col2,col3) "+
"VALUES (1,2,3)")
}
func TestInsertMultipleRows(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1Col2).
VALUES(1, 2).
VALUES(11, 22).
VALUES(111, 222)
sql, err := stmt.String()
assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 "+
"(col1,col2) "+
"VALUES (1,2), (11,22), (111,222)")
}
func TestInsertValuesFromModel(t *testing.T) {
type Table1Model struct {
Col1 int
Col2 string
}
toInsert := Table1Model{
Col1: 1,
Col2: "one",
}
stmt := table1.INSERT(table1Col1, table1Col2).
VALUES_MAPPING(toInsert)
sql, err := stmt.String()
assert.NilError(t, err)
fmt.Println(sql)
assert.Equal(t, sql, `INSERT INTO db.table1 (col1,col2) VALUES (1,'one')`)
}
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
type Table1Model struct {
Col1Prim int
Col2 string
}
toInsert := Table1Model{
Col1Prim: 1,
Col2: "one",
}
stmt := table1.INSERT(table1Col1, table1Col2).
VALUES_MAPPING(toInsert)
_, err := stmt.String()
fmt.Println(err)
assert.Assert(t, err != nil)
}

View file

@ -2,10 +2,10 @@ package sqlbuilder
import (
"bytes"
"database/sql"
"fmt"
"github.com/dropbox/godropbox/errors"
"github.com/sub0Zero/go-sqlbuilder/sqlbuilder/execution"
"github.com/sub0Zero/go-sqlbuilder/types"
"reflect"
)
@ -28,6 +28,9 @@ type SelectStatement interface {
Copy() SelectStatement
AsTable(alias string) *SelectStatementTable
Execute(db types.Db, destination interface{}) error
//ExecuteInTx(tx *sql.Tx, destination interface{}) error
}
// NOTE: SelectStatement purposely does not implement the Table interface since
@ -84,7 +87,7 @@ func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable {
}
}
func (s *selectStatementImpl) Execute(db *sql.DB, destination interface{}) error {
func (s *selectStatementImpl) Execute(db types.Db, destination interface{}) error {
destinationType := reflect.TypeOf(destination)
if destinationType.Kind() == reflect.Ptr && destinationType.Elem().Kind() == reflect.Struct {

View file

@ -8,6 +8,15 @@ type SelectStatementTable struct {
alias string
}
// Returns the tableName's name in the database
func (t *SelectStatementTable) SchemaName() string {
return ""
}
func (s *SelectStatementTable) TableName() string {
return s.alias
}
func (s *SelectStatementTable) Columns() []Column {
return s.columns
}

View file

@ -12,17 +12,6 @@ import (
type Statement interface {
// String returns generated SQL as string.
String() (sql string, err error)
Execute(db *sql.DB, destination interface{}) error
}
type InsertStatement interface {
Statement
// Add a row of values to the insert statement.
Add(row ...Expression) InsertStatement
AddOnDuplicateKeyUpdate(col Column, expr Expression) InsertStatement
Comment(comment string) InsertStatement
IgnoreDuplicates(ignore bool) InsertStatement
}
// By default, rows selected by a UNION statement are out-of-order
@ -261,183 +250,6 @@ func (us *unionStatementImpl) String() (sql string, err error) {
return buf.String(), nil
}
//
// INSERT Statement ============================================================
//
func newInsertStatement(
t WritableTable,
columns ...Column) InsertStatement {
return &insertStatementImpl{
table: t,
columns: columns,
rows: make([][]Expression, 0, 1),
onDuplicateKeyUpdates: make([]columnAssignment, 0, 0),
}
}
type columnAssignment struct {
col Column
expr Expression
}
type insertStatementImpl struct {
table WritableTable
columns []Column
rows [][]Expression
onDuplicateKeyUpdates []columnAssignment
comment string
ignore bool
}
func (i *insertStatementImpl) Execute(db *sql.DB, data interface{}) error {
return nil
}
func (s *insertStatementImpl) Add(
row ...Expression) InsertStatement {
s.rows = append(s.rows, row)
return s
}
func (s *insertStatementImpl) AddOnDuplicateKeyUpdate(
col Column,
expr Expression) InsertStatement {
s.onDuplicateKeyUpdates = append(
s.onDuplicateKeyUpdates,
columnAssignment{col, expr})
return s
}
func (s *insertStatementImpl) IgnoreDuplicates(ignore bool) InsertStatement {
s.ignore = ignore
return s
}
func (s *insertStatementImpl) Comment(comment string) InsertStatement {
s.comment = comment
return s
}
func (s *insertStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("INSERT ")
if s.ignore {
_, _ = buf.WriteString("IGNORE ")
}
_, _ = buf.WriteString("INTO ")
if err = writeComment(s.comment, buf); err != nil {
return
}
if s.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = s.table.SerializeSql(buf); err != nil {
return
}
if len(s.columns) == 0 {
return "", errors.Newf(
"No column specified. Generated sql: %s",
buf.String())
}
_, _ = buf.WriteString(" (")
for i, col := range s.columns {
if i > 0 {
_ = buf.WriteByte(',')
}
if col == nil {
return "", errors.Newf(
"nil column in columns list. Generated sql: %s",
buf.String())
}
if err = col.SerializeSql(buf, FOR_PROJECTION); err != nil {
return
}
}
if len(s.rows) == 0 {
return "", errors.Newf(
"No row specified. Generated sql: %s",
buf.String())
}
_, _ = buf.WriteString(") VALUES (")
for row_i, row := range s.rows {
if row_i > 0 {
_, _ = buf.WriteString(", (")
}
if len(row) != len(s.columns) {
return "", errors.Newf(
"# of values does not match # of columns. Generated sql: %s",
buf.String())
}
for col_i, value := range row {
if col_i > 0 {
_ = buf.WriteByte(',')
}
if value == nil {
return "", errors.Newf(
"nil value in row %d col %d. Generated sql: %s",
row_i,
col_i,
buf.String())
}
if err = value.SerializeSql(buf); err != nil {
return
}
}
_ = buf.WriteByte(')')
}
if len(s.onDuplicateKeyUpdates) > 0 {
_, _ = buf.WriteString(" ON DUPLICATE KEY UPDATE ")
for i, colExpr := range s.onDuplicateKeyUpdates {
if i > 0 {
_, _ = buf.WriteString(", ")
}
if colExpr.col == nil {
return "", errors.Newf(
"nil column in on duplicate key update list. "+"Generated sql: %s",
buf.String())
}
if err = colExpr.col.SerializeSql(buf, FOR_PROJECTION); err != nil {
return
}
_ = buf.WriteByte('=')
if colExpr.expr == nil {
return "", errors.Newf(
"nil expression in on duplicate key update list. "+"Generated sql: %s",
buf.String())
}
if err = colExpr.expr.SerializeSql(buf); err != nil {
return
}
}
}
return buf.String(), nil
}
//
// UPDATE statement ===========================================================
//

View file

@ -237,37 +237,37 @@ func (s *StmtSuite) TestSelectDistinct(c *gc.C) {
//
func (s *StmtSuite) TestInsertNoColumn(c *gc.C) {
_, err := table1.Insert().Add().String()
_, err := table1.INSERT().Add().String()
c.Assert(err, gc.NotNil)
}
func (s *StmtSuite) TestInsertNoRow(c *gc.C) {
_, err := table1.Insert(table1Col1).String()
_, err := table1.INSERT(table1Col1).String()
c.Assert(err, gc.NotNil)
}
func (s *StmtSuite) TestInsertColumnLengthMismatch(c *gc.C) {
_, err := table1.Insert(table1Col1, table1Col2).Add(nil).String()
_, err := table1.INSERT(table1Col1, table1Col2).Add(nil).String()
c.Assert(err, gc.NotNil)
}
func (s *StmtSuite) TestInsertNilValue(c *gc.C) {
_, err := table1.Insert(table1Col1).Add(nil).String()
_, err := table1.INSERT(table1Col1).Add(nil).String()
c.Assert(err, gc.NotNil)
}
func (s *StmtSuite) TestInsertNilColumn(c *gc.C) {
_, err := table1.Insert(nil).Add(Literal(1)).String()
_, err := table1.INSERT(nil).Add(Literal(1)).String()
c.Assert(err, gc.NotNil)
}
func (s *StmtSuite) TestInsertSingleValue(c *gc.C) {
sql, err := table1.Insert(table1Col1).Add(Literal(1)).String()
sql, err := table1.INSERT(table1Col1).Add(Literal(1)).String()
c.Assert(err, gc.IsNil)
c.Assert(
@ -279,7 +279,7 @@ func (s *StmtSuite) TestInsertSingleValue(c *gc.C) {
func (s *StmtSuite) TestInsertDate(c *gc.C) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
sql, err := table1.Insert(table1Col4).Add(Literal(date)).String()
sql, err := table1.INSERT(table1Col4).Add(Literal(date)).String()
c.Assert(err, gc.IsNil)
c.Assert(
@ -290,7 +290,7 @@ func (s *StmtSuite) TestInsertDate(c *gc.C) {
}
func (s *StmtSuite) TestInsertIgnore(c *gc.C) {
stmt := table1.Insert(table1Col1).Add(Literal(1)).IgnoreDuplicates(true)
stmt := table1.INSERT(table1Col1).Add(Literal(1)).IgnoreDuplicates(true)
sql, err := stmt.String()
c.Assert(err, gc.IsNil)
@ -301,7 +301,7 @@ func (s *StmtSuite) TestInsertIgnore(c *gc.C) {
}
func (s *StmtSuite) TestInsertMultipleValues(c *gc.C) {
stmt := table1.Insert(table1Col1, table1Col2, table1Col3)
stmt := table1.INSERT(table1Col1, table1Col2, table1Col3)
stmt.Add(Literal(1), Literal(2), Literal(3))
sql, err := stmt.String()
@ -316,7 +316,7 @@ func (s *StmtSuite) TestInsertMultipleValues(c *gc.C) {
}
func (s *StmtSuite) TestInsertMultipleRows(c *gc.C) {
stmt := table1.Insert(table1Col1, table1Col2)
stmt := table1.INSERT(table1Col1, table1Col2)
stmt.Add(Literal(1), Literal(2))
stmt.Add(Literal(11), Literal(22))
stmt.Add(Literal(111), Literal(222))
@ -333,7 +333,7 @@ func (s *StmtSuite) TestInsertMultipleRows(c *gc.C) {
}
func (s *StmtSuite) TestOnDuplicateKeyUpdateNilCol(c *gc.C) {
stmt := table1.Insert(table1Col1, table1Col2)
stmt := table1.INSERT(table1Col1, table1Col2)
stmt.Add(Literal(1), Literal(2))
stmt.AddOnDuplicateKeyUpdate(nil, Literal(3))
@ -342,7 +342,7 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateNilCol(c *gc.C) {
}
func (s *StmtSuite) TestOnDuplicateKeyUpdateNilExpr(c *gc.C) {
stmt := table1.Insert(table1Col1, table1Col2)
stmt := table1.INSERT(table1Col1, table1Col2)
stmt.Add(Literal(1), Literal(2))
stmt.AddOnDuplicateKeyUpdate(table1Col1, nil)
@ -351,7 +351,7 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateNilExpr(c *gc.C) {
}
func (s *StmtSuite) TestOnDuplicateKeyUpdateSingle(c *gc.C) {
stmt := table1.Insert(table1Col1, table1Col2)
stmt := table1.INSERT(table1Col1, table1Col2)
stmt.Add(Literal(1), Literal(2))
stmt.AddOnDuplicateKeyUpdate(table1Col3, Literal(3))
@ -368,7 +368,7 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateSingle(c *gc.C) {
}
func (s *StmtSuite) TestOnDuplicateKeyUpdateMulti(c *gc.C) {
stmt := table1.Insert(table1Col1, table1Col2)
stmt := table1.INSERT(table1Col1, table1Col2)
stmt.Add(Literal(1), Literal(2))
stmt.AddOnDuplicateKeyUpdate(table1Col3, Literal(3))
stmt.AddOnDuplicateKeyUpdate(table1Col2, Literal(4))

View file

@ -8,17 +8,20 @@ import (
"github.com/dropbox/godropbox/errors"
)
// The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause
// are not supported.
type ReadableTable interface {
type TableInterface interface {
SchemaName() string
TableName() string
// Returns the list of columns that are in the current tableName expression.
Columns() []Column
//Column(name string) Column
// Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement.
SerializeSql(out *bytes.Buffer) error
}
// The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause
// are not supported.
type ReadableTable interface {
TableInterface
// Generates a select query on the current tableName.
SELECT(projections ...Projection) SelectStatement
@ -26,8 +29,6 @@ type ReadableTable interface {
// Creates a inner join tableName expression using onCondition.
INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
//InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable
// Creates a left join tableName expression using onCondition.
LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
@ -41,15 +42,9 @@ type ReadableTable interface {
// The sql tableName write interface.
type WritableTable interface {
// Returns the list of columns that are in the tableName.
Columns() []Column
TableInterface
// Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement.
// The database is the name of the database the tableName is on
SerializeSql(out *bytes.Buffer) error
Insert(columns ...Column) InsertStatement
INSERT(columns ...Column) InsertStatement
Update() UpdateStatement
Delete() DeleteStatement
}
@ -72,7 +67,7 @@ func NewTable(schemaName, name string, columns ...Column) *Table {
if err != nil {
panic(err)
}
t.columnLookup[c.Name()] = c
t.columnLookup[c.TableName()] = c
}
if len(columns) == 0 {
@ -132,11 +127,16 @@ func (t *Table) SetAlias(alias string) {
}
// Returns the tableName's name in the database
func (t *Table) Name() string {
func (t *Table) SchemaName() string {
return t.schemaName
}
// Returns the tableName's name in the database
func (t *Table) TableName() string {
return t.name
}
func (t *Table) SchemaName() string {
func (t *Table) SchemaTableName() string {
return t.schemaName
}
@ -161,7 +161,7 @@ func (t *Table) SerializeSql(out *bytes.Buffer) error {
_, _ = out.WriteString(t.schemaName)
_, _ = out.WriteString(".")
_, _ = out.WriteString(t.Name())
_, _ = out.WriteString(t.TableName())
if len(t.alias) > 0 {
out.WriteString(" AS ")
@ -225,7 +225,7 @@ func (t *Table) CrossJoin(table ReadableTable) ReadableTable {
return CrossJoin(t, table)
}
func (t *Table) Insert(columns ...Column) InsertStatement {
func (t *Table) INSERT(columns ...Column) InsertStatement {
return newInsertStatement(t, columns...)
}
@ -308,6 +308,15 @@ func CrossJoin(
return newJoinTable(lhs, rhs, CROSS_JOIN, nil)
}
// Returns the tableName's name in the database
func (t *joinTable) SchemaName() string {
return ""
}
func (t *joinTable) TableName() string {
return ""
}
func (t *joinTable) Columns() []Column {
columns := make([]Column, 0)
columns = append(columns, t.lhs.Columns()...)

View file

@ -1,11 +1,9 @@
// +build disabled
package sqlbuilder
var table1Col1 = IntColumn("col1", Nullable)
var table1Col2 = IntColumn("col2", Nullable)
var table1Col3 = IntColumn("col3", Nullable)
var table1Col4 = DateTimeColumn("col4", Nullable)
var table1Col1 = NewIntegerColumn("col1", Nullable)
var table1Col2 = NewIntegerColumn("col2", Nullable)
var table1Col3 = NewIntegerColumn("col3", Nullable)
var table1Col4 = NewTimeColumn("col4", Nullable)
var table1 = NewTable(
"db",
"table1",
@ -14,16 +12,16 @@ var table1 = NewTable(
table1Col3,
table1Col4)
var table2Col3 = IntColumn("col3", Nullable)
var table2Col4 = IntColumn("col4", Nullable)
var table2Col3 = NewIntegerColumn("col3", Nullable)
var table2Col4 = NewIntegerColumn("col4", Nullable)
var table2 = NewTable(
"db",
"table2",
table2Col3,
table2Col4)
var table3Col1 = IntColumn("col1", Nullable)
var table3Col2 = IntColumn("col2", Nullable)
var table3Col1 = NewIntegerColumn("col1", Nullable)
var table3Col2 = NewIntegerColumn("col2", Nullable)
var table3 = NewTable(
"db",
"table3",

View file

@ -1,52 +1,17 @@
package tests
import (
"database/sql"
"fmt"
"github.com/sub0Zero/go-sqlbuilder/generator"
"github.com/sub0Zero/go-sqlbuilder/sqlbuilder"
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model"
. "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table"
"gotest.tools/assert"
"os"
"strings"
"testing"
"time"
)
const (
folderPath = ".test_files/"
host = "localhost"
port = 5432
user = "postgres"
password = "postgres"
dbname = "dvd_rental"
schemaName = "dvds"
)
var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname)
var db *sql.DB
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema test_sample -path .test_files
func TestMain(m *testing.M) {
fmt.Println("Begin")
var err error
db, err = sql.Open("postgres", connectString)
if err != nil {
panic("Failed to connect to test db")
}
defer db.Close()
ret := m.Run()
db.Close()
fmt.Println("END")
os.Exit(ret)
}
func TestGenerateModel(t *testing.T) {
err := generator.Generate(folderPath, connectString, dbname, schemaName)

79
tests/insert_test.go Normal file
View file

@ -0,0 +1,79 @@
package tests
import (
"fmt"
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model"
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table"
"gotest.tools/assert"
"testing"
)
func TestInsertValues(t *testing.T) {
insertQuery := table.Link.INSERT(table.Link.URL, table.Link.Name).
VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial").
VALUES("http://www.google.com", "Google").
VALUES("http://www.yahoo.com", "Yahoo").
VALUES("http://www.bing.com", "Bing").
RETURNING(table.Link.ID)
insertQueryStr, err := insertQuery.String()
assert.NilError(t, err)
fmt.Println(insertQueryStr)
assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial'), ('http://www.google.com','Google'), ('http://www.yahoo.com','Yahoo'), ('http://www.bing.com','Bing') RETURNING link.id;`)
res, err := insertQuery.Execute(db)
assert.NilError(t, err)
rowsAffected, err := res.RowsAffected()
assert.NilError(t, err)
assert.Equal(t, rowsAffected, int64(4))
link := []model.Link{}
err = table.Link.SELECT(table.Link.AllColumns).Execute(db, &link)
assert.NilError(t, err)
assert.Equal(t, len(link), 4)
assert.DeepEqual(t, link[0], model.Link{
ID: 1,
URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial",
Rel: nil,
})
assert.DeepEqual(t, link[3], model.Link{
ID: 4,
URL: "http://www.bing.com",
Name: "Bing",
Rel: nil,
})
}
func TestInsertDataObject(t *testing.T) {
linkData := model.Link{
URL: "http://www.duckduckgo.com",
Name: "Duck Duck go",
Rel: nil,
}
query := table.Link.INSERT(table.Link.URL, table.Link.Name).
VALUES_MAPPING(linkData)
queryStr, err := query.String()
assert.NilError(t, err)
fmt.Println(queryStr)
result, err := query.Execute(db)
assert.NilError(t, err)
fmt.Println(result)
}

75
tests/main_test.go Normal file
View file

@ -0,0 +1,75 @@
package tests
import (
"database/sql"
"fmt"
"os"
"testing"
)
const (
folderPath = ".test_files/"
host = "localhost"
port = 5432
user = "postgres"
password = "postgres"
dbname = "dvd_rental"
schemaName = "dvds"
)
var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname)
var db *sql.DB
var tx *sql.Tx
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema test_sample -path .test_files
func TestMain(m *testing.M) {
fmt.Println("Begin")
var err error
db, err = sql.Open("postgres", connectString)
if err != nil {
panic("Failed to connect to test db")
}
tx, _ = db.Begin()
defer cleanUp()
dbInit()
ret := m.Run()
cleanUp()
fmt.Println("END")
os.Exit(ret)
}
func cleanUp() {
fmt.Println("CLEAN UP")
tx.Rollback()
db.Close()
}
func dbInit() {
linkTableCreate := `
DROP TABLE IF EXISTS test_sample.link;
CREATE TABLE IF NOT EXISTS test_sample.link (
ID serial PRIMARY KEY,
url VARCHAR (255) NOT NULL,
name VARCHAR (255) NOT NULL,
description VARCHAR (255),
rel VARCHAR (50)
);`
result, err := db.Exec(linkTableCreate)
if err != nil {
panic(err)
}
fmt.Println(result)
}

8
types/db.go Normal file
View file

@ -0,0 +1,8 @@
package types
import "database/sql"
type Db interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
}