From 599a8c537aefde0354139668dfb5cc84ab23d886 Mon Sep 17 00:00:00 2001 From: zer0sub Date: Sun, 7 Apr 2019 09:58:12 +0200 Subject: [PATCH] Add support for INSERT statements. --- sqlbuilder/bool_expression_test.go | 4 +- sqlbuilder/column_types_test.go | 8 +- sqlbuilder/execution/execution.go | 3 +- sqlbuilder/insert_statement.go | 225 +++++++++++++++++++++++++++ sqlbuilder/insert_statement_test.go | 125 +++++++++++++++ sqlbuilder/select_statement.go | 7 +- sqlbuilder/select_statement_table.go | 9 ++ sqlbuilder/statement.go | 188 ---------------------- sqlbuilder/statement_test.go | 28 ++-- sqlbuilder/table.go | 51 +++--- sqlbuilder/test_utils.go | 18 +-- tests/generator_test.go | 35 ----- tests/insert_test.go | 79 ++++++++++ tests/main_test.go | 75 +++++++++ types/db.go | 8 + 15 files changed, 586 insertions(+), 277 deletions(-) create mode 100644 sqlbuilder/insert_statement.go create mode 100644 sqlbuilder/insert_statement_test.go create mode 100644 tests/insert_test.go create mode 100644 tests/main_test.go create mode 100644 types/db.go diff --git a/sqlbuilder/bool_expression_test.go b/sqlbuilder/bool_expression_test.go index eba700f..b0bc407 100644 --- a/sqlbuilder/bool_expression_test.go +++ b/sqlbuilder/bool_expression_test.go @@ -19,7 +19,7 @@ func TestBinaryExpression(t *testing.T) { alias := boolExpression.As("alias_eq_expression") out := bytes.Buffer{} - err := alias.SerializeSql(&out) + err := alias.SerializeForProjection(&out) assert.NilError(t, err) assert.Equal(t, out.String(), `2 = 3 AS "alias_eq_expression"`) @@ -59,7 +59,7 @@ func TestUnaryExpression(t *testing.T) { alias := notExpression.As("alias_not_expression") out := bytes.Buffer{} - err := alias.SerializeSql(&out) + err := alias.SerializeForProjection(&out) assert.NilError(t, err) assert.Equal(t, out.String(), ` NOT 2 = 1 AS "alias_not_expression"`) diff --git a/sqlbuilder/column_types_test.go b/sqlbuilder/column_types_test.go index 355175d..344e1f6 100644 --- a/sqlbuilder/column_types_test.go +++ b/sqlbuilder/column_types_test.go @@ -31,7 +31,7 @@ func TestNewBoolColumn(t *testing.T) { err = boolColumn.setTableName("table1") assert.NilError(t, err) aliasedBoolColumn := boolColumn.As("alias1") - err = aliasedBoolColumn.SerializeSql(&out, FOR_PROJECTION) + err = aliasedBoolColumn.SerializeForProjection(&out) assert.NilError(t, err) assert.Equal(t, out.String(), `table1.col AS "alias1"`) } @@ -61,7 +61,7 @@ func TestNewIntColumn(t *testing.T) { err = integerColumn.setTableName("table1") assert.NilError(t, err) aliasedBoolColumn := integerColumn.As("alias1") - err = aliasedBoolColumn.SerializeSql(&out, FOR_PROJECTION) + err = aliasedBoolColumn.SerializeForProjection(&out) assert.NilError(t, err) assert.Equal(t, out.String(), `table1.col AS "alias1"`) } @@ -76,7 +76,7 @@ func TestNewNumericColumnColumn(t *testing.T) { assert.Equal(t, out.String(), "col") out.Reset() - err = numericColumn.SerializeSql(&out, FOR_PROJECTION) + err = numericColumn.SerializeSql(&out) assert.NilError(t, err) assert.Equal(t, out.String(), "col") @@ -91,7 +91,7 @@ func TestNewNumericColumnColumn(t *testing.T) { err = numericColumn.setTableName("table1") assert.NilError(t, err) aliasedBoolColumn := numericColumn.As("alias1") - err = aliasedBoolColumn.SerializeSql(&out, FOR_PROJECTION) + err = aliasedBoolColumn.SerializeForProjection(&out) assert.NilError(t, err) assert.Equal(t, out.String(), `table1.col AS "alias1"`) } diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index 94ca3db..d249240 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "github.com/serenize/snaker" + "github.com/sub0Zero/go-sqlbuilder/types" "reflect" "regexp" "strconv" @@ -13,7 +14,7 @@ import ( "time" ) -func Execute(db *sql.DB, query string, destinationPtr interface{}) error { +func Execute(db types.Db, query string, destinationPtr interface{}) error { if db == nil { return errors.New("db is nil") } diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go new file mode 100644 index 0000000..8c8b660 --- /dev/null +++ b/sqlbuilder/insert_statement.go @@ -0,0 +1,225 @@ +package sqlbuilder + +import ( + "bytes" + "database/sql" + "github.com/dropbox/godropbox/errors" + "github.com/serenize/snaker" + "github.com/sub0Zero/go-sqlbuilder/types" + "reflect" + "strings" +) + +type InsertStatement interface { + Statement + + // Add a row of values to the insert statement. + VALUES(values ...interface{}) InsertStatement + // Map or stracture mapped to column names + VALUES_MAPPING(data interface{}) InsertStatement + + RETURNING(column ...Expression) InsertStatement + + Execute(db types.Db) (sql.Result, error) +} + +func newInsertStatement(t WritableTable, columns ...Column) InsertStatement { + return &insertStatementImpl{ + table: t, + columns: columns, + rows: make([][]Expression, 0, 1), + returning: make([]Expression, 0, 1), + } +} + +type columnAssignment struct { + col Column + expr Expression +} + +type insertStatementImpl struct { + table WritableTable + columns []Column + rows [][]Expression + returning []Expression + + errors []string +} + +func (i *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) { + query, err := i.String() + + if err != nil { + return + } + + res, err = db.Exec(query) + + return +} + +//func (i *insertStatementImpl) ExecuteInTx(tx *sql.Tx) (res sql.Result, err error) { +// query, err := i.String() +// +// if err != nil { +// return +// } +// +// res, err = tx.Exec(query) +// +// return +//} + +func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement { + literalRow := []Expression{} + + for _, value := range values { + literalRow = append(literalRow, Literal(value)) + } + + s.rows = append(s.rows, literalRow) + return s +} + +func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement { + if data == nil { + i.addError("Add method data is nil.") + return i + } + + value := reflect.ValueOf(data) + + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + if value.Kind() != reflect.Struct { + i.addError("Add method data is not struct or pointer to struct.") + return i + } + + rowValues := []Expression{} + + for _, column := range i.columns { + columnName := column.Name() + structFieldName := snaker.SnakeToCamel(columnName) + + structField := value.FieldByName(structFieldName) + + if !structField.IsValid() { + i.addError("Add() : Data structure doesn't contain field : " + structFieldName + " for column " + columnName) + return i + } + + rowValues = append(rowValues, Literal(structField.Interface())) + } + + i.rows = append(i.rows, rowValues) + + return i +} + +func (i *insertStatementImpl) RETURNING(column ...Expression) InsertStatement { + i.returning = column + + return i +} + +func (i *insertStatementImpl) addError(err string) { + i.errors = append(i.errors, err) +} + +func (s *insertStatementImpl) String() (sql string, err error) { + buf := new(bytes.Buffer) + _, _ = buf.WriteString("INSERT ") + _, _ = buf.WriteString("INTO ") + + if len(s.errors) > 0 { + return "", errors.New("sql builder errors: " + strings.Join(s.errors, ", ")) + } + + if s.table == nil { + return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) + } + + buf.WriteString(s.table.SchemaName() + "." + s.table.TableName()) + + if len(s.columns) == 0 { + return "", errors.Newf( + "No column specified. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" (") + for i, col := range s.columns { + if i > 0 { + _ = buf.WriteByte(',') + } + + if col == nil { + return "", errors.Newf( + "nil column in columns list. Generated sql: %s", + buf.String()) + } + + buf.WriteString(col.Name()) + } + + if len(s.rows) == 0 { + return "", errors.Newf( + "No row specified. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(") VALUES (") + for row_i, row := range s.rows { + if row_i > 0 { + _, _ = buf.WriteString(", (") + } + + if len(row) != len(s.columns) { + return "", errors.Newf( + "# of values does not match # of columns. Generated sql: %s", + buf.String()) + } + + for col_i, value := range row { + if col_i > 0 { + _ = buf.WriteByte(',') + } + + if value == nil { + return "", errors.Newf( + "nil value in row %d col %d. Generated sql: %s", + row_i, + col_i, + buf.String()) + } + + if err = value.SerializeSql(buf); err != nil { + return + } + } + _ = buf.WriteByte(')') + } + + if len(s.returning) > 0 { + buf.WriteString(" RETURNING ") + + for i, column := range s.returning { + if i > 0 { + buf.WriteString(",") + } + + err = column.SerializeSql(buf) + + if err != nil { + return + } + } + } + + buf.WriteByte(';') + + return buf.String(), nil +} diff --git a/sqlbuilder/insert_statement_test.go b/sqlbuilder/insert_statement_test.go new file mode 100644 index 0000000..d43b872 --- /dev/null +++ b/sqlbuilder/insert_statement_test.go @@ -0,0 +1,125 @@ +package sqlbuilder + +import ( + "fmt" + "gotest.tools/assert" + "testing" + "time" +) + +func TestInsertNoColumn(t *testing.T) { + _, err := table1.INSERT().VALUES().String() + + assert.Assert(t, err != nil) +} + +func TestInsertNoRow(t *testing.T) { + _, err := table1.INSERT(table1Col1).String() + + assert.Assert(t, err != nil) +} + +func TestInsertColumnLengthMismatch(t *testing.T) { + _, err := table1.INSERT(table1Col1, table1Col2).VALUES(nil).String() + + fmt.Println(err) + assert.Assert(t, err != nil) +} + +func TestInsertNilValue(t *testing.T) { + _, err := table1.INSERT(table1Col1).VALUES(nil).String() + + assert.Assert(t, err != nil) +} + +func TestInsertNilColumn(t *testing.T) { + _, err := table1.INSERT(nil).VALUES(1).String() + + assert.Assert(t, err != nil) +} + +func TestInsertSingleValue(t *testing.T) { + sql, err := table1.INSERT(table1Col1).VALUES(1).String() + assert.NilError(t, err) + + assert.Equal(t, sql, "INSERT INTO db.table1 (col1) VALUES (1)") +} + +func TestInsertDate(t *testing.T) { + date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) + + sql, err := table1.INSERT(table1Col4).VALUES(date).String() + assert.NilError(t, err) + + assert.Equal(t, sql, "INSERT INTO db.table1 (col4) "+ + "VALUES ('1999-01-02 03:04:05.000000')") +} + +func TestInsertMultipleValues(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1Col2, table1Col3) + stmt.VALUES(1, 2, 3) + + sql, err := stmt.String() + assert.NilError(t, err) + + assert.Equal(t, sql, "INSERT INTO db.table1 "+ + "(col1,col2,col3) "+ + "VALUES (1,2,3)") +} + +func TestInsertMultipleRows(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1Col2). + VALUES(1, 2). + VALUES(11, 22). + VALUES(111, 222) + + sql, err := stmt.String() + assert.NilError(t, err) + + assert.Equal(t, sql, "INSERT INTO db.table1 "+ + "(col1,col2) "+ + "VALUES (1,2), (11,22), (111,222)") +} + +func TestInsertValuesFromModel(t *testing.T) { + type Table1Model struct { + Col1 int + Col2 string + } + + toInsert := Table1Model{ + Col1: 1, + Col2: "one", + } + + stmt := table1.INSERT(table1Col1, table1Col2). + VALUES_MAPPING(toInsert) + + sql, err := stmt.String() + + assert.NilError(t, err) + + fmt.Println(sql) + + assert.Equal(t, sql, `INSERT INTO db.table1 (col1,col2) VALUES (1,'one')`) +} + +func TestInsertValuesFromModelColumnMismatch(t *testing.T) { + type Table1Model struct { + Col1Prim int + Col2 string + } + + toInsert := Table1Model{ + Col1Prim: 1, + Col2: "one", + } + + stmt := table1.INSERT(table1Col1, table1Col2). + VALUES_MAPPING(toInsert) + + _, err := stmt.String() + + fmt.Println(err) + assert.Assert(t, err != nil) +} diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 2c53fd2..ea0096a 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -2,10 +2,10 @@ package sqlbuilder import ( "bytes" - "database/sql" "fmt" "github.com/dropbox/godropbox/errors" "github.com/sub0Zero/go-sqlbuilder/sqlbuilder/execution" + "github.com/sub0Zero/go-sqlbuilder/types" "reflect" ) @@ -28,6 +28,9 @@ type SelectStatement interface { Copy() SelectStatement AsTable(alias string) *SelectStatementTable + + Execute(db types.Db, destination interface{}) error + //ExecuteInTx(tx *sql.Tx, destination interface{}) error } // NOTE: SelectStatement purposely does not implement the Table interface since @@ -84,7 +87,7 @@ func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable { } } -func (s *selectStatementImpl) Execute(db *sql.DB, destination interface{}) error { +func (s *selectStatementImpl) Execute(db types.Db, destination interface{}) error { destinationType := reflect.TypeOf(destination) if destinationType.Kind() == reflect.Ptr && destinationType.Elem().Kind() == reflect.Struct { diff --git a/sqlbuilder/select_statement_table.go b/sqlbuilder/select_statement_table.go index fb75e09..0410207 100644 --- a/sqlbuilder/select_statement_table.go +++ b/sqlbuilder/select_statement_table.go @@ -8,6 +8,15 @@ type SelectStatementTable struct { alias string } +// Returns the tableName's name in the database +func (t *SelectStatementTable) SchemaName() string { + return "" +} + +func (s *SelectStatementTable) TableName() string { + return s.alias +} + func (s *SelectStatementTable) Columns() []Column { return s.columns } diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go index 653c4a5..f8b6b9b 100644 --- a/sqlbuilder/statement.go +++ b/sqlbuilder/statement.go @@ -12,17 +12,6 @@ import ( type Statement interface { // String returns generated SQL as string. String() (sql string, err error) - Execute(db *sql.DB, destination interface{}) error -} - -type InsertStatement interface { - Statement - - // Add a row of values to the insert statement. - Add(row ...Expression) InsertStatement - AddOnDuplicateKeyUpdate(col Column, expr Expression) InsertStatement - Comment(comment string) InsertStatement - IgnoreDuplicates(ignore bool) InsertStatement } // By default, rows selected by a UNION statement are out-of-order @@ -261,183 +250,6 @@ func (us *unionStatementImpl) String() (sql string, err error) { return buf.String(), nil } -// -// INSERT Statement ============================================================ -// - -func newInsertStatement( - t WritableTable, - columns ...Column) InsertStatement { - - return &insertStatementImpl{ - table: t, - columns: columns, - rows: make([][]Expression, 0, 1), - onDuplicateKeyUpdates: make([]columnAssignment, 0, 0), - } -} - -type columnAssignment struct { - col Column - expr Expression -} - -type insertStatementImpl struct { - table WritableTable - columns []Column - rows [][]Expression - onDuplicateKeyUpdates []columnAssignment - comment string - ignore bool -} - -func (i *insertStatementImpl) Execute(db *sql.DB, data interface{}) error { - return nil -} - -func (s *insertStatementImpl) Add( - row ...Expression) InsertStatement { - - s.rows = append(s.rows, row) - return s -} - -func (s *insertStatementImpl) AddOnDuplicateKeyUpdate( - col Column, - expr Expression) InsertStatement { - - s.onDuplicateKeyUpdates = append( - s.onDuplicateKeyUpdates, - columnAssignment{col, expr}) - - return s -} - -func (s *insertStatementImpl) IgnoreDuplicates(ignore bool) InsertStatement { - s.ignore = ignore - return s -} - -func (s *insertStatementImpl) Comment(comment string) InsertStatement { - s.comment = comment - return s -} - -func (s *insertStatementImpl) String() (sql string, err error) { - buf := new(bytes.Buffer) - _, _ = buf.WriteString("INSERT ") - if s.ignore { - _, _ = buf.WriteString("IGNORE ") - } - _, _ = buf.WriteString("INTO ") - - if err = writeComment(s.comment, buf); err != nil { - return - } - - if s.table == nil { - return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) - } - - if err = s.table.SerializeSql(buf); err != nil { - return - } - - if len(s.columns) == 0 { - return "", errors.Newf( - "No column specified. Generated sql: %s", - buf.String()) - } - - _, _ = buf.WriteString(" (") - for i, col := range s.columns { - if i > 0 { - _ = buf.WriteByte(',') - } - - if col == nil { - return "", errors.Newf( - "nil column in columns list. Generated sql: %s", - buf.String()) - } - - if err = col.SerializeSql(buf, FOR_PROJECTION); err != nil { - return - } - } - - if len(s.rows) == 0 { - return "", errors.Newf( - "No row specified. Generated sql: %s", - buf.String()) - } - - _, _ = buf.WriteString(") VALUES (") - for row_i, row := range s.rows { - if row_i > 0 { - _, _ = buf.WriteString(", (") - } - - if len(row) != len(s.columns) { - return "", errors.Newf( - "# of values does not match # of columns. Generated sql: %s", - buf.String()) - } - - for col_i, value := range row { - if col_i > 0 { - _ = buf.WriteByte(',') - } - - if value == nil { - return "", errors.Newf( - "nil value in row %d col %d. Generated sql: %s", - row_i, - col_i, - buf.String()) - } - - if err = value.SerializeSql(buf); err != nil { - return - } - } - _ = buf.WriteByte(')') - } - - if len(s.onDuplicateKeyUpdates) > 0 { - _, _ = buf.WriteString(" ON DUPLICATE KEY UPDATE ") - for i, colExpr := range s.onDuplicateKeyUpdates { - if i > 0 { - _, _ = buf.WriteString(", ") - } - - if colExpr.col == nil { - return "", errors.Newf( - "nil column in on duplicate key update list. "+"Generated sql: %s", - buf.String()) - } - - if err = colExpr.col.SerializeSql(buf, FOR_PROJECTION); err != nil { - return - } - - _ = buf.WriteByte('=') - - if colExpr.expr == nil { - return "", errors.Newf( - "nil expression in on duplicate key update list. "+"Generated sql: %s", - buf.String()) - } - - if err = colExpr.expr.SerializeSql(buf); err != nil { - return - } - } - } - - return buf.String(), nil -} - // // UPDATE statement =========================================================== // diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go index 48409f1..f580fca 100644 --- a/sqlbuilder/statement_test.go +++ b/sqlbuilder/statement_test.go @@ -237,37 +237,37 @@ func (s *StmtSuite) TestSelectDistinct(c *gc.C) { // func (s *StmtSuite) TestInsertNoColumn(c *gc.C) { - _, err := table1.Insert().Add().String() + _, err := table1.INSERT().Add().String() c.Assert(err, gc.NotNil) } func (s *StmtSuite) TestInsertNoRow(c *gc.C) { - _, err := table1.Insert(table1Col1).String() + _, err := table1.INSERT(table1Col1).String() c.Assert(err, gc.NotNil) } func (s *StmtSuite) TestInsertColumnLengthMismatch(c *gc.C) { - _, err := table1.Insert(table1Col1, table1Col2).Add(nil).String() + _, err := table1.INSERT(table1Col1, table1Col2).Add(nil).String() c.Assert(err, gc.NotNil) } func (s *StmtSuite) TestInsertNilValue(c *gc.C) { - _, err := table1.Insert(table1Col1).Add(nil).String() + _, err := table1.INSERT(table1Col1).Add(nil).String() c.Assert(err, gc.NotNil) } func (s *StmtSuite) TestInsertNilColumn(c *gc.C) { - _, err := table1.Insert(nil).Add(Literal(1)).String() + _, err := table1.INSERT(nil).Add(Literal(1)).String() c.Assert(err, gc.NotNil) } func (s *StmtSuite) TestInsertSingleValue(c *gc.C) { - sql, err := table1.Insert(table1Col1).Add(Literal(1)).String() + sql, err := table1.INSERT(table1Col1).Add(Literal(1)).String() c.Assert(err, gc.IsNil) c.Assert( @@ -279,7 +279,7 @@ func (s *StmtSuite) TestInsertSingleValue(c *gc.C) { func (s *StmtSuite) TestInsertDate(c *gc.C) { date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) - sql, err := table1.Insert(table1Col4).Add(Literal(date)).String() + sql, err := table1.INSERT(table1Col4).Add(Literal(date)).String() c.Assert(err, gc.IsNil) c.Assert( @@ -290,7 +290,7 @@ func (s *StmtSuite) TestInsertDate(c *gc.C) { } func (s *StmtSuite) TestInsertIgnore(c *gc.C) { - stmt := table1.Insert(table1Col1).Add(Literal(1)).IgnoreDuplicates(true) + stmt := table1.INSERT(table1Col1).Add(Literal(1)).IgnoreDuplicates(true) sql, err := stmt.String() c.Assert(err, gc.IsNil) @@ -301,7 +301,7 @@ func (s *StmtSuite) TestInsertIgnore(c *gc.C) { } func (s *StmtSuite) TestInsertMultipleValues(c *gc.C) { - stmt := table1.Insert(table1Col1, table1Col2, table1Col3) + stmt := table1.INSERT(table1Col1, table1Col2, table1Col3) stmt.Add(Literal(1), Literal(2), Literal(3)) sql, err := stmt.String() @@ -316,7 +316,7 @@ func (s *StmtSuite) TestInsertMultipleValues(c *gc.C) { } func (s *StmtSuite) TestInsertMultipleRows(c *gc.C) { - stmt := table1.Insert(table1Col1, table1Col2) + stmt := table1.INSERT(table1Col1, table1Col2) stmt.Add(Literal(1), Literal(2)) stmt.Add(Literal(11), Literal(22)) stmt.Add(Literal(111), Literal(222)) @@ -333,7 +333,7 @@ func (s *StmtSuite) TestInsertMultipleRows(c *gc.C) { } func (s *StmtSuite) TestOnDuplicateKeyUpdateNilCol(c *gc.C) { - stmt := table1.Insert(table1Col1, table1Col2) + stmt := table1.INSERT(table1Col1, table1Col2) stmt.Add(Literal(1), Literal(2)) stmt.AddOnDuplicateKeyUpdate(nil, Literal(3)) @@ -342,7 +342,7 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateNilCol(c *gc.C) { } func (s *StmtSuite) TestOnDuplicateKeyUpdateNilExpr(c *gc.C) { - stmt := table1.Insert(table1Col1, table1Col2) + stmt := table1.INSERT(table1Col1, table1Col2) stmt.Add(Literal(1), Literal(2)) stmt.AddOnDuplicateKeyUpdate(table1Col1, nil) @@ -351,7 +351,7 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateNilExpr(c *gc.C) { } func (s *StmtSuite) TestOnDuplicateKeyUpdateSingle(c *gc.C) { - stmt := table1.Insert(table1Col1, table1Col2) + stmt := table1.INSERT(table1Col1, table1Col2) stmt.Add(Literal(1), Literal(2)) stmt.AddOnDuplicateKeyUpdate(table1Col3, Literal(3)) @@ -368,7 +368,7 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateSingle(c *gc.C) { } func (s *StmtSuite) TestOnDuplicateKeyUpdateMulti(c *gc.C) { - stmt := table1.Insert(table1Col1, table1Col2) + stmt := table1.INSERT(table1Col1, table1Col2) stmt.Add(Literal(1), Literal(2)) stmt.AddOnDuplicateKeyUpdate(table1Col3, Literal(3)) stmt.AddOnDuplicateKeyUpdate(table1Col2, Literal(4)) diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index 98c4867..d7d08aa 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -8,17 +8,20 @@ import ( "github.com/dropbox/godropbox/errors" ) -// The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause -// are not supported. -type ReadableTable interface { +type TableInterface interface { + SchemaName() string + TableName() string // Returns the list of columns that are in the current tableName expression. Columns() []Column - - //Column(name string) Column - // Generates the sql string for the current tableName expression. Note: the // generated string may not be a valid/executable sql statement. SerializeSql(out *bytes.Buffer) error +} + +// The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause +// are not supported. +type ReadableTable interface { + TableInterface // Generates a select query on the current tableName. SELECT(projections ...Projection) SelectStatement @@ -26,8 +29,6 @@ type ReadableTable interface { // Creates a inner join tableName expression using onCondition. INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable - //InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable - // Creates a left join tableName expression using onCondition. LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable @@ -41,15 +42,9 @@ type ReadableTable interface { // The sql tableName write interface. type WritableTable interface { - // Returns the list of columns that are in the tableName. - Columns() []Column + TableInterface - // Generates the sql string for the current tableName expression. Note: the - // generated string may not be a valid/executable sql statement. - // The database is the name of the database the tableName is on - SerializeSql(out *bytes.Buffer) error - - Insert(columns ...Column) InsertStatement + INSERT(columns ...Column) InsertStatement Update() UpdateStatement Delete() DeleteStatement } @@ -72,7 +67,7 @@ func NewTable(schemaName, name string, columns ...Column) *Table { if err != nil { panic(err) } - t.columnLookup[c.Name()] = c + t.columnLookup[c.TableName()] = c } if len(columns) == 0 { @@ -132,11 +127,16 @@ func (t *Table) SetAlias(alias string) { } // Returns the tableName's name in the database -func (t *Table) Name() string { +func (t *Table) SchemaName() string { + return t.schemaName +} + +// Returns the tableName's name in the database +func (t *Table) TableName() string { return t.name } -func (t *Table) SchemaName() string { +func (t *Table) SchemaTableName() string { return t.schemaName } @@ -161,7 +161,7 @@ func (t *Table) SerializeSql(out *bytes.Buffer) error { _, _ = out.WriteString(t.schemaName) _, _ = out.WriteString(".") - _, _ = out.WriteString(t.Name()) + _, _ = out.WriteString(t.TableName()) if len(t.alias) > 0 { out.WriteString(" AS ") @@ -225,7 +225,7 @@ func (t *Table) CrossJoin(table ReadableTable) ReadableTable { return CrossJoin(t, table) } -func (t *Table) Insert(columns ...Column) InsertStatement { +func (t *Table) INSERT(columns ...Column) InsertStatement { return newInsertStatement(t, columns...) } @@ -308,6 +308,15 @@ func CrossJoin( return newJoinTable(lhs, rhs, CROSS_JOIN, nil) } +// Returns the tableName's name in the database +func (t *joinTable) SchemaName() string { + return "" +} + +func (t *joinTable) TableName() string { + return "" +} + func (t *joinTable) Columns() []Column { columns := make([]Column, 0) columns = append(columns, t.lhs.Columns()...) diff --git a/sqlbuilder/test_utils.go b/sqlbuilder/test_utils.go index 1bcd0ec..2274f95 100644 --- a/sqlbuilder/test_utils.go +++ b/sqlbuilder/test_utils.go @@ -1,11 +1,9 @@ -// +build disabled - package sqlbuilder -var table1Col1 = IntColumn("col1", Nullable) -var table1Col2 = IntColumn("col2", Nullable) -var table1Col3 = IntColumn("col3", Nullable) -var table1Col4 = DateTimeColumn("col4", Nullable) +var table1Col1 = NewIntegerColumn("col1", Nullable) +var table1Col2 = NewIntegerColumn("col2", Nullable) +var table1Col3 = NewIntegerColumn("col3", Nullable) +var table1Col4 = NewTimeColumn("col4", Nullable) var table1 = NewTable( "db", "table1", @@ -14,16 +12,16 @@ var table1 = NewTable( table1Col3, table1Col4) -var table2Col3 = IntColumn("col3", Nullable) -var table2Col4 = IntColumn("col4", Nullable) +var table2Col3 = NewIntegerColumn("col3", Nullable) +var table2Col4 = NewIntegerColumn("col4", Nullable) var table2 = NewTable( "db", "table2", table2Col3, table2Col4) -var table3Col1 = IntColumn("col1", Nullable) -var table3Col2 = IntColumn("col2", Nullable) +var table3Col1 = NewIntegerColumn("col1", Nullable) +var table3Col2 = NewIntegerColumn("col2", Nullable) var table3 = NewTable( "db", "table3", diff --git a/tests/generator_test.go b/tests/generator_test.go index 4b77b84..504338b 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -1,52 +1,17 @@ package tests import ( - "database/sql" "fmt" "github.com/sub0Zero/go-sqlbuilder/generator" "github.com/sub0Zero/go-sqlbuilder/sqlbuilder" "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" . "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" "gotest.tools/assert" - "os" "strings" "testing" "time" ) -const ( - folderPath = ".test_files/" - host = "localhost" - port = 5432 - user = "postgres" - password = "postgres" - dbname = "dvd_rental" - schemaName = "dvds" -) - -var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname) -var db *sql.DB - -//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files -//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema test_sample -path .test_files - -func TestMain(m *testing.M) { - fmt.Println("Begin") - var err error - db, err = sql.Open("postgres", connectString) - if err != nil { - panic("Failed to connect to test db") - } - defer db.Close() - - ret := m.Run() - - db.Close() - fmt.Println("END") - - os.Exit(ret) -} - func TestGenerateModel(t *testing.T) { err := generator.Generate(folderPath, connectString, dbname, schemaName) diff --git a/tests/insert_test.go b/tests/insert_test.go new file mode 100644 index 0000000..8e25181 --- /dev/null +++ b/tests/insert_test.go @@ -0,0 +1,79 @@ +package tests + +import ( + "fmt" + "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model" + "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table" + "gotest.tools/assert" + "testing" +) + +func TestInsertValues(t *testing.T) { + insertQuery := table.Link.INSERT(table.Link.URL, table.Link.Name). + VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial"). + VALUES("http://www.google.com", "Google"). + VALUES("http://www.yahoo.com", "Yahoo"). + VALUES("http://www.bing.com", "Bing"). + RETURNING(table.Link.ID) + + insertQueryStr, err := insertQuery.String() + + assert.NilError(t, err) + + fmt.Println(insertQueryStr) + + assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial'), ('http://www.google.com','Google'), ('http://www.yahoo.com','Yahoo'), ('http://www.bing.com','Bing') RETURNING link.id;`) + res, err := insertQuery.Execute(db) + + assert.NilError(t, err) + + rowsAffected, err := res.RowsAffected() + assert.NilError(t, err) + + assert.Equal(t, rowsAffected, int64(4)) + + link := []model.Link{} + + err = table.Link.SELECT(table.Link.AllColumns).Execute(db, &link) + + assert.NilError(t, err) + + assert.Equal(t, len(link), 4) + + assert.DeepEqual(t, link[0], model.Link{ + ID: 1, + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial", + Rel: nil, + }) + + assert.DeepEqual(t, link[3], model.Link{ + ID: 4, + URL: "http://www.bing.com", + Name: "Bing", + Rel: nil, + }) +} + +func TestInsertDataObject(t *testing.T) { + linkData := model.Link{ + URL: "http://www.duckduckgo.com", + Name: "Duck Duck go", + Rel: nil, + } + + query := table.Link.INSERT(table.Link.URL, table.Link.Name). + VALUES_MAPPING(linkData) + + queryStr, err := query.String() + + assert.NilError(t, err) + + fmt.Println(queryStr) + + result, err := query.Execute(db) + + assert.NilError(t, err) + + fmt.Println(result) +} diff --git a/tests/main_test.go b/tests/main_test.go new file mode 100644 index 0000000..12c5057 --- /dev/null +++ b/tests/main_test.go @@ -0,0 +1,75 @@ +package tests + +import ( + "database/sql" + "fmt" + "os" + "testing" +) + +const ( + folderPath = ".test_files/" + host = "localhost" + port = 5432 + user = "postgres" + password = "postgres" + dbname = "dvd_rental" + schemaName = "dvds" +) + +var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname) +var db *sql.DB +var tx *sql.Tx + +//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files +//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema test_sample -path .test_files + +func TestMain(m *testing.M) { + fmt.Println("Begin") + + var err error + db, err = sql.Open("postgres", connectString) + if err != nil { + panic("Failed to connect to test db") + } + tx, _ = db.Begin() + defer cleanUp() + + dbInit() + + ret := m.Run() + + cleanUp() + fmt.Println("END") + + os.Exit(ret) +} + +func cleanUp() { + fmt.Println("CLEAN UP") + + tx.Rollback() + db.Close() +} + +func dbInit() { + linkTableCreate := ` +DROP TABLE IF EXISTS test_sample.link; + +CREATE TABLE IF NOT EXISTS test_sample.link ( + ID serial PRIMARY KEY, + url VARCHAR (255) NOT NULL, + name VARCHAR (255) NOT NULL, + description VARCHAR (255), + rel VARCHAR (50) +);` + + result, err := db.Exec(linkTableCreate) + + if err != nil { + panic(err) + } + + fmt.Println(result) + +} diff --git a/types/db.go b/types/db.go new file mode 100644 index 0000000..a21dacd --- /dev/null +++ b/types/db.go @@ -0,0 +1,8 @@ +package types + +import "database/sql" + +type Db interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) +}