Select lock and table lock improvements.

This commit is contained in:
go-jet 2019-06-15 13:58:45 +02:00
parent a4feb66692
commit 8a2c34fbd7
19 changed files with 363 additions and 762 deletions

View file

@ -5,6 +5,7 @@ import (
) )
func TestBoolExpressionEQ(t *testing.T) { func TestBoolExpressionEQ(t *testing.T) {
assertClauseSerializeErr(t, table1ColBool.EQ(nil), "nil rhs")
assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.colBool = table2.colBool)") assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.colBool = table2.colBool)")
assertClauseSerialize(t, table1ColBool.EQ(Bool(true)), "(table1.colBool = $1)", true) assertClauseSerialize(t, table1ColBool.EQ(Bool(true)), "(table1.colBool = $1)", true)
} }

View file

@ -1,210 +0,0 @@
// +build disabled
package sqlbuilder
import (
"bytes"
"testing"
gc "gopkg.in/check.v1"
)
func Test(t *testing.T) {
gc.TestingT(t)
}
type ColumnSuite struct {
}
var _ = gc.Suite(&ColumnSuite{})
//
// tests for columnImpl and columns that extends columnImpl
//
func (s *ColumnSuite) TestRealColumnName(c *gc.C) {
col := IntColumn("col", Nullable)
c.Assert(col.Name(), gc.Equals, "col")
}
func (s *ColumnSuite) TestRealColumnSerializeSqlForColumnList(c *gc.C) {
col := IntColumn("col", Nullable)
// Without tableName name
buf := &bytes.Buffer{}
err := col.SerializeSqlForColumnList(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "col")
// With tableName name
err = col.setTableName("foo")
c.Assert(err, gc.IsNil)
buf = &bytes.Buffer{}
err = col.SerializeSqlForColumnList(buf)
c.Assert(err, gc.IsNil)
sql = buf.String()
c.Assert(sql, gc.Equals, "foo.col")
}
func (s *ColumnSuite) TestRealColumnSerializeSql(c *gc.C) {
col := IntColumn("col", Nullable)
// Without tableName name
buf := &bytes.Buffer{}
err := col.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "col")
// With tableName name
err = col.setTableName("foo")
c.Assert(err, gc.IsNil)
buf = &bytes.Buffer{}
err = col.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql = buf.String()
c.Assert(sql, gc.Equals, "foo.col")
}
//
// tests for AliasCoulmns
//
func (s *ColumnSuite) TestAliasColumnName(c *gc.C) {
col := Alias("foo", SqlFunc("max", table1Col1))
c.Assert(col.Name(), gc.Equals, "foo")
}
func (s *ColumnSuite) TestAliasColumnSerializeSqlForColumnList(c *gc.C) {
col := Alias("foo", SqlFunc("max", table1Col1))
buf := &bytes.Buffer{}
err := col.SerializeSqlForColumnList(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(err, gc.IsNil)
c.Assert(sql, gc.Equals, "(max(table1.col1)) AS \"foo\"")
}
func (s *ColumnSuite) TestAliasColumnSerializeSqlForColumnListNilExpr(c *gc.C) {
col := Alias("foo", nil)
buf := &bytes.Buffer{}
err := col.SerializeSqlForColumnList(buf)
c.Assert(err, gc.NotNil)
}
func (s *ColumnSuite) TestAliasColumnSerializeSqlForColumnListInvalidAlias(
c *gc.C) {
col := Alias("1234", SqlFunc("max", table1Col1))
buf := &bytes.Buffer{}
err := col.SerializeSqlForColumnList(buf)
c.Assert(err, gc.NotNil)
}
func (s *ColumnSuite) TestAliasColumnSerializeSql(c *gc.C) {
col := Alias("foo", SqlFunc("max", table1Col1))
buf := &bytes.Buffer{}
err := col.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "`foo`")
}
func (s *ColumnSuite) TestAliasColumnSetTableName(c *gc.C) {
col := Alias("foo", SqlFunc("max", table1Col1))
// should always error
err := col.setTableName("test")
c.Assert(err, gc.NotNil)
}
//
// tests for deferredLookkupColumnName
//
func (s *ColumnSuite) TestDeferredLookupColumnName(c *gc.C) {
col := table1.C("foo")
c.Assert(col.Name(), gc.Equals, "foo")
}
func (s *ColumnSuite) TestDeferredLookupColumnSerializeSqlForColumnList(
c *gc.C) {
col := table1.C("col1")
buf := &bytes.Buffer{}
err := col.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1")
// check cached lookup
buf = &bytes.Buffer{}
err = col.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql = buf.String()
c.Assert(sql, gc.Equals, "table1.col1")
}
func (s *ColumnSuite) TestDeferredLookupColumnSerializeSqlForColumnListInvalidName(
c *gc.C) {
col := table1.C("foo")
buf := &bytes.Buffer{}
err := col.SerializeSql(buf)
c.Assert(err, gc.NotNil)
}
func (s *ColumnSuite) TestDeferredLookupColumnSerializeSql(c *gc.C) {
col := table1.C("col1")
buf := &bytes.Buffer{}
err := col.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1")
}
func (s *ColumnSuite) TestDeferredLookupColumnSerializeSqlInvalidName(c *gc.C) {
col := table1.C("foo")
buf := &bytes.Buffer{}
err := col.SerializeSql(buf)
c.Assert(err, gc.NotNil)
}
func (s *ColumnSuite) TestDeferredLookupColumnSetTableName(c *gc.C) {
col := table1.C("col1")
err := col.setTableName("foo")
c.Assert(err, gc.NotNil)
}

View file

@ -30,13 +30,13 @@ func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
func (d *deleteStatementImpl) serializeImpl(out *queryData) error { func (d *deleteStatementImpl) serializeImpl(out *queryData) error {
if d == nil { if d == nil {
return errors.New("Delete expression. ") return errors.New("delete statement is nil")
} }
out.newLine() out.newLine()
out.writeString("DELETE FROM") out.writeString("DELETE FROM")
if d.table == nil { if d.table == nil {
return errors.New("nil tableName.") return errors.New("nil tableName")
} }
if err := d.table.serialize(delete_statement, out); err != nil { if err := d.table.serialize(delete_statement, out); err != nil {
@ -44,7 +44,7 @@ func (d *deleteStatementImpl) serializeImpl(out *queryData) error {
} }
if d.where == nil { if d.where == nil {
return errors.New("Deleting without a WHERE clause.") return errors.New("deleting without a WHERE clause")
} }
if err := out.writeWhere(delete_statement, d.where); err != nil { if err := out.writeWhere(delete_statement, d.where); err != nil {

View file

@ -163,13 +163,13 @@ func newBinaryExpression(lhs, rhs Expression, operator string) binaryOpExpressio
func (c *binaryOpExpression) serialize(statement statementType, out *queryData, options ...serializeOption) error { func (c *binaryOpExpression) serialize(statement statementType, out *queryData, options ...serializeOption) error {
if c == nil { if c == nil {
return errors.New("Binary Expression is nil.") return errors.New("binary Expression is nil")
} }
if c.lhs == nil { if c.lhs == nil {
return errors.New("nil lhs.") return errors.New("nil lhs")
} }
if c.rhs == nil { if c.rhs == nil {
return errors.New("nil rhs.") return errors.New("nil rhs")
} }
wrap := !contains(options, NO_WRAP) wrap := !contains(options, NO_WRAP)

View file

@ -9,444 +9,6 @@ import (
gc "gopkg.in/check.v1" gc "gopkg.in/check.v1"
) )
type ExprSuite struct {
}
var _ = gc.Suite(&ExprSuite{})
func (s *ExprSuite) TestConjunctExprEmptyList(c *gc.C) {
expr := And()
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil)
}
func (s *ExprSuite) TestConjunctExprNilInList(c *gc.C) {
expr := And(nil, EqL(table1Col1, 1))
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil)
}
func (s *ExprSuite) TestConjunctExprSingleElement(c *gc.C) {
expr := And(EqL(table1Col1, 1))
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1=1")
}
func (s *ExprSuite) TestLikeExpr(c *gc.C) {
expr := LikeL(table1Col1, EscapeForLike("%my_prefix")+"%")
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(
sql,
gc.Equals,
"table1.col1 LIKE '\\%my\\_prefix%'")
}
func (s *ExprSuite) TestRegexExpr(c *gc.C) {
expr := RegexpL(table1Col1, "[[:<:]]log|[[.low-line.]]log")
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(
sql,
gc.Equals,
"table1.col1 REGEXP '[[:<:]]log|[[.low-line.]]log'")
}
func (s *ExprSuite) TestAndExpr(c *gc.C) {
expr := And(EqL(table1Col1, 1), EqL(table1ColFloat, 2), EqL(table1Col3, 3))
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(
sql,
gc.Equals,
"(table1.col1=1 AND table1.col2=2 AND table1.col3=3)")
}
func (s *ExprSuite) TestOrExpr(c *gc.C) {
expr := Or(EqL(table1Col1, 1), EqL(table1ColFloat, 2), EqL(table1Col3, 3))
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(
sql,
gc.Equals,
"(table1.col1=1 OR table1.col2=2 OR table1.col3=3)")
}
func (s *ExprSuite) TestAddExpr(c *gc.C) {
expr := Add(literal(1), literal(2), literal(3))
buf := &bytes.Buffer{}
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "(1 + 2 + 3)")
}
func (s *ExprSuite) TestSubExpr(c *gc.C) {
expr := Sub(literal(1), literal(2), literal(3))
buf := &bytes.Buffer{}
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "(1 - 2 - 3)")
}
func (s *ExprSuite) TestMulExpr(c *gc.C) {
expr := Mul(literal(1), literal(2), literal(3))
buf := &bytes.Buffer{}
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "(1 * 2 * 3)")
}
func (s *ExprSuite) TestDivExpr(c *gc.C) {
expr := Div(literal(1), literal(2), literal(3))
buf := &bytes.Buffer{}
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "(1 / 2 / 3)")
}
func (s *ExprSuite) TestBinaryExprNilLHS(c *gc.C) {
expr := GT(nil, table1Col1)
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil)
}
func (s *ExprSuite) TestNegateExpr(c *gc.C) {
expr := NOT(EqL(table1Col1, 123))
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "NOT (table1.col1=123)")
}
func (s *ExprSuite) TestBinaryExprNilRHS(c *gc.C) {
expr := LT(table1Col1, nil)
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil)
}
func (s *ExprSuite) TestEqExpr(c *gc.C) {
expr := EqL(table1Col1, 321)
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1=321")
}
func (s *ExprSuite) TestEqExprNilLHS(c *gc.C) {
expr := EqL(table1Col1, nil)
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1 IS null")
}
func (s *ExprSuite) TestNeqExpr(c *gc.C) {
expr := NeqL(table1Col1, 123)
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1!=123")
}
func (s *ExprSuite) TestNeqExprNilLHS(c *gc.C) {
expr := NeqL(table1Col1, nil)
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1 IS NOT null")
}
func (s *ExprSuite) TestLtExpr(c *gc.C) {
expr := LtL(table1Col1, -1.5)
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1<-1.5")
}
func (s *ExprSuite) TestLteExpr(c *gc.C) {
expr := LteL(table1Col1, "foo\"';drop user tableName;")
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(
sql,
gc.Equals,
"table1.col1<='foo\\\"\\';drop user tableName;'")
}
func (s *ExprSuite) TestGtExpr(c *gc.C) {
expr := GtL(table1Col1, 1.1)
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1>1.1")
}
func (s *ExprSuite) TestGteExpr(c *gc.C) {
expr := GteL(table1Col1, 1)
buf := &bytes.Buffer{}
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1>=1")
}
func (s *ExprSuite) TestInExpr(c *gc.C) {
values := []int32{1, 2, 3}
expr := In(table1Col1, values)
buf := &bytes.Buffer{}
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1 IN (1,2,3)")
}
func (s *ExprSuite) TestInExprEmptyList(c *gc.C) {
values := []int32{}
expr := In(table1Col1, values)
buf := &bytes.Buffer{}
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "FALSE")
}
func (s *ExprSuite) TestSqlFuncExprNilInArgList(c *gc.C) {
expr := SqlFunc("rand", nil)
buf := &bytes.Buffer{}
err := expr.Serialize(buf)
c.Assert(err, gc.NotNil)
}
func (s *ExprSuite) TestSqlFuncExprEmptyArgList(c *gc.C) {
expr := SqlFunc("rand")
buf := &bytes.Buffer{}
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "rand()")
}
func (s *ExprSuite) TestSqlFuncExprNonEmptyArgList(c *gc.C) {
expr := SqlFunc("add", table1Col1, table1ColFloat)
buf := &bytes.Buffer{}
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "add(table1.col1,table1.col2)")
}
func (s *ExprSuite) TestOrderByClauseNilExpr(c *gc.C) {
clause := ASC(nil)
buf := &bytes.Buffer{}
err := clause.serialize(buf)
c.Assert(err, gc.NotNil)
}
func (s *ExprSuite) TestAsc(c *gc.C) {
clause := ASC(table1Col1)
buf := &bytes.Buffer{}
err := clause.serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1 ASC")
}
func (s *ExprSuite) TestDesc(c *gc.C) {
clause := DESC(table1Col1)
buf := &bytes.Buffer{}
err := clause.serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "table1.col1 DESC")
}
func (s *ExprSuite) TestColumnValue(c *gc.C) {
clause := ColumnValue(table1Col1)
buf := &bytes.Buffer{}
err := clause.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "VALUES(table1.col1)")
}
func (s *ExprSuite) TestBitwiseOr(c *gc.C) {
clause := BitOr(literal(1), literal(2))
buf := &bytes.Buffer{}
err := clause.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "1 | 2")
}
func (s *ExprSuite) TestBitwiseAnd(c *gc.C) {
clause := BitAnd(literal(1), literal(2))
buf := &bytes.Buffer{}
err := clause.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "1 & 2")
}
func (s *ExprSuite) TestBitwiseXor(c *gc.C) {
clause := BitXor(literal(1), literal(2))
buf := &bytes.Buffer{}
err := clause.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "1 ^ 2")
}
func (s *ExprSuite) TestPlus(c *gc.C) {
clause := Plus(literal(1), literal(2))
buf := &bytes.Buffer{}
err := clause.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "1 + 2")
}
func (s *ExprSuite) TestMinus(c *gc.C) {
clause := Minus(literal(1), literal(2))
buf := &bytes.Buffer{}
err := clause.SerializeSql(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(sql, gc.Equals, "1 - 2")
}
func (s *ExprSuite) TestInterval(c *gc.C) { func (s *ExprSuite) TestInterval(c *gc.C) {
testTable := []struct { testTable := []struct {
interval time.Duration interval time.Duration

View file

@ -7,7 +7,7 @@ import (
func TestExpressionIS_NULL(t *testing.T) { func TestExpressionIS_NULL(t *testing.T) {
assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL") assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL")
assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL") assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL")
assertClauseSerializeErr(t, table2Col3.ADD(nil), "nil rhs.") assertClauseSerializeErr(t, table2Col3.ADD(nil), "nil rhs")
} }
func TestExpressionIS_NOT_NULL(t *testing.T) { func TestExpressionIS_NOT_NULL(t *testing.T) {

View file

@ -4,7 +4,6 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"github.com/go-jet/jet/sqlbuilder/execution" "github.com/go-jet/jet/sqlbuilder/execution"
"strings"
) )
type InsertStatement interface { type InsertStatement interface {
@ -33,8 +32,6 @@ type insertStatementImpl struct {
rows [][]clause rows [][]clause
query SelectStatement query SelectStatement
returning []projection returning []projection
errors []string
} }
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
@ -57,19 +54,11 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState
return i return i
} }
func (i *insertStatementImpl) addError(err string) {
i.errors = append(i.errors, err)
}
func (i *insertStatementImpl) DebugSql() (query string, err error) { func (i *insertStatementImpl) DebugSql() (query string, err error) {
return DebugSql(i) return DebugSql(i)
} }
func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) { func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) {
if len(i.errors) > 0 {
return "", nil, errors.New("errors: " + strings.Join(i.errors, ", "))
}
queryData := &queryData{} queryData := &queryData{}
queryData.newLine() queryData.newLine()

View file

@ -119,24 +119,39 @@ INSERT INTO db.table1 (col1, colFloat) VALUES
} }
func TestInsertValuesFromModelColumnMismatch(t *testing.T) { func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "missing struct field for column : col1")
}()
type Table1Model struct { type Table1Model struct {
Col1Prim int Col1Prim int
Col2 string Col2 string
} }
toInsert := Table1Model{ newData := Table1Model{
Col1Prim: 1, Col1Prim: 1,
Col2: "one", Col2: "one",
} }
stmt := table1.INSERT(table1Col1, table1ColFloat). stmt := table1.
USING(toInsert) INSERT(table1Col1, table1ColFloat).
USING(newData)
_, _, err := stmt.Sql() _, _, err := stmt.Sql()
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
} }
func TestInsertFromNonStructModel(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "argument mismatch: expected struct, got []int")
}()
table2.INSERT(table2ColInt).USING([]int{})
}
func TestInsertQuery(t *testing.T) { func TestInsertQuery(t *testing.T) {
stmt := table1.INSERT(table1Col1). stmt := table1.INSERT(table1Col1).

View file

@ -6,7 +6,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type lockMode string type TableLockMode string
const ( const (
LOCK_ACCESS_SHARE = "ACCESS SHARE" LOCK_ACCESS_SHARE = "ACCESS SHARE"
@ -22,13 +22,13 @@ const (
type LockStatement interface { type LockStatement interface {
Statement Statement
IN(lockMode lockMode) LockStatement IN(lockMode TableLockMode) LockStatement
NOWAIT() LockStatement NOWAIT() LockStatement
} }
type lockStatementImpl struct { type lockStatementImpl struct {
tables []WritableTable tables []WritableTable
lockMode lockMode lockMode TableLockMode
nowait bool nowait bool
} }
@ -38,7 +38,7 @@ func LOCK(tables ...WritableTable) LockStatement {
} }
} }
func (l *lockStatementImpl) IN(lockMode lockMode) LockStatement { func (l *lockStatementImpl) IN(lockMode TableLockMode) LockStatement {
l.lockMode = lockMode l.lockMode = lockMode
return l return l
} }

View file

@ -1,28 +1,32 @@
package sqlbuilder package sqlbuilder
import ( import (
"gotest.tools/assert"
"testing" "testing"
) )
func TestLockSingleTable(t *testing.T) { func TestLockTable(t *testing.T) {
lock := table1.LOCK().IN(LOCK_ROW_SHARE) assertStatement(t, table1.LOCK().IN(LOCK_ACCESS_SHARE), `
LOCK TABLE db.table1 IN ACCESS SHARE MODE;
queryStr, _, err := lock.Sql() `)
assertStatement(t, table1.LOCK().IN(LOCK_ROW_SHARE), `
assert.NilError(t, err)
assert.Equal(t, queryStr, `
LOCK TABLE db.table1 IN ROW SHARE MODE; LOCK TABLE db.table1 IN ROW SHARE MODE;
`) `)
} assertStatement(t, table1.LOCK().IN(LOCK_ROW_EXCLUSIVE), `
LOCK TABLE db.table1 IN ROW EXCLUSIVE MODE;
func TestLockMultipleTable(t *testing.T) { `)
lock := LOCK(table2, table1).IN(LOCK_ACCESS_EXCLUSIVE).NOWAIT() assertStatement(t, table1.LOCK().IN(LOCK_SHARE_UPDATE_EXCLUSIVE), `
LOCK TABLE db.table1 IN SHARE UPDATE EXCLUSIVE MODE;
queryStr, _, err := lock.Sql() `)
assertStatement(t, table1.LOCK().IN(LOCK_SHARE), `
assert.NilError(t, err) LOCK TABLE db.table1 IN SHARE MODE;
assert.Equal(t, queryStr, ` `)
LOCK TABLE db.table2, db.table1 IN ACCESS EXCLUSIVE MODE NOWAIT; assertStatement(t, table1.LOCK().IN(LOCK_SHARE_ROW_EXCLUSIVE), `
LOCK TABLE db.table1 IN SHARE ROW EXCLUSIVE MODE;
`)
assertStatement(t, table1.LOCK().IN(LOCK_EXCLUSIVE), `
LOCK TABLE db.table1 IN EXCLUSIVE MODE;
`)
assertStatement(t, table1.LOCK().IN(LOCK_ACCESS_EXCLUSIVE), `
LOCK TABLE db.table1 IN ACCESS EXCLUSIVE MODE;
`) `)
} }

10
sqlbuilder/row_type.go Normal file
View file

@ -0,0 +1,10 @@
package sqlbuilder
type rowsType interface {
clause
hasRows()
}
type isRowsType struct{}
func (i *isRowsType) hasRows() {}

View file

@ -6,6 +6,13 @@ import (
"github.com/go-jet/jet/sqlbuilder/execution" "github.com/go-jet/jet/sqlbuilder/execution"
) )
var (
UPDATE = newLock("UPDATE")
NO_KEY_UPDATE = newLock("NO KEY UPDATE")
SHARE = newLock("SHARE")
KEY_SHARE = newLock("KEY SHARE")
)
type SelectStatement interface { type SelectStatement interface {
Statement Statement
Expression Expression
@ -21,17 +28,15 @@ type SelectStatement interface {
LIMIT(limit int64) SelectStatement LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement OFFSET(offset int64) SelectStatement
FOR_UPDATE() SelectStatement FOR(lock SelectLock) SelectStatement
AsTable(alias string) ExpressionTable AsTable(alias string) ExpressionTable
} }
func SELECT(projection ...projection) SelectStatement { func SELECT(projection1 projection, projections ...projection) SelectStatement {
return newSelectStatement(nil, projection) return newSelectStatement(nil, append([]projection{projection1}, projections...))
} }
// NOTE: SelectStatement purposely does not implement the tableImpl interface since
// mysql's subquery performance is horrible.
type selectStatementImpl struct { type selectStatementImpl struct {
expressionInterfaceImpl expressionInterfaceImpl
isRowsType isRowsType
@ -46,7 +51,7 @@ type selectStatementImpl struct {
limit, offset int64 limit, offset int64
forUpdate bool lockFor SelectLock
} }
func newSelectStatement(table ReadableTable, projections []projection) SelectStatement { func newSelectStatement(table ReadableTable, projections []projection) SelectStatement {
@ -55,7 +60,6 @@ func newSelectStatement(table ReadableTable, projections []projection) SelectSta
projections: projections, projections: projections,
limit: -1, limit: -1,
offset: -1, offset: -1,
forUpdate: false,
distinct: false, distinct: false,
} }
@ -161,15 +165,19 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
out.insertPreparedArgument(s.offset) out.insertPreparedArgument(s.offset)
} }
if s.forUpdate { if s.lockFor != nil {
out.newLine() out.newLine()
out.writeString("FOR UPDATE") out.writeString("FOR")
err := s.lockFor.serialize(select_statement, out)
if err != nil {
return err
}
} }
return nil return nil
} }
// Return the properly escaped SQL Statement, against the specified database
func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error) { func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := queryData{} queryData := queryData{}
@ -229,11 +237,57 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement {
return s return s
} }
func (s *selectStatementImpl) FOR_UPDATE() SelectStatement { func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement {
s.forUpdate = true s.lockFor = lock
return s return s
} }
type SelectLock interface {
clause
NOWAIT() SelectLock
SKIP_LOCKED() SelectLock
}
type selectLockImpl struct {
lockStrength string
noWait, skipLocked bool
}
func newLock(name string) func() SelectLock {
return func() SelectLock {
return newSelectLock(name)
}
}
func newSelectLock(lockStrength string) SelectLock {
return &selectLockImpl{lockStrength: lockStrength}
}
func (s *selectLockImpl) NOWAIT() SelectLock {
s.noWait = true
return s
}
func (s *selectLockImpl) SKIP_LOCKED() SelectLock {
s.skipLocked = true
return s
}
func (s *selectLockImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error {
out.writeString(s.lockStrength)
if s.noWait {
out.writeString("NOWAIT")
}
if s.skipLocked {
out.writeString("SKIP LOCKED")
}
return nil
}
func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) error { func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) error {
return Query(s, db, destination) return Query(s, db, destination)
} }

View file

@ -0,0 +1,108 @@
package sqlbuilder
import "testing"
func TestInvalidSelect(t *testing.T) {
assertStatementErr(t, SELECT(nil), "projection is nil")
}
func TestSelectDistinct(t *testing.T) {
assertStatement(t, SELECT(table1ColBool).DISTINCT(), `
SELECT DISTINCT table1.colBool AS "table1.colBool";
`)
}
func TestSelectFrom(t *testing.T) {
assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), `
SELECT table1.colInt AS "table1.colInt",
table2.colFloat AS "table2.colFloat"
FROM db.table1;
`)
assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), `
SELECT table1.colInt AS "table1.colInt",
table2.colFloat AS "table2.colFloat"
FROM db.table1
INNER JOIN db.table2 ON (table1.colInt = table2.colInt);
`)
assertStatement(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), `
SELECT table1.colInt AS "table1.colInt",
table2.colFloat AS "table2.colFloat"
FROM db.table1
INNER JOIN db.table2 ON (table1.colInt = table2.colInt);
`)
}
func TestSelectWhere(t *testing.T) {
assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), `
SELECT table1.colInt AS "table1.colInt"
FROM db.table1
WHERE $1;
`, true)
assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), `
SELECT table1.colInt AS "table1.colInt"
FROM db.table1
WHERE table1.colInt >= $1;
`, int64(10))
}
func TestSelectGroupBy(t *testing.T) {
assertStatement(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), `
SELECT table2.colInt AS "table2.colInt"
FROM db.table2
GROUP BY table2.colFloat;
`)
}
func TestSelectHaving(t *testing.T) {
assertStatement(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), `
SELECT table3.colInt AS "table3.colInt"
FROM db.table3
HAVING table1.colBool = $1;
`, true)
}
func TestSelectOrderBy(t *testing.T) {
assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), `
SELECT table2.colFloat AS "table2.colFloat"
FROM db.table2
ORDER BY table2.colInt DESC;
`)
}
func TestSelectLimitOffset(t *testing.T) {
assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), `
SELECT table2.colInt AS "table2.colInt"
FROM db.table2
LIMIT $1;
`, int64(10))
assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), `
SELECT table2.colInt AS "table2.colInt"
FROM db.table2
LIMIT $1
OFFSET $2;
`, int64(10), int64(2))
}
func TestSelectLock(t *testing.T) {
assertStatement(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE()), `
SELECT table1.colBool AS "table1.colBool"
FROM db.table1
FOR UPDATE;
`)
assertStatement(t, SELECT(table1ColBool).FROM(table1).FOR(SHARE().NOWAIT()), `
SELECT table1.colBool AS "table1.colBool"
FROM db.table1
FOR SHARE NOWAIT;
`)
assertStatement(t, SELECT(table1ColBool).FROM(table1).FOR(KEY_SHARE().NOWAIT()), `
SELECT table1.colBool AS "table1.colBool"
FROM db.table1
FOR KEY SHARE NOWAIT;
`)
assertStatement(t, SELECT(table1ColBool).FROM(table1).FOR(NO_KEY_UPDATE().SKIP_LOCKED()), `
SELECT table1.colBool AS "table1.colBool"
FROM db.table1
FOR NO KEY UPDATE SKIP LOCKED;
`)
}

View file

@ -8,7 +8,7 @@ import (
type readableTable interface { type readableTable interface {
// Generates a select query on the current tableName. // Generates a select query on the current tableName.
SELECT(projections ...projection) SelectStatement SELECT(projection projection, projections ...projection) SelectStatement
// Creates a inner join tableName Expression using onCondition. // Creates a inner join tableName Expression using onCondition.
INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
@ -59,8 +59,8 @@ type readableTableInterfaceImpl struct {
} }
// Generates a select query on the current tableName. // Generates a select query on the current tableName.
func (r *readableTableInterfaceImpl) SELECT(projections ...projection) SelectStatement { func (r *readableTableInterfaceImpl) SELECT(projection1 projection, projections ...projection) SelectStatement {
return newSelectStatement(r.parent, projections) return newSelectStatement(r.parent, append([]projection{projection1}, projections...))
} }
// Creates a inner join tableName Expression using onCondition. // Creates a inner join tableName Expression using onCondition.

View file

@ -65,6 +65,7 @@ func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
out := queryData{} out := queryData{}
err := clause.serialize(select_statement, &out) err := clause.serialize(select_statement, &out)
//fmt.Println(out.buff.String())
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
assert.Equal(t, err.Error(), errString) assert.Equal(t, err.Error(), errString)
} }
@ -82,6 +83,8 @@ func assertProjectionSerialize(t *testing.T, projection projection, query string
func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) { func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) {
queryStr, args, err := query.Sql() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
//fmt.Println(queryStr)
assert.Equal(t, queryStr, expectedQuery) assert.Equal(t, queryStr, expectedQuery)
assert.DeepEqual(t, args, expectedArgs) assert.DeepEqual(t, args, expectedArgs)
} }

View file

@ -1,35 +0,0 @@
package sqlbuilder
type rowsType interface {
clause
hasRows()
}
type isRowsType struct{}
func (i *isRowsType) hasRows() {}
// A clause that can be used in orderBy by
// A clause that is selectable.
//type projection interface {
// clause
// isProjectionInterface
//
// SerializeSqlForColumnList(out *bytes.Buffer) error
//}
//
// Boiler plates ...
//
//
//type isProjectionInterface interface {
// isProjectionType()
//}
//
//type isProjection struct {
//}
//
//func (p *isProjection) isProjectionType() {
//}

View file

@ -88,7 +88,7 @@ func serializeProjectionList(statement statementType, projections []projection,
} }
if col == nil { if col == nil {
return errors.New("projection Expression is nil") return errors.New("projection is nil")
} }
if err := col.serializeForProjection(statement, out); err != nil { if err := col.serializeForProjection(statement, out); err != nil {
@ -115,6 +115,16 @@ func serializeColumnNames(columns []column, out *queryData) error {
return nil return nil
} }
func columnListToProjectionList(columns []Column) []projection {
var ret []projection
for _, column := range columns {
ret = append(ret, column)
}
return ret
}
func isNil(v interface{}) bool { func isNil(v interface{}) bool {
return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil())
} }
@ -132,9 +142,7 @@ func unwindRowFromModel(columns []column, data interface{}) []clause {
row := []clause{} row := []clause{}
if structValue.Kind() != reflect.Struct { mustBe(structValue, reflect.Struct)
return row
}
for _, column := range columns { for _, column := range columns {
columnName := column.Name() columnName := column.Name()
@ -143,7 +151,7 @@ func unwindRowFromModel(columns []column, data interface{}) []clause {
structField := structValue.FieldByName(structFieldName) structField := structValue.FieldByName(structFieldName)
if !structField.IsValid() { if !structField.IsValid() {
continue panic("missing struct field for column : " + column.Name())
} }
var field interface{} var field interface{}
@ -172,14 +180,10 @@ func unwindRowFromValues(value interface{}, values []interface{}) []clause {
return row return row
} }
func columnListToProjectionList(columns []Column) []projection { func mustBe(v reflect.Value, expected reflect.Kind) {
var ret []projection if k := v.Kind(); k != expected {
panic("argument mismatch: expected " + expected.String() + ", got " + v.Type().String())
for _, column := range columns {
ret = append(ret, column)
} }
return ret
} }
func Query(statement Statement, db execution.Db, destination interface{}) error { func Query(statement Statement, db execution.Db, destination interface{}) error {

View file

@ -1104,18 +1104,109 @@ LIMIT 20;
func TestLockTable(t *testing.T) { func TestLockTable(t *testing.T) {
expectedSql := ` expectedSql := `
LOCK TABLE dvds.address IN EXCLUSIVE MODE NOWAIT; LOCK TABLE dvds.address IN`
`
query := Address.LOCK().IN(LOCK_EXCLUSIVE).NOWAIT()
querySql, _, _ := query.Sql() var testData = []TableLockMode{
fmt.Println("-" + querySql + "-") LOCK_ACCESS_SHARE,
LOCK_ROW_SHARE,
LOCK_ROW_EXCLUSIVE,
LOCK_SHARE_UPDATE_EXCLUSIVE,
LOCK_SHARE,
LOCK_SHARE_ROW_EXCLUSIVE,
LOCK_EXCLUSIVE,
LOCK_ACCESS_EXCLUSIVE,
}
assertStatementSql(t, query, expectedSql) for _, lockMode := range testData {
query := Address.LOCK().IN(lockMode)
assertStatementSql(t, query, expectedSql+" "+string(lockMode)+" MODE;\n")
tx, _ := db.Begin() tx, _ := db.Begin()
_, err := query.Exec(tx) _, err := query.Exec(tx)
assert.NilError(t, err) assert.NilError(t, err)
tx.Rollback()
}
for _, lockMode := range testData {
query := Address.LOCK().IN(lockMode).NOWAIT()
assertStatementSql(t, query, expectedSql+" "+string(lockMode)+" MODE NOWAIT;\n")
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
tx.Rollback()
}
}
func getRowLockTestData() map[SelectLock]string {
return map[SelectLock]string{
UPDATE(): "UPDATE",
NO_KEY_UPDATE(): "NO KEY UPDATE",
SHARE(): "SHARE",
KEY_SHARE(): "KEY SHARE",
}
}
func TestRowLock(t *testing.T) {
expectedSql := `
SELECT *
FROM dvds.address
LIMIT 3
FOR`
query := Address.
SELECT(STAR).
LIMIT(3)
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType)
assertStatementSql(t, query, expectedSql+" "+lockTypeStr+";\n", int64(3))
tx, _ := db.Begin()
res, err := query.Exec(tx)
assert.NilError(t, err)
rowsAffected, _ := res.RowsAffected()
assert.Equal(t, rowsAffected, int64(3))
tx.Rollback()
}
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType.NOWAIT())
assertStatementSql(t, query, expectedSql+" "+lockTypeStr+" NOWAIT;\n", int64(3))
tx, _ := db.Begin()
res, err := query.Exec(tx)
assert.NilError(t, err)
rowsAffected, _ := res.RowsAffected()
assert.Equal(t, rowsAffected, int64(3))
tx.Rollback()
}
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType.SKIP_LOCKED())
assertStatementSql(t, query, expectedSql+" "+lockTypeStr+" SKIP LOCKED;\n", int64(3))
tx, _ := db.Begin()
res, err := query.Exec(tx)
assert.NilError(t, err)
rowsAffected, _ := res.RowsAffected()
assert.Equal(t, rowsAffected, int64(3))
tx.Rollback()
}
} }

View file

@ -213,6 +213,11 @@ WHERE link.id = 201;
} }
func TestUpdateWithInvalidModelData(t *testing.T) { func TestUpdateWithInvalidModelData(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "missing struct field for column : id")
}()
setupLinkTableForUpdateTest(t) setupLinkTableForUpdateTest(t)