diff --git a/expression_table.go b/expression_table.go index c3a3114..7b1a750 100644 --- a/expression_table.go +++ b/expression_table.go @@ -36,7 +36,7 @@ func (e *expressionTableImpl) Alias() string { return e.alias } -func (e *expressionTableImpl) columns() []Column { +func (e *expressionTableImpl) columns() []column { return nil } diff --git a/generator/internal/metadata/postgres-metadata/table_info.go b/generator/internal/metadata/postgres-metadata/table_info.go index c738c59..43e8265 100644 --- a/generator/internal/metadata/postgres-metadata/table_info.go +++ b/generator/internal/metadata/postgres-metadata/table_info.go @@ -16,10 +16,24 @@ func (t TableInfo) Name() string { return t.name } -func (t TableInfo) IsUnique(columnName string) bool { +func (t TableInfo) IsPrimaryKey(columnName string) bool { return t.PrimaryKeys[columnName] } +func (t TableInfo) MutableColumns() []ColumnInfo { + ret := []ColumnInfo{} + + for _, column := range t.Columns { + if t.IsPrimaryKey(column.Name) { + continue + } + + ret = append(ret, column) + } + + return ret +} + func (t TableInfo) GetImports() []string { imports := map[string]string{} diff --git a/generator/postgresgen/templates.go b/generator/postgresgen/templates.go index daa0e99..267168a 100644 --- a/generator/postgresgen/templates.go +++ b/generator/postgresgen/templates.go @@ -36,7 +36,8 @@ type {{.GoStructName}} struct { {{camelize .Name}} jet.Column{{.SqlBuilderColumnType}} {{- end}} - AllColumns jet.ColumnList + AllColumns jet.ColumnList + MutableColumns jet.ColumnList } // creates new {{.GoStructName}} with assigned alias @@ -63,7 +64,8 @@ func new{{.GoStructName}}() *{{.GoStructName}} { {{camelize .Name}}: {{camelize .Name}}Column, {{- end}} - AllColumns: jet.ColumnList{ {{template "column-list" .Columns}} }, + AllColumns: jet.ColumnList{ {{template "column-list" .Columns}} }, + MutableColumns: jet.ColumnList{ {{template "column-list" .MutableColumns}} }, } } @@ -82,7 +84,7 @@ import ( type {{camelize .Name}} struct { {{- range .Columns}} - {{camelize .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsUnique .Name)}}" + ` + {{camelize .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + ` {{- end}} } ` diff --git a/insert_statement.go b/insert_statement.go index ba3e39d..c8e73e5 100644 --- a/insert_statement.go +++ b/insert_statement.go @@ -10,10 +10,13 @@ import ( type InsertStatement interface { Statement - // Add a row of values to the insert Statement. + // Insert row of values VALUES(value interface{}, values ...interface{}) InsertStatement - // Model structure mapped to column names - USING(data interface{}) InsertStatement + // Insert row of values, where value for each column is extracted from filed of structure data. + // If data is not struct or there is no field for every column selected, this method will panic. + MODEL(data interface{}) InsertStatement + + MODELS(data interface{}) InsertStatement QUERY(selectStatement SelectStatement) InsertStatement @@ -40,8 +43,13 @@ func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) I return i } -func (i *insertStatementImpl) USING(data interface{}) InsertStatement { - i.rows = append(i.rows, unwindRowFromModel(i.columns, data)) +func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { + i.rows = append(i.rows, unwindRowFromModel(i.getColumns(), data)) + return i +} + +func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { + i.rows = append(i.rows, unwindRowsFromModels(i.getColumns(), data)...) return i } @@ -55,6 +63,14 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState return i } +func (i *insertStatementImpl) getColumns() []column { + if len(i.columns) > 0 { + return i.columns + } + + return i.table.columns() +} + func (i *insertStatementImpl) DebugSql() (query string, err error) { return debugSql(i) } @@ -107,10 +123,6 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) queryData.newLine() queryData.writeString("(") - if len(row) != len(i.columns) { - return "", nil, errors.New("number of values does not match number of columns") - } - err = serializeClauseList(insert_statement, row, queryData) if err != nil { diff --git a/insert_statement_test.go b/insert_statement_test.go index 2e09d53..9e956d0 100644 --- a/insert_statement_test.go +++ b/insert_statement_test.go @@ -8,7 +8,6 @@ import ( func TestInvalidInsert(t *testing.T) { assertStatementErr(t, table1.INSERT(table1Col1), "no row values or query specified") - assertStatementErr(t, table1.INSERT(table1Col1, table1ColFloat).VALUES(11), "number of values does not match number of columns") assertStatementErr(t, table1.INSERT(nil).VALUES(1), "nil column in columns list") } @@ -79,8 +78,8 @@ func TestInsertValuesFromModel(t *testing.T) { } stmt := table1.INSERT(table1Col1, table1ColFloat). - USING(toInsert). - USING(&toInsert) + MODEL(toInsert). + MODEL(&toInsert) expectedSql := ` INSERT INTO db.table1 (col1, col_float) VALUES @@ -108,7 +107,7 @@ func TestInsertValuesFromModelColumnMismatch(t *testing.T) { table1. INSERT(table1Col1, table1ColFloat). - USING(newData) + MODEL(newData) } func TestInsertFromNonStructModel(t *testing.T) { @@ -118,7 +117,7 @@ func TestInsertFromNonStructModel(t *testing.T) { assert.Equal(t, r, "argument mismatch: expected struct, got []int") }() - table2.INSERT(table2ColInt).USING([]int{}) + table2.INSERT(table2ColInt).MODEL([]int{}) } func TestInsertQuery(t *testing.T) { diff --git a/table.go b/table.go index 8103207..a7dfb2e 100644 --- a/table.go +++ b/table.go @@ -6,6 +6,10 @@ import ( "errors" ) +type table interface { + columns() []column +} + type readableTable interface { // Generates a select query on the current tableName. SELECT(projection projection, projections ...projection) SelectStatement @@ -24,13 +28,11 @@ type readableTable interface { // Creates a cross join tableName Expression using onCondition. CROSS_JOIN(table ReadableTable) ReadableTable - - columns() []Column } // The sql tableName write interface. type writableTable interface { - INSERT(column column, columns ...column) InsertStatement + INSERT(columns ...column) InsertStatement UPDATE(column column, columns ...column) UpdateStatement DELETE() DeleteStatement @@ -38,16 +40,19 @@ type writableTable interface { } type ReadableTable interface { + table readableTable clause } type WritableTable interface { + table writableTable clause } type Table interface { + table readableTable writableTable clause @@ -92,8 +97,14 @@ type writableTableInterfaceImpl struct { parent WritableTable } -func (w *writableTableInterfaceImpl) INSERT(column column, columns ...column) InsertStatement { - return newInsertStatement(w.parent, unwindColumns(column, columns...)) +func (w *writableTableInterfaceImpl) INSERT(columns ...column) InsertStatement { + //columnList := unwidColumnList(columns) + // + //if len(columns) == 0 { + // columnList = w.parent.columns() + //} + + return newInsertStatement(w.parent, unwidColumnList(columns)) } func (w *writableTableInterfaceImpl) UPDATE(column column, columns ...column) UpdateStatement { @@ -153,8 +164,14 @@ func (t *tableImpl) TableName() string { return t.name } -func (t *tableImpl) columns() []Column { - return t.columnList +func (t *tableImpl) columns() []column { + ret := []column{} + + for _, col := range t.columnList { + ret = append(ret, col) + } + + return ret } func (t *tableImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { @@ -220,7 +237,7 @@ func (t *joinTable) TableName() string { return "" } -func (t *joinTable) columns() []Column { +func (t *joinTable) columns() []column { return append(t.lhs.columns(), t.rhs.columns()...) } @@ -289,3 +306,19 @@ func unwindColumns(column1 column, columns ...column) []column { return columnList } + +func unwidColumnList(columns []column) []column { + ret := []column{} + + for _, col := range columns { + if columnList, ok := col.(ColumnList); ok { + for _, c := range columnList { + ret = append(ret, c) + } + } else { + ret = append(ret, col) + } + } + + return ret +} diff --git a/tests/all_types_test.go b/tests/all_types_test.go index 7d3a361..42f2141 100644 --- a/tests/all_types_test.go +++ b/tests/all_types_test.go @@ -26,8 +26,8 @@ func TestAllTypesSelect(t *testing.T) { func TestAllTypesInsertModel(t *testing.T) { query := AllTypes.INSERT(AllTypes.AllColumns). - USING(allTypesRow0). - USING(&allTypesRow1). + MODEL(allTypesRow0). + MODEL(&allTypesRow1). RETURNING(AllTypes.AllColumns) dest := []model.AllTypes{} diff --git a/tests/init/data/test_sample.sql b/tests/init/data/test_sample.sql index c8b11b0..5ac455b 100644 --- a/tests/init/data/test_sample.sql +++ b/tests/init/data/test_sample.sql @@ -144,11 +144,10 @@ VALUES (1, 1, 300, 300, 50000, 5000, 11.44, 11.44, 55.77, 55.77, 99.1, 99.1, 111 DROP TABLE IF EXISTS test_sample.link; CREATE TABLE IF NOT EXISTS test_sample.link ( - ID serial PRIMARY KEY, + id serial PRIMARY KEY, url VARCHAR (255) NOT NULL, name VARCHAR (255) NOT NULL, - description VARCHAR (255), - rel VARCHAR (50) + description VARCHAR (255) ); INSERT INTO test_sample.link (ID, url, name, description) VALUES diff --git a/tests/insert_test.go b/tests/insert_test.go index 84d7d56..229277c 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -9,22 +9,20 @@ import ( ) func TestInsertValues(t *testing.T) { - cleanUpLinkTable(t) var expectedSql = ` -INSERT INTO test_sample.link (id, url, name, rel) VALUES +INSERT INTO test_sample.link (id, url, name, description) VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), (101, 'http://www.google.com', 'Google', DEFAULT), (102, 'http://www.yahoo.com', 'Yahoo', NULL) RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", - link.description AS "link.description", - link.rel AS "link.rel"; + link.description AS "link.description"; ` - insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Rel). + insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(101, "http://www.google.com", "Google", DEFAULT). VALUES(102, "http://www.yahoo.com", "Yahoo", nil). @@ -47,21 +45,18 @@ RETURNING link.id AS "link.id", ID: 100, URL: "http://www.postgresqltutorial.com", Name: "PostgreSQL Tutorial", - Rel: nil, }) assert.DeepEqual(t, insertedLinks[1], model.Link{ ID: 101, URL: "http://www.google.com", Name: "Google", - Rel: nil, }) assert.DeepEqual(t, insertedLinks[2], model.Link{ ID: 102, URL: "http://www.yahoo.com", Name: "Yahoo", - Rel: nil, }) allLinks := []model.Link{} @@ -76,7 +71,24 @@ RETURNING link.id AS "link.id", assert.DeepEqual(t, insertedLinks, allLinks) } -func TestInsertDataObject(t *testing.T) { +func TestInsertEmptyColumnList(t *testing.T) { + cleanUpLinkTable(t) + + expectedSql := ` +INSERT INTO test_sample.link VALUES + (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); +` + + stmt := Link.INSERT(). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) + + assertStatementSql(t, stmt, expectedSql, + 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") + + assertExec(t, stmt, 1) +} + +func TestInsertModelObject(t *testing.T) { var expectedSql = ` INSERT INTO test_sample.link (url, name) VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); @@ -85,12 +97,11 @@ INSERT INTO test_sample.link (url, name) VALUES linkData := model.Link{ URL: "http://www.duckduckgo.com", Name: "Duck Duck go", - Rel: nil, } query := Link. INSERT(Link.URL, Link.Name). - USING(linkData) + MODEL(linkData) assertStatementSql(t, query, expectedSql, "http://www.duckduckgo.com", "Duck Duck go") @@ -103,6 +114,75 @@ INSERT INTO test_sample.link (url, name) VALUES assert.Equal(t, rowsAffected, int64(1)) } +func TestInsertModelsObject(t *testing.T) { + expectedSql := ` +INSERT INTO test_sample.link (url, name) VALUES + ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); +` + + tutorial := model.Link{ + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial", + } + + google := model.Link{ + URL: "http://www.google.com", + Name: "Google", + } + + yahoo := model.Link{ + URL: "http://www.yahoo.com", + Name: "Yahoo", + } + + stmt := Link. + INSERT(Link.URL, Link.Name). + MODELS([]model.Link{tutorial, google, yahoo}) + + assertStatementSql(t, stmt, expectedSql, + "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + "http://www.google.com", "Google", + "http://www.yahoo.com", "Yahoo") + + assertExec(t, stmt, 3) +} + +func TestInsertUsingMutableColumns(t *testing.T) { + var expectedSql = ` +INSERT INTO test_sample.link (url, name, description) VALUES + ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); +` + + google := model.Link{ + URL: "http://www.google.com", + Name: "Google", + } + + yahoo := model.Link{ + URL: "http://www.yahoo.com", + Name: "Yahoo", + } + + stmt := Link. + INSERT(Link.MutableColumns). + VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + MODEL(google). + MODELS([]model.Link{google, yahoo}) + + assertStatementSql(t, stmt, expectedSql, + "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + "http://www.google.com", "Google", nil, + "http://www.google.com", "Google", nil, + "http://www.yahoo.com", "Yahoo", nil) + + assertExec(t, stmt, 4) +} + func TestInsertQuery(t *testing.T) { _, err := Link.DELETE(). WHERE(Link.ID.NOT_EQ(Int(0)).AND(Link.Name.EQ(String("Youtube")))). @@ -119,8 +199,7 @@ INSERT INTO test_sample.link (url, name) ( RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", - link.description AS "link.description", - link.rel AS "link.rel"; + link.description AS "link.description"; ` query := Link. diff --git a/tests/update_test.go b/tests/update_test.go index 4bdffbd..32738d3 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -86,8 +86,7 @@ WHERE link.name = 'Ask' RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", - link.description AS "link.description", - link.rel AS "link.rel"; + link.description AS "link.description"; ` stmt := Link. @@ -120,12 +119,11 @@ func TestUpdateWithSelect(t *testing.T) { expectedSql := ` UPDATE test_sample.link -SET (id, url, name, description, rel) = ( +SET (id, url, name, description) = ( SELECT link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", - link.description AS "link.description", - link.rel AS "link.rel" + link.description AS "link.description" FROM test_sample.link WHERE link.id = 0 ) @@ -148,7 +146,7 @@ func TestUpdateWithInvalidSelect(t *testing.T) { var expectedSql = ` UPDATE test_sample.link -SET (id, url, name, description, rel) = ( +SET (id, url, name, description) = ( SELECT link.id AS "link.id", link.name AS "link.name" FROM test_sample.link @@ -177,10 +175,10 @@ func TestUpdateWithModelData(t *testing.T) { expectedSql := ` UPDATE test_sample.link -SET (id, url, name, description, rel) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL, NULL) +SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL) WHERE link.id = 201; ` - assertStatementSql(t, stmt, expectedSql, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201)) + assertStatementSql(t, stmt, expectedSql, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) assertExec(t, stmt, 1) } @@ -195,7 +193,7 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { Name: "DuckDuckGo", } - updateColumnList := ColumnList{Link.Rel, Link.Name, Link.URL} + updateColumnList := ColumnList{Link.Description, Link.Name, Link.URL} stmt := Link. UPDATE(updateColumnList). @@ -204,7 +202,7 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { var expectedSql = ` UPDATE test_sample.link -SET (rel, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') +SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') WHERE link.id = 201; ` assertStatementSql(t, stmt, expectedSql, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) @@ -252,7 +250,7 @@ func setupLinkTableForUpdateTest(t *testing.T) { cleanUpLinkTable(t) - _, err := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Rel). + _, err := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(201, "http://www.ask.com", "Ask", DEFAULT). VALUES(202, "http://www.ask.com", "Ask", DEFAULT). diff --git a/utils.go b/utils.go index f9c6f68..982c986 100644 --- a/utils.go +++ b/utils.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/serenize/snaker" "reflect" + "strings" ) func serializeOrderByClauseList(statement statementType, orderByClauses []OrderByClause, out *queryData) error { @@ -166,6 +167,21 @@ func unwindRowFromModel(columns []column, data interface{}) []clause { return row } +func unwindRowsFromModels(columns []column, data interface{}) [][]clause { + sliceValue := reflect.Indirect(reflect.ValueOf(data)) + mustBe(sliceValue, reflect.Slice) + + rows := [][]clause{} + + for i := 0; i < sliceValue.Len(); i++ { + structValue := sliceValue.Index(i) + + rows = append(rows, unwindRowFromModel(columns, structValue.Interface())) + } + + return rows +} + func unwindRowFromValues(value interface{}, values []interface{}) []clause { row := []clause{} @@ -178,8 +194,16 @@ func unwindRowFromValues(value interface{}, values []interface{}) []clause { return row } -func mustBe(v reflect.Value, expected reflect.Kind) { - if k := v.Kind(); k != expected { - panic("argument mismatch: expected " + expected.String() + ", got " + v.Type().String()) +func mustBe(v reflect.Value, expectedKinds ...reflect.Kind) { + indirectV := reflect.Indirect(v) + types := []string{} + + for _, expectedKind := range expectedKinds { + types = append(types, expectedKind.String()) + if k := indirectV.Kind(); k == expectedKind { + return + } } + + panic("argument mismatch: expected " + strings.Join(types, " or ") + ", got " + v.Type().String()) }