From 8a2c34fbd7ef70d585f2971d4888618a529152e9 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 15 Jun 2019 13:58:45 +0200 Subject: [PATCH] Select lock and table lock improvements. --- sqlbuilder/bool_expression_test.go | 1 + sqlbuilder/column_test_old.go | 210 ------------- sqlbuilder/delete_statement.go | 6 +- sqlbuilder/expression.go | 6 +- sqlbuilder/expression_old_test.go | 438 ---------------------------- sqlbuilder/expression_test.go | 2 +- sqlbuilder/insert_statement.go | 11 - sqlbuilder/insert_statement_test.go | 21 +- sqlbuilder/lock_statement.go | 8 +- sqlbuilder/lock_statement_test.go | 40 +-- sqlbuilder/row_type.go | 10 + sqlbuilder/select_statement.go | 78 ++++- sqlbuilder/select_statement_test.go | 108 +++++++ sqlbuilder/table.go | 6 +- sqlbuilder/test_utils.go | 3 + sqlbuilder/types.go | 35 --- sqlbuilder/utils.go | 28 +- tests/select_test.go | 109 ++++++- tests/update_test.go | 5 + 19 files changed, 363 insertions(+), 762 deletions(-) delete mode 100644 sqlbuilder/column_test_old.go create mode 100644 sqlbuilder/row_type.go create mode 100644 sqlbuilder/select_statement_test.go delete mode 100644 sqlbuilder/types.go diff --git a/sqlbuilder/bool_expression_test.go b/sqlbuilder/bool_expression_test.go index 19e1d3c..0191c73 100644 --- a/sqlbuilder/bool_expression_test.go +++ b/sqlbuilder/bool_expression_test.go @@ -5,6 +5,7 @@ import ( ) func TestBoolExpressionEQ(t *testing.T) { + assertClauseSerializeErr(t, table1ColBool.EQ(nil), "nil rhs") assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.colBool = table2.colBool)") assertClauseSerialize(t, table1ColBool.EQ(Bool(true)), "(table1.colBool = $1)", true) } diff --git a/sqlbuilder/column_test_old.go b/sqlbuilder/column_test_old.go deleted file mode 100644 index 5f4e47b..0000000 --- a/sqlbuilder/column_test_old.go +++ /dev/null @@ -1,210 +0,0 @@ -// +build disabled - -package sqlbuilder - -import ( - "bytes" - "testing" - - gc "gopkg.in/check.v1" -) - -func Test(t *testing.T) { - gc.TestingT(t) -} - -type ColumnSuite struct { -} - -var _ = gc.Suite(&ColumnSuite{}) - -// -// tests for columnImpl and columns that extends columnImpl -// - -func (s *ColumnSuite) TestRealColumnName(c *gc.C) { - col := IntColumn("col", Nullable) - - c.Assert(col.Name(), gc.Equals, "col") -} - -func (s *ColumnSuite) TestRealColumnSerializeSqlForColumnList(c *gc.C) { - col := IntColumn("col", Nullable) - - // Without tableName name - buf := &bytes.Buffer{} - - err := col.SerializeSqlForColumnList(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "col") - - // With tableName name - err = col.setTableName("foo") - c.Assert(err, gc.IsNil) - - buf = &bytes.Buffer{} - - err = col.SerializeSqlForColumnList(buf) - c.Assert(err, gc.IsNil) - - sql = buf.String() - c.Assert(sql, gc.Equals, "foo.col") -} - -func (s *ColumnSuite) TestRealColumnSerializeSql(c *gc.C) { - col := IntColumn("col", Nullable) - - // Without tableName name - buf := &bytes.Buffer{} - - err := col.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "col") - - // With tableName name - err = col.setTableName("foo") - c.Assert(err, gc.IsNil) - - buf = &bytes.Buffer{} - - err = col.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql = buf.String() - c.Assert(sql, gc.Equals, "foo.col") -} - -// -// tests for AliasCoulmns -// - -func (s *ColumnSuite) TestAliasColumnName(c *gc.C) { - col := Alias("foo", SqlFunc("max", table1Col1)) - - c.Assert(col.Name(), gc.Equals, "foo") -} - -func (s *ColumnSuite) TestAliasColumnSerializeSqlForColumnList(c *gc.C) { - col := Alias("foo", SqlFunc("max", table1Col1)) - - buf := &bytes.Buffer{} - err := col.SerializeSqlForColumnList(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(err, gc.IsNil) - - c.Assert(sql, gc.Equals, "(max(table1.col1)) AS \"foo\"") -} - -func (s *ColumnSuite) TestAliasColumnSerializeSqlForColumnListNilExpr(c *gc.C) { - col := Alias("foo", nil) - - buf := &bytes.Buffer{} - err := col.SerializeSqlForColumnList(buf) - c.Assert(err, gc.NotNil) -} - -func (s *ColumnSuite) TestAliasColumnSerializeSqlForColumnListInvalidAlias( - c *gc.C) { - - col := Alias("1234", SqlFunc("max", table1Col1)) - - buf := &bytes.Buffer{} - err := col.SerializeSqlForColumnList(buf) - c.Assert(err, gc.NotNil) -} - -func (s *ColumnSuite) TestAliasColumnSerializeSql(c *gc.C) { - col := Alias("foo", SqlFunc("max", table1Col1)) - - buf := &bytes.Buffer{} - err := col.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "`foo`") -} - -func (s *ColumnSuite) TestAliasColumnSetTableName(c *gc.C) { - col := Alias("foo", SqlFunc("max", table1Col1)) - - // should always error - err := col.setTableName("test") - c.Assert(err, gc.NotNil) -} - -// -// tests for deferredLookkupColumnName -// - -func (s *ColumnSuite) TestDeferredLookupColumnName(c *gc.C) { - col := table1.C("foo") - - c.Assert(col.Name(), gc.Equals, "foo") -} - -func (s *ColumnSuite) TestDeferredLookupColumnSerializeSqlForColumnList( - c *gc.C) { - - col := table1.C("col1") - - buf := &bytes.Buffer{} - - err := col.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1") - - // check cached lookup - buf = &bytes.Buffer{} - - err = col.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql = buf.String() - c.Assert(sql, gc.Equals, "table1.col1") -} - -func (s *ColumnSuite) TestDeferredLookupColumnSerializeSqlForColumnListInvalidName( - c *gc.C) { - col := table1.C("foo") - - buf := &bytes.Buffer{} - - err := col.SerializeSql(buf) - c.Assert(err, gc.NotNil) -} - -func (s *ColumnSuite) TestDeferredLookupColumnSerializeSql(c *gc.C) { - col := table1.C("col1") - - buf := &bytes.Buffer{} - - err := col.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1") -} - -func (s *ColumnSuite) TestDeferredLookupColumnSerializeSqlInvalidName(c *gc.C) { - col := table1.C("foo") - - buf := &bytes.Buffer{} - - err := col.SerializeSql(buf) - c.Assert(err, gc.NotNil) -} - -func (s *ColumnSuite) TestDeferredLookupColumnSetTableName(c *gc.C) { - col := table1.C("col1") - - err := col.setTableName("foo") - c.Assert(err, gc.NotNil) -} diff --git a/sqlbuilder/delete_statement.go b/sqlbuilder/delete_statement.go index e8ed466..1c38783 100644 --- a/sqlbuilder/delete_statement.go +++ b/sqlbuilder/delete_statement.go @@ -30,13 +30,13 @@ func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement { func (d *deleteStatementImpl) serializeImpl(out *queryData) error { if d == nil { - return errors.New("Delete expression. ") + return errors.New("delete statement is nil") } out.newLine() out.writeString("DELETE FROM") if d.table == nil { - return errors.New("nil tableName.") + return errors.New("nil tableName") } if err := d.table.serialize(delete_statement, out); err != nil { @@ -44,7 +44,7 @@ func (d *deleteStatementImpl) serializeImpl(out *queryData) error { } if d.where == nil { - return errors.New("Deleting without a WHERE clause.") + return errors.New("deleting without a WHERE clause") } if err := out.writeWhere(delete_statement, d.where); err != nil { diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index 440dffe..2c836eb 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -163,13 +163,13 @@ func newBinaryExpression(lhs, rhs Expression, operator string) binaryOpExpressio func (c *binaryOpExpression) serialize(statement statementType, out *queryData, options ...serializeOption) error { if c == nil { - return errors.New("Binary Expression is nil.") + return errors.New("binary Expression is nil") } if c.lhs == nil { - return errors.New("nil lhs.") + return errors.New("nil lhs") } if c.rhs == nil { - return errors.New("nil rhs.") + return errors.New("nil rhs") } wrap := !contains(options, NO_WRAP) diff --git a/sqlbuilder/expression_old_test.go b/sqlbuilder/expression_old_test.go index f4cdcc3..fc42841 100644 --- a/sqlbuilder/expression_old_test.go +++ b/sqlbuilder/expression_old_test.go @@ -9,444 +9,6 @@ import ( gc "gopkg.in/check.v1" ) -type ExprSuite struct { -} - -var _ = gc.Suite(&ExprSuite{}) - -func (s *ExprSuite) TestConjunctExprEmptyList(c *gc.C) { - expr := And() - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.NotNil) -} - -func (s *ExprSuite) TestConjunctExprNilInList(c *gc.C) { - expr := And(nil, EqL(table1Col1, 1)) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.NotNil) -} - -func (s *ExprSuite) TestConjunctExprSingleElement(c *gc.C) { - expr := And(EqL(table1Col1, 1)) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1=1") -} - -func (s *ExprSuite) TestLikeExpr(c *gc.C) { - expr := LikeL(table1Col1, EscapeForLike("%my_prefix")+"%") - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "table1.col1 LIKE '\\%my\\_prefix%'") - -} - -func (s *ExprSuite) TestRegexExpr(c *gc.C) { - expr := RegexpL(table1Col1, "[[:<:]]log|[[.low-line.]]log") - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "table1.col1 REGEXP '[[:<:]]log|[[.low-line.]]log'") - -} - -func (s *ExprSuite) TestAndExpr(c *gc.C) { - expr := And(EqL(table1Col1, 1), EqL(table1ColFloat, 2), EqL(table1Col3, 3)) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "(table1.col1=1 AND table1.col2=2 AND table1.col3=3)") -} - -func (s *ExprSuite) TestOrExpr(c *gc.C) { - expr := Or(EqL(table1Col1, 1), EqL(table1ColFloat, 2), EqL(table1Col3, 3)) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "(table1.col1=1 OR table1.col2=2 OR table1.col3=3)") -} - -func (s *ExprSuite) TestAddExpr(c *gc.C) { - expr := Add(literal(1), literal(2), literal(3)) - - buf := &bytes.Buffer{} - - err := expr.Serialize(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "(1 + 2 + 3)") -} - -func (s *ExprSuite) TestSubExpr(c *gc.C) { - expr := Sub(literal(1), literal(2), literal(3)) - - buf := &bytes.Buffer{} - - err := expr.Serialize(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "(1 - 2 - 3)") -} - -func (s *ExprSuite) TestMulExpr(c *gc.C) { - expr := Mul(literal(1), literal(2), literal(3)) - - buf := &bytes.Buffer{} - - err := expr.Serialize(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "(1 * 2 * 3)") -} - -func (s *ExprSuite) TestDivExpr(c *gc.C) { - expr := Div(literal(1), literal(2), literal(3)) - - buf := &bytes.Buffer{} - - err := expr.Serialize(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "(1 / 2 / 3)") -} - -func (s *ExprSuite) TestBinaryExprNilLHS(c *gc.C) { - expr := GT(nil, table1Col1) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.NotNil) -} - -func (s *ExprSuite) TestNegateExpr(c *gc.C) { - expr := NOT(EqL(table1Col1, 123)) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "NOT (table1.col1=123)") -} - -func (s *ExprSuite) TestBinaryExprNilRHS(c *gc.C) { - expr := LT(table1Col1, nil) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.NotNil) -} - -func (s *ExprSuite) TestEqExpr(c *gc.C) { - expr := EqL(table1Col1, 321) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1=321") -} - -func (s *ExprSuite) TestEqExprNilLHS(c *gc.C) { - expr := EqL(table1Col1, nil) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1 IS null") -} - -func (s *ExprSuite) TestNeqExpr(c *gc.C) { - expr := NeqL(table1Col1, 123) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1!=123") -} - -func (s *ExprSuite) TestNeqExprNilLHS(c *gc.C) { - expr := NeqL(table1Col1, nil) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1 IS NOT null") -} - -func (s *ExprSuite) TestLtExpr(c *gc.C) { - expr := LtL(table1Col1, -1.5) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1<-1.5") -} - -func (s *ExprSuite) TestLteExpr(c *gc.C) { - expr := LteL(table1Col1, "foo\"';drop user tableName;") - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "table1.col1<='foo\\\"\\';drop user tableName;'") -} - -func (s *ExprSuite) TestGtExpr(c *gc.C) { - expr := GtL(table1Col1, 1.1) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1>1.1") -} - -func (s *ExprSuite) TestGteExpr(c *gc.C) { - expr := GteL(table1Col1, 1) - - buf := &bytes.Buffer{} - - err := expr.serialize(0, buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1>=1") -} - -func (s *ExprSuite) TestInExpr(c *gc.C) { - values := []int32{1, 2, 3} - expr := In(table1Col1, values) - - buf := &bytes.Buffer{} - - err := expr.Serialize(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1 IN (1,2,3)") -} - -func (s *ExprSuite) TestInExprEmptyList(c *gc.C) { - values := []int32{} - expr := In(table1Col1, values) - - buf := &bytes.Buffer{} - - err := expr.Serialize(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "FALSE") -} - -func (s *ExprSuite) TestSqlFuncExprNilInArgList(c *gc.C) { - expr := SqlFunc("rand", nil) - - buf := &bytes.Buffer{} - - err := expr.Serialize(buf) - c.Assert(err, gc.NotNil) -} - -func (s *ExprSuite) TestSqlFuncExprEmptyArgList(c *gc.C) { - expr := SqlFunc("rand") - - buf := &bytes.Buffer{} - - err := expr.Serialize(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "rand()") -} - -func (s *ExprSuite) TestSqlFuncExprNonEmptyArgList(c *gc.C) { - expr := SqlFunc("add", table1Col1, table1ColFloat) - - buf := &bytes.Buffer{} - - err := expr.Serialize(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "add(table1.col1,table1.col2)") -} - -func (s *ExprSuite) TestOrderByClauseNilExpr(c *gc.C) { - clause := ASC(nil) - - buf := &bytes.Buffer{} - - err := clause.serialize(buf) - c.Assert(err, gc.NotNil) -} - -func (s *ExprSuite) TestAsc(c *gc.C) { - clause := ASC(table1Col1) - - buf := &bytes.Buffer{} - - err := clause.serialize(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1 ASC") -} - -func (s *ExprSuite) TestDesc(c *gc.C) { - clause := DESC(table1Col1) - - buf := &bytes.Buffer{} - - err := clause.serialize(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "table1.col1 DESC") -} - -func (s *ExprSuite) TestColumnValue(c *gc.C) { - clause := ColumnValue(table1Col1) - - buf := &bytes.Buffer{} - - err := clause.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "VALUES(table1.col1)") -} - -func (s *ExprSuite) TestBitwiseOr(c *gc.C) { - clause := BitOr(literal(1), literal(2)) - - buf := &bytes.Buffer{} - - err := clause.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "1 | 2") -} - -func (s *ExprSuite) TestBitwiseAnd(c *gc.C) { - clause := BitAnd(literal(1), literal(2)) - - buf := &bytes.Buffer{} - - err := clause.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "1 & 2") -} - -func (s *ExprSuite) TestBitwiseXor(c *gc.C) { - clause := BitXor(literal(1), literal(2)) - - buf := &bytes.Buffer{} - - err := clause.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "1 ^ 2") -} - -func (s *ExprSuite) TestPlus(c *gc.C) { - clause := Plus(literal(1), literal(2)) - - buf := &bytes.Buffer{} - - err := clause.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "1 + 2") -} - -func (s *ExprSuite) TestMinus(c *gc.C) { - clause := Minus(literal(1), literal(2)) - - buf := &bytes.Buffer{} - - err := clause.SerializeSql(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert(sql, gc.Equals, "1 - 2") -} - func (s *ExprSuite) TestInterval(c *gc.C) { testTable := []struct { interval time.Duration diff --git a/sqlbuilder/expression_test.go b/sqlbuilder/expression_test.go index 1593f57..c7e593e 100644 --- a/sqlbuilder/expression_test.go +++ b/sqlbuilder/expression_test.go @@ -7,7 +7,7 @@ import ( func TestExpressionIS_NULL(t *testing.T) { assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL") assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL") - assertClauseSerializeErr(t, table2Col3.ADD(nil), "nil rhs.") + assertClauseSerializeErr(t, table2Col3.ADD(nil), "nil rhs") } func TestExpressionIS_NOT_NULL(t *testing.T) { diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index e7e8089..0ca67d7 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -4,7 +4,6 @@ import ( "database/sql" "errors" "github.com/go-jet/jet/sqlbuilder/execution" - "strings" ) type InsertStatement interface { @@ -33,8 +32,6 @@ type insertStatementImpl struct { rows [][]clause query SelectStatement returning []projection - - errors []string } func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { @@ -57,19 +54,11 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState return i } -func (i *insertStatementImpl) addError(err string) { - i.errors = append(i.errors, err) -} - func (i *insertStatementImpl) DebugSql() (query string, err error) { return DebugSql(i) } func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) { - if len(i.errors) > 0 { - return "", nil, errors.New("errors: " + strings.Join(i.errors, ", ")) - } - queryData := &queryData{} queryData.newLine() diff --git a/sqlbuilder/insert_statement_test.go b/sqlbuilder/insert_statement_test.go index f674ae3..0c0cab5 100644 --- a/sqlbuilder/insert_statement_test.go +++ b/sqlbuilder/insert_statement_test.go @@ -119,24 +119,39 @@ INSERT INTO db.table1 (col1, colFloat) VALUES } func TestInsertValuesFromModelColumnMismatch(t *testing.T) { + defer func() { + r := recover() + assert.Equal(t, r, "missing struct field for column : col1") + }() type Table1Model struct { Col1Prim int Col2 string } - toInsert := Table1Model{ + newData := Table1Model{ Col1Prim: 1, Col2: "one", } - stmt := table1.INSERT(table1Col1, table1ColFloat). - USING(toInsert) + stmt := table1. + INSERT(table1Col1, table1ColFloat). + USING(newData) _, _, err := stmt.Sql() assert.Assert(t, err != nil) } +func TestInsertFromNonStructModel(t *testing.T) { + + defer func() { + r := recover() + assert.Equal(t, r, "argument mismatch: expected struct, got []int") + }() + + table2.INSERT(table2ColInt).USING([]int{}) +} + func TestInsertQuery(t *testing.T) { stmt := table1.INSERT(table1Col1). diff --git a/sqlbuilder/lock_statement.go b/sqlbuilder/lock_statement.go index b9271bb..4d960ad 100644 --- a/sqlbuilder/lock_statement.go +++ b/sqlbuilder/lock_statement.go @@ -6,7 +6,7 @@ import ( "github.com/pkg/errors" ) -type lockMode string +type TableLockMode string const ( LOCK_ACCESS_SHARE = "ACCESS SHARE" @@ -22,13 +22,13 @@ const ( type LockStatement interface { Statement - IN(lockMode lockMode) LockStatement + IN(lockMode TableLockMode) LockStatement NOWAIT() LockStatement } type lockStatementImpl struct { tables []WritableTable - lockMode lockMode + lockMode TableLockMode nowait bool } @@ -38,7 +38,7 @@ func LOCK(tables ...WritableTable) LockStatement { } } -func (l *lockStatementImpl) IN(lockMode lockMode) LockStatement { +func (l *lockStatementImpl) IN(lockMode TableLockMode) LockStatement { l.lockMode = lockMode return l } diff --git a/sqlbuilder/lock_statement_test.go b/sqlbuilder/lock_statement_test.go index af25cea..da878e1 100644 --- a/sqlbuilder/lock_statement_test.go +++ b/sqlbuilder/lock_statement_test.go @@ -1,28 +1,32 @@ package sqlbuilder import ( - "gotest.tools/assert" "testing" ) -func TestLockSingleTable(t *testing.T) { - lock := table1.LOCK().IN(LOCK_ROW_SHARE) - - queryStr, _, err := lock.Sql() - - assert.NilError(t, err) - assert.Equal(t, queryStr, ` +func TestLockTable(t *testing.T) { + assertStatement(t, table1.LOCK().IN(LOCK_ACCESS_SHARE), ` +LOCK TABLE db.table1 IN ACCESS SHARE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_ROW_SHARE), ` LOCK TABLE db.table1 IN ROW SHARE MODE; `) -} - -func TestLockMultipleTable(t *testing.T) { - lock := LOCK(table2, table1).IN(LOCK_ACCESS_EXCLUSIVE).NOWAIT() - - queryStr, _, err := lock.Sql() - - assert.NilError(t, err) - assert.Equal(t, queryStr, ` -LOCK TABLE db.table2, db.table1 IN ACCESS EXCLUSIVE MODE NOWAIT; + assertStatement(t, table1.LOCK().IN(LOCK_ROW_EXCLUSIVE), ` +LOCK TABLE db.table1 IN ROW EXCLUSIVE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_SHARE_UPDATE_EXCLUSIVE), ` +LOCK TABLE db.table1 IN SHARE UPDATE EXCLUSIVE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_SHARE), ` +LOCK TABLE db.table1 IN SHARE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_SHARE_ROW_EXCLUSIVE), ` +LOCK TABLE db.table1 IN SHARE ROW EXCLUSIVE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_EXCLUSIVE), ` +LOCK TABLE db.table1 IN EXCLUSIVE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_ACCESS_EXCLUSIVE), ` +LOCK TABLE db.table1 IN ACCESS EXCLUSIVE MODE; `) } diff --git a/sqlbuilder/row_type.go b/sqlbuilder/row_type.go new file mode 100644 index 0000000..b89914d --- /dev/null +++ b/sqlbuilder/row_type.go @@ -0,0 +1,10 @@ +package sqlbuilder + +type rowsType interface { + clause + hasRows() +} + +type isRowsType struct{} + +func (i *isRowsType) hasRows() {} diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 3c4a7d9..ebf8fa1 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -6,6 +6,13 @@ import ( "github.com/go-jet/jet/sqlbuilder/execution" ) +var ( + UPDATE = newLock("UPDATE") + NO_KEY_UPDATE = newLock("NO KEY UPDATE") + SHARE = newLock("SHARE") + KEY_SHARE = newLock("KEY SHARE") +) + type SelectStatement interface { Statement Expression @@ -21,17 +28,15 @@ type SelectStatement interface { LIMIT(limit int64) SelectStatement OFFSET(offset int64) SelectStatement - FOR_UPDATE() SelectStatement + FOR(lock SelectLock) SelectStatement AsTable(alias string) ExpressionTable } -func SELECT(projection ...projection) SelectStatement { - return newSelectStatement(nil, projection) +func SELECT(projection1 projection, projections ...projection) SelectStatement { + return newSelectStatement(nil, append([]projection{projection1}, projections...)) } -// NOTE: SelectStatement purposely does not implement the tableImpl interface since -// mysql's subquery performance is horrible. type selectStatementImpl struct { expressionInterfaceImpl isRowsType @@ -46,7 +51,7 @@ type selectStatementImpl struct { limit, offset int64 - forUpdate bool + lockFor SelectLock } func newSelectStatement(table ReadableTable, projections []projection) SelectStatement { @@ -55,7 +60,6 @@ func newSelectStatement(table ReadableTable, projections []projection) SelectSta projections: projections, limit: -1, offset: -1, - forUpdate: false, distinct: false, } @@ -161,15 +165,19 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error { out.insertPreparedArgument(s.offset) } - if s.forUpdate { + if s.lockFor != nil { out.newLine() - out.writeString("FOR UPDATE") + out.writeString("FOR") + err := s.lockFor.serialize(select_statement, out) + + if err != nil { + return err + } } return nil } -// Return the properly escaped SQL Statement, against the specified database func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error) { queryData := queryData{} @@ -229,11 +237,57 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement { return s } -func (s *selectStatementImpl) FOR_UPDATE() SelectStatement { - s.forUpdate = true +func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement { + s.lockFor = lock return s } +type SelectLock interface { + clause + + NOWAIT() SelectLock + SKIP_LOCKED() SelectLock +} + +type selectLockImpl struct { + lockStrength string + noWait, skipLocked bool +} + +func newLock(name string) func() SelectLock { + return func() SelectLock { + return newSelectLock(name) + } +} + +func newSelectLock(lockStrength string) SelectLock { + return &selectLockImpl{lockStrength: lockStrength} +} + +func (s *selectLockImpl) NOWAIT() SelectLock { + s.noWait = true + return s +} + +func (s *selectLockImpl) SKIP_LOCKED() SelectLock { + s.skipLocked = true + return s +} + +func (s *selectLockImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { + out.writeString(s.lockStrength) + + if s.noWait { + out.writeString("NOWAIT") + } + + if s.skipLocked { + out.writeString("SKIP LOCKED") + } + + return nil +} + func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) error { return Query(s, db, destination) } diff --git a/sqlbuilder/select_statement_test.go b/sqlbuilder/select_statement_test.go new file mode 100644 index 0000000..fcad208 --- /dev/null +++ b/sqlbuilder/select_statement_test.go @@ -0,0 +1,108 @@ +package sqlbuilder + +import "testing" + +func TestInvalidSelect(t *testing.T) { + assertStatementErr(t, SELECT(nil), "projection is nil") +} + +func TestSelectDistinct(t *testing.T) { + assertStatement(t, SELECT(table1ColBool).DISTINCT(), ` +SELECT DISTINCT table1.colBool AS "table1.colBool"; +`) +} + +func TestSelectFrom(t *testing.T) { + assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), ` +SELECT table1.colInt AS "table1.colInt", + table2.colFloat AS "table2.colFloat" +FROM db.table1; +`) + assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), ` +SELECT table1.colInt AS "table1.colInt", + table2.colFloat AS "table2.colFloat" +FROM db.table1 + INNER JOIN db.table2 ON (table1.colInt = table2.colInt); +`) + assertStatement(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), ` +SELECT table1.colInt AS "table1.colInt", + table2.colFloat AS "table2.colFloat" +FROM db.table1 + INNER JOIN db.table2 ON (table1.colInt = table2.colInt); +`) +} + +func TestSelectWhere(t *testing.T) { + assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` +SELECT table1.colInt AS "table1.colInt" +FROM db.table1 +WHERE $1; +`, true) + assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` +SELECT table1.colInt AS "table1.colInt" +FROM db.table1 +WHERE table1.colInt >= $1; +`, int64(10)) +} + +func TestSelectGroupBy(t *testing.T) { + assertStatement(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), ` +SELECT table2.colInt AS "table2.colInt" +FROM db.table2 +GROUP BY table2.colFloat; +`) +} + +func TestSelectHaving(t *testing.T) { + assertStatement(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` +SELECT table3.colInt AS "table3.colInt" +FROM db.table3 +HAVING table1.colBool = $1; +`, true) +} + +func TestSelectOrderBy(t *testing.T) { + assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), ` +SELECT table2.colFloat AS "table2.colFloat" +FROM db.table2 +ORDER BY table2.colInt DESC; +`) +} + +func TestSelectLimitOffset(t *testing.T) { + assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), ` +SELECT table2.colInt AS "table2.colInt" +FROM db.table2 +LIMIT $1; +`, int64(10)) + assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), ` +SELECT table2.colInt AS "table2.colInt" +FROM db.table2 +LIMIT $1 +OFFSET $2; +`, int64(10), int64(2)) +} + +func TestSelectLock(t *testing.T) { + assertStatement(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE()), ` +SELECT table1.colBool AS "table1.colBool" +FROM db.table1 +FOR UPDATE; +`) + assertStatement(t, SELECT(table1ColBool).FROM(table1).FOR(SHARE().NOWAIT()), ` +SELECT table1.colBool AS "table1.colBool" +FROM db.table1 +FOR SHARE NOWAIT; +`) + + assertStatement(t, SELECT(table1ColBool).FROM(table1).FOR(KEY_SHARE().NOWAIT()), ` +SELECT table1.colBool AS "table1.colBool" +FROM db.table1 +FOR KEY SHARE NOWAIT; +`) + assertStatement(t, SELECT(table1ColBool).FROM(table1).FOR(NO_KEY_UPDATE().SKIP_LOCKED()), ` +SELECT table1.colBool AS "table1.colBool" +FROM db.table1 +FOR NO KEY UPDATE SKIP LOCKED; +`) +} diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index f889a32..3093722 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -8,7 +8,7 @@ import ( type readableTable interface { // Generates a select query on the current tableName. - SELECT(projections ...projection) SelectStatement + SELECT(projection projection, projections ...projection) SelectStatement // Creates a inner join tableName Expression using onCondition. INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable @@ -59,8 +59,8 @@ type readableTableInterfaceImpl struct { } // Generates a select query on the current tableName. -func (r *readableTableInterfaceImpl) SELECT(projections ...projection) SelectStatement { - return newSelectStatement(r.parent, projections) +func (r *readableTableInterfaceImpl) SELECT(projection1 projection, projections ...projection) SelectStatement { + return newSelectStatement(r.parent, append([]projection{projection1}, projections...)) } // Creates a inner join tableName Expression using onCondition. diff --git a/sqlbuilder/test_utils.go b/sqlbuilder/test_utils.go index 845d884..37c1f47 100644 --- a/sqlbuilder/test_utils.go +++ b/sqlbuilder/test_utils.go @@ -65,6 +65,7 @@ func assertClauseSerializeErr(t *testing.T, clause clause, errString string) { out := queryData{} err := clause.serialize(select_statement, &out) + //fmt.Println(out.buff.String()) assert.Assert(t, err != nil) assert.Equal(t, err.Error(), errString) } @@ -82,6 +83,8 @@ func assertProjectionSerialize(t *testing.T, projection projection, query string func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) { queryStr, args, err := query.Sql() assert.NilError(t, err) + + //fmt.Println(queryStr) assert.Equal(t, queryStr, expectedQuery) assert.DeepEqual(t, args, expectedArgs) } diff --git a/sqlbuilder/types.go b/sqlbuilder/types.go deleted file mode 100644 index 5d2f136..0000000 --- a/sqlbuilder/types.go +++ /dev/null @@ -1,35 +0,0 @@ -package sqlbuilder - -type rowsType interface { - clause - hasRows() -} - -type isRowsType struct{} - -func (i *isRowsType) hasRows() {} - -// A clause that can be used in orderBy by - -// A clause that is selectable. -//type projection interface { -// clause -// isProjectionInterface -// -// SerializeSqlForColumnList(out *bytes.Buffer) error -//} - -// -// Boiler plates ... -// - -// -//type isProjectionInterface interface { -// isProjectionType() -//} -// -//type isProjection struct { -//} -// -//func (p *isProjection) isProjectionType() { -//} diff --git a/sqlbuilder/utils.go b/sqlbuilder/utils.go index 68aeab8..25c4929 100644 --- a/sqlbuilder/utils.go +++ b/sqlbuilder/utils.go @@ -88,7 +88,7 @@ func serializeProjectionList(statement statementType, projections []projection, } if col == nil { - return errors.New("projection Expression is nil") + return errors.New("projection is nil") } if err := col.serializeForProjection(statement, out); err != nil { @@ -115,6 +115,16 @@ func serializeColumnNames(columns []column, out *queryData) error { return nil } +func columnListToProjectionList(columns []Column) []projection { + var ret []projection + + for _, column := range columns { + ret = append(ret, column) + } + + return ret +} + func isNil(v interface{}) bool { return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) } @@ -132,9 +142,7 @@ func unwindRowFromModel(columns []column, data interface{}) []clause { row := []clause{} - if structValue.Kind() != reflect.Struct { - return row - } + mustBe(structValue, reflect.Struct) for _, column := range columns { columnName := column.Name() @@ -143,7 +151,7 @@ func unwindRowFromModel(columns []column, data interface{}) []clause { structField := structValue.FieldByName(structFieldName) if !structField.IsValid() { - continue + panic("missing struct field for column : " + column.Name()) } var field interface{} @@ -172,14 +180,10 @@ func unwindRowFromValues(value interface{}, values []interface{}) []clause { return row } -func columnListToProjectionList(columns []Column) []projection { - var ret []projection - - for _, column := range columns { - ret = append(ret, column) +func mustBe(v reflect.Value, expected reflect.Kind) { + if k := v.Kind(); k != expected { + panic("argument mismatch: expected " + expected.String() + ", got " + v.Type().String()) } - - return ret } func Query(statement Statement, db execution.Db, destination interface{}) error { diff --git a/tests/select_test.go b/tests/select_test.go index fa9490f..2043d1b 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -1104,18 +1104,109 @@ LIMIT 20; func TestLockTable(t *testing.T) { expectedSql := ` -LOCK TABLE dvds.address IN EXCLUSIVE MODE NOWAIT; -` - query := Address.LOCK().IN(LOCK_EXCLUSIVE).NOWAIT() +LOCK TABLE dvds.address IN` - querySql, _, _ := query.Sql() - fmt.Println("-" + querySql + "-") + var testData = []TableLockMode{ + LOCK_ACCESS_SHARE, + LOCK_ROW_SHARE, + LOCK_ROW_EXCLUSIVE, + LOCK_SHARE_UPDATE_EXCLUSIVE, + LOCK_SHARE, + LOCK_SHARE_ROW_EXCLUSIVE, + LOCK_EXCLUSIVE, + LOCK_ACCESS_EXCLUSIVE, + } - assertStatementSql(t, query, expectedSql) + for _, lockMode := range testData { + query := Address.LOCK().IN(lockMode) - tx, _ := db.Begin() + assertStatementSql(t, query, expectedSql+" "+string(lockMode)+" MODE;\n") - _, err := query.Exec(tx) + tx, _ := db.Begin() - assert.NilError(t, err) + _, err := query.Exec(tx) + + assert.NilError(t, err) + + tx.Rollback() + } + + for _, lockMode := range testData { + query := Address.LOCK().IN(lockMode).NOWAIT() + + assertStatementSql(t, query, expectedSql+" "+string(lockMode)+" MODE NOWAIT;\n") + + tx, _ := db.Begin() + + _, err := query.Exec(tx) + + assert.NilError(t, err) + + tx.Rollback() + } +} + +func getRowLockTestData() map[SelectLock]string { + return map[SelectLock]string{ + UPDATE(): "UPDATE", + NO_KEY_UPDATE(): "NO KEY UPDATE", + SHARE(): "SHARE", + KEY_SHARE(): "KEY SHARE", + } +} + +func TestRowLock(t *testing.T) { + expectedSql := ` +SELECT * +FROM dvds.address +LIMIT 3 +FOR` + query := Address. + SELECT(STAR). + LIMIT(3) + + for lockType, lockTypeStr := range getRowLockTestData() { + query.FOR(lockType) + + assertStatementSql(t, query, expectedSql+" "+lockTypeStr+";\n", int64(3)) + + tx, _ := db.Begin() + + res, err := query.Exec(tx) + assert.NilError(t, err) + rowsAffected, _ := res.RowsAffected() + assert.Equal(t, rowsAffected, int64(3)) + + tx.Rollback() + } + + for lockType, lockTypeStr := range getRowLockTestData() { + query.FOR(lockType.NOWAIT()) + + assertStatementSql(t, query, expectedSql+" "+lockTypeStr+" NOWAIT;\n", int64(3)) + + tx, _ := db.Begin() + + res, err := query.Exec(tx) + assert.NilError(t, err) + rowsAffected, _ := res.RowsAffected() + assert.Equal(t, rowsAffected, int64(3)) + + tx.Rollback() + } + + for lockType, lockTypeStr := range getRowLockTestData() { + query.FOR(lockType.SKIP_LOCKED()) + + assertStatementSql(t, query, expectedSql+" "+lockTypeStr+" SKIP LOCKED;\n", int64(3)) + + tx, _ := db.Begin() + + res, err := query.Exec(tx) + assert.NilError(t, err) + rowsAffected, _ := res.RowsAffected() + assert.Equal(t, rowsAffected, int64(3)) + + tx.Rollback() + } } diff --git a/tests/update_test.go b/tests/update_test.go index 26997f1..866a118 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -213,6 +213,11 @@ WHERE link.id = 201; } func TestUpdateWithInvalidModelData(t *testing.T) { + defer func() { + r := recover() + + assert.Equal(t, r, "missing struct field for column : id") + }() setupLinkTableForUpdateTest(t)