diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go new file mode 100644 index 0000000..8aae3ac --- /dev/null +++ b/sqlbuilder/column.go @@ -0,0 +1,302 @@ +// Modeling of columns + +package sqlbuilder + +import ( + "bytes" + "regexp" + + "github.com/dropbox/godropbox/errors" +) + +// XXX: Maybe add UIntColumn + +// Representation of a table for query generation +type Column interface { + isProjectionInterface + + Name() string + // Serialization for use in column lists + SerializeSqlForColumnList(out *bytes.Buffer) error + // Serialization for use in an expression (Clause) + SerializeSql(out *bytes.Buffer) error + + // Internal function for tracking table that a column belongs to + // for the purpose of serialization + setTableName(table string) error +} + +type NullableColumn bool + +const ( + Nullable NullableColumn = true + NotNullable NullableColumn = false +) + +// A column that can be refer to outside of the projection list +type NonAliasColumn interface { + Column + isOrderByClauseInterface + isExpressionInterface +} + +type Collation string + +const ( + UTF8CaseInsensitive Collation = "utf8_unicode_ci" + UTF8CaseSensitive Collation = "utf8_unicode" + UTF8Binary Collation = "utf8_bin" +) + +// Representation of MySQL charsets +type Charset string + +const ( + UTF8 Charset = "utf8" +) + +// The base type for real materialized columns. +type baseColumn struct { + isProjection + isExpression + name string + nullable NullableColumn + table string +} + +func (c *baseColumn) Name() string { + return c.name +} + +func (c *baseColumn) setTableName(table string) error { + c.table = table + return nil +} + +func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { + if c.table != "" { + _ = out.WriteByte('`') + _, _ = out.WriteString(c.table) + _, _ = out.WriteString("`.") + } + _, _ = out.WriteString("`") + _, _ = out.WriteString(c.name) + _ = out.WriteByte('`') + return nil +} + +func (c *baseColumn) SerializeSql(out *bytes.Buffer) error { + return c.SerializeSqlForColumnList(out) +} + +type bytesColumn struct { + baseColumn + isExpression +} + +// Representation of VARBINARY/BLOB columns +// This function will panic if name is not valid +func BytesColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in bytes column") + } + bc := &bytesColumn{} + bc.name = name + bc.nullable = nullable + return bc +} + +type stringColumn struct { + baseColumn + isExpression + charset Charset + collation Collation +} + +// Representation of VARCHAR/TEXT columns +// This function will panic if name is not valid +func StrColumn( + name string, + charset Charset, + collation Collation, + nullable NullableColumn) NonAliasColumn { + + if !validIdentifierName(name) { + panic("Invalid column name in str column") + } + sc := &stringColumn{charset: charset, collation: collation} + sc.name = name + sc.nullable = nullable + return sc +} + +type dateTimeColumn struct { + baseColumn + isExpression +} + +// Representation of DateTime columns +// This function will panic if name is not valid +func DateTimeColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in datetime column") + } + dc := &dateTimeColumn{} + dc.name = name + dc.nullable = nullable + return dc +} + +type integerColumn struct { + baseColumn + isExpression +} + +// Representation of any integer column +// This function will panic if name is not valid +func IntColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in int column") + } + ic := &integerColumn{} + ic.name = name + ic.nullable = nullable + return ic +} + +type doubleColumn struct { + baseColumn + isExpression +} + +// Representation of any double column +// This function will panic if name is not valid +func DoubleColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in int column") + } + ic := &doubleColumn{} + ic.name = name + ic.nullable = nullable + return ic +} + +type booleanColumn struct { + baseColumn + isExpression + + // XXX: Maybe allow isBoolExpression (for now, not included because + // the deferred lookup equivalent can never be isBoolExpression) +} + +// Representation of TINYINT used as a bool +// This function will panic if name is not valid +func BoolColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in bool column") + } + bc := &booleanColumn{} + bc.name = name + bc.nullable = nullable + return bc +} + +type aliasColumn struct { + baseColumn + expression Expression +} + +func (c *aliasColumn) SerializeSql(out *bytes.Buffer) error { + _ = out.WriteByte('`') + _, _ = out.WriteString(c.name) + _ = out.WriteByte('`') + return nil +} + +func (c *aliasColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { + if !validIdentifierName(c.name) { + return errors.Newf( + "Invalid alias name `%s`. Generated sql: %s", + c.name, + out.String()) + } + if c.expression == nil { + return errors.Newf( + "Cannot alias a nil expression. Generated sql: %s", + out.String()) + } + + _ = out.WriteByte('(') + if c.expression == nil { + return errors.Newf("nil alias clause. Generate sql: %s", out.String()) + } + if err := c.expression.SerializeSql(out); err != nil { + return err + } + _, _ = out.WriteString(") AS `") + _, _ = out.WriteString(c.name) + _ = out.WriteByte('`') + return nil +} + +func (c *aliasColumn) setTableName(table string) error { + return errors.Newf( + "Alias column '%s' should never have setTableName called on it", + c.name) +} + +// Representation of aliased clauses (expression AS name) +func Alias(name string, c Expression) Column { + ac := &aliasColumn{} + ac.name = name + ac.expression = c + return ac +} + +// This is a strict subset of the actual allowed identifiers +var validIdentifierRegexp = regexp.MustCompile("^[a-zA-Z_]\\w*$") + +// Returns true if the given string is suitable as an identifier. +func validIdentifierName(name string) bool { + return validIdentifierRegexp.MatchString(name) +} + +// Pseudo Column type returned by table.C(name) +type deferredLookupColumn struct { + isProjection + isExpression + table *Table + colName string + + cachedColumn NonAliasColumn +} + +func (c *deferredLookupColumn) Name() string { + return c.colName +} + +func (c *deferredLookupColumn) SerializeSqlForColumnList( + out *bytes.Buffer) error { + + return c.SerializeSql(out) +} + +func (c *deferredLookupColumn) SerializeSql(out *bytes.Buffer) error { + if c.cachedColumn != nil { + return c.cachedColumn.SerializeSql(out) + } + + col, err := c.table.getColumn(c.colName) + if err != nil { + return err + } + + c.cachedColumn = col + return col.SerializeSql(out) +} + +func (c *deferredLookupColumn) setTableName(table string) error { + return errors.Newf( + "Lookup column '%s' should never have setTableName called on it", + c.colName) +} diff --git a/sqlbuilder/column_test.go b/sqlbuilder/column_test.go new file mode 100644 index 0000000..695384c --- /dev/null +++ b/sqlbuilder/column_test.go @@ -0,0 +1,208 @@ +package sqlbuilder + +import ( + "bytes" + "testing" + + gc "gopkg.in/check.v1" +) + +func Test(t *testing.T) { + gc.TestingT(t) +} + +type ColumnSuite struct { +} + +var _ = gc.Suite(&ColumnSuite{}) + +// +// tests for baseColumn and columns that extends baseColumn +// + +func (s *ColumnSuite) TestRealColumnName(c *gc.C) { + col := IntColumn("col", Nullable) + + c.Assert(col.Name(), gc.Equals, "col") +} + +func (s *ColumnSuite) TestRealColumnSerializeSqlForColumnList(c *gc.C) { + col := IntColumn("col", Nullable) + + // Without table name + buf := &bytes.Buffer{} + + err := col.SerializeSqlForColumnList(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`col`") + + // With table name + err = col.setTableName("foo") + c.Assert(err, gc.IsNil) + + buf = &bytes.Buffer{} + + err = col.SerializeSqlForColumnList(buf) + c.Assert(err, gc.IsNil) + + sql = buf.String() + c.Assert(sql, gc.Equals, "`foo`.`col`") +} + +func (s *ColumnSuite) TestRealColumnSerializeSql(c *gc.C) { + col := IntColumn("col", Nullable) + + // Without table name + buf := &bytes.Buffer{} + + err := col.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`col`") + + // With table name + err = col.setTableName("foo") + c.Assert(err, gc.IsNil) + + buf = &bytes.Buffer{} + + err = col.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql = buf.String() + c.Assert(sql, gc.Equals, "`foo`.`col`") +} + +// +// tests for AliasCoulmns +// + +func (s *ColumnSuite) TestAliasColumnName(c *gc.C) { + col := Alias("foo", SqlFunc("max", table1Col1)) + + c.Assert(col.Name(), gc.Equals, "foo") +} + +func (s *ColumnSuite) TestAliasColumnSerializeSqlForColumnList(c *gc.C) { + col := Alias("foo", SqlFunc("max", table1Col1)) + + buf := &bytes.Buffer{} + err := col.SerializeSqlForColumnList(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(err, gc.IsNil) + + c.Assert(sql, gc.Equals, "(max(`table1`.`col1`)) AS `foo`") +} + +func (s *ColumnSuite) TestAliasColumnSerializeSqlForColumnListNilExpr(c *gc.C) { + col := Alias("foo", nil) + + buf := &bytes.Buffer{} + err := col.SerializeSqlForColumnList(buf) + c.Assert(err, gc.NotNil) +} + +func (s *ColumnSuite) TestAliasColumnSerializeSqlForColumnListInvalidAlias( + c *gc.C) { + + col := Alias("1234", SqlFunc("max", table1Col1)) + + buf := &bytes.Buffer{} + err := col.SerializeSqlForColumnList(buf) + c.Assert(err, gc.NotNil) +} + +func (s *ColumnSuite) TestAliasColumnSerializeSql(c *gc.C) { + col := Alias("foo", SqlFunc("max", table1Col1)) + + buf := &bytes.Buffer{} + err := col.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`foo`") +} + +func (s *ColumnSuite) TestAliasColumnSetTableName(c *gc.C) { + col := Alias("foo", SqlFunc("max", table1Col1)) + + // should always error + err := col.setTableName("test") + c.Assert(err, gc.NotNil) +} + +// +// tests for deferredLookkupColumnName +// + +func (s *ColumnSuite) TestDeferredLookupColumnName(c *gc.C) { + col := table1.C("foo") + + c.Assert(col.Name(), gc.Equals, "foo") +} + +func (s *ColumnSuite) TestDeferredLookupColumnSerializeSqlForColumnList( + c *gc.C) { + + col := table1.C("col1") + + buf := &bytes.Buffer{} + + err := col.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1`") + + // check cached lookup + buf = &bytes.Buffer{} + + err = col.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql = buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1`") +} + +func (s *ColumnSuite) TestDeferredLookupColumnSerializeSqlForColumnListInvalidName( + c *gc.C) { + col := table1.C("foo") + + buf := &bytes.Buffer{} + + err := col.SerializeSql(buf) + c.Assert(err, gc.NotNil) +} + +func (s *ColumnSuite) TestDeferredLookupColumnSerializeSql(c *gc.C) { + col := table1.C("col1") + + buf := &bytes.Buffer{} + + err := col.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1`") +} + +func (s *ColumnSuite) TestDeferredLookupColumnSerializeSqlInvalidName(c *gc.C) { + col := table1.C("foo") + + buf := &bytes.Buffer{} + + err := col.SerializeSql(buf) + c.Assert(err, gc.NotNil) +} + +func (s *ColumnSuite) TestDeferredLookupColumnSetTableName(c *gc.C) { + col := table1.C("col1") + + err := col.setTableName("foo") + c.Assert(err, gc.NotNil) +} diff --git a/sqlbuilder/doc.go b/sqlbuilder/doc.go new file mode 100644 index 0000000..3f9170a --- /dev/null +++ b/sqlbuilder/doc.go @@ -0,0 +1,25 @@ +// A library for generating sql programmatically. +// +// SQL COMPATIBILITY NOTE: sqlbuilder is designed to generate valid MySQL sql +// statements. The generated statements may not work for other sql variants. +// For instances, the generated statements does not currently work for +// PostgreSQL since column identifiers are escaped with backquotes. +// Patches to support other sql flavors are welcome! (see +// https://godropbox/issues/33 for additional details). +// +// Known limitations for SELECT queries: +// - does not support subqueries (since mysql is bad at it) +// - does not currently support join table alias (and hence self join) +// - does not support NATURAL joins and join USING +// +// Known limitation for INSERT statements: +// - does not support "INSERT INTO SELECT" +// +// Known limitation for UPDATE statements: +// - does not support update without a WHERE clause (since it is dangerous) +// - does not support multi-table update +// +// Known limitation for DELETE statements: +// - does not support delete without a WHERE clause (since it is dangerous) +// - does not support multi-table delete +package sqlbuilder diff --git a/sqlbuilder/example_test.go b/sqlbuilder/example_test.go new file mode 100644 index 0000000..b5eb406 --- /dev/null +++ b/sqlbuilder/example_test.go @@ -0,0 +1,38 @@ +package sqlbuilder + +import "fmt" + +func Example() { + t1 := NewTable( + "parent_prefix", + IntColumn("ns_id", NotNullable), + IntColumn("hash", NotNullable), + StrColumn("prefix", + UTF8, + UTF8CaseInsensitive, + NotNullable)) + + t2 := NewTable( + "sfj", + IntColumn("ns_id", NotNullable), + IntColumn("sjid", NotNullable), + StrColumn("filename", + UTF8, + UTF8CaseInsensitive, + NotNullable)) + + ns_id1 := t1.C("ns_id") + prefix := t1.C("prefix") + ns_id2 := t2.C("ns_id") + sjid := t2.C("sjid") + filename := t2.C("filename") + + in := []int32{1, 2, 3} + join := t2.LeftJoinOn(t1, Eq(ns_id1, ns_id2)) + q := join.Select(ns_id2, sjid, prefix, filename).Where( + And(EqL(ns_id2, 456), In(sjid, in))) + text, _ := q.String("shard1") + fmt.Println(text) + // Output: + // SELECT `sfj`.`ns_id`,`sfj`.`sjid`,`parent_prefix`.`prefix`,`sfj`.`filename` FROM `shard1`.`sfj` LEFT JOIN `shard1`.`parent_prefix` ON `parent_prefix`.`ns_id`=`sfj`.`ns_id` WHERE (`sfj`.`ns_id`=456 AND `sfj`.`sjid` IN (1,2,3)) +} diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go new file mode 100644 index 0000000..f7ef1f4 --- /dev/null +++ b/sqlbuilder/expression.go @@ -0,0 +1,732 @@ +// Query building functions for expression components +package sqlbuilder + +import ( + "bytes" + "reflect" + "strconv" + "strings" + "time" + + "github.com/dropbox/godropbox/database/sqltypes" + "github.com/dropbox/godropbox/errors" +) + +type orderByClause struct { + isOrderByClause + expression Expression + ascent bool +} + +func (o *orderByClause) SerializeSql(out *bytes.Buffer) error { + if o.expression == nil { + return errors.Newf( + "nil order by clause. Generated sql: %s", + out.String()) + } + + if err := o.expression.SerializeSql(out); err != nil { + return err + } + + if o.ascent { + _, _ = out.WriteString(" ASC") + } else { + _, _ = out.WriteString(" DESC") + } + + return nil +} + +func Asc(expression Expression) OrderByClause { + return &orderByClause{expression: expression, ascent: true} +} + +func Desc(expression Expression) OrderByClause { + return &orderByClause{expression: expression, ascent: false} +} + +// Representation of an escaped literal +type literalExpression struct { + isExpression + value sqltypes.Value +} + +func (c literalExpression) SerializeSql(out *bytes.Buffer) error { + sqltypes.Value(c.value).EncodeSql(out) + return nil +} + +func serializeClauses( + clauses []Clause, + separator []byte, + out *bytes.Buffer) (err error) { + + if clauses == nil || len(clauses) == 0 { + return errors.Newf("Empty clauses. Generated sql: %s", out.String()) + } + + if clauses[0] == nil { + return errors.Newf("nil clause. Generated sql: %s", out.String()) + } + if err = clauses[0].SerializeSql(out); err != nil { + return + } + + for _, c := range clauses[1:] { + _, _ = out.Write(separator) + + if c == nil { + return errors.Newf("nil clause. Generated sql: %s", out.String()) + } + if err = c.SerializeSql(out); err != nil { + return + } + } + + return nil +} + +// Representation of n-ary conjunctions (AND/OR) +type conjunctExpression struct { + isExpression + isBoolExpression + expressions []BoolExpression + conjunction []byte +} + +func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer) (err error) { + if len(conj.expressions) == 0 { + return errors.Newf( + "Empty conjunction. Generated sql: %s", + out.String()) + } + + clauses := make([]Clause, len(conj.expressions), len(conj.expressions)) + for i, expr := range conj.expressions { + clauses[i] = expr + } + + useParentheses := len(clauses) > 1 + if useParentheses { + _ = out.WriteByte('(') + } + + if err = serializeClauses(clauses, conj.conjunction, out); err != nil { + return + } + + if useParentheses { + _ = out.WriteByte(')') + } + + return nil +} + +// Representation of n-ary arithmetic (+ - * /) +type arithmeticExpression struct { + isExpression + expressions []Expression + operator []byte +} + +func (arith *arithmeticExpression) SerializeSql(out *bytes.Buffer) (err error) { + if len(arith.expressions) == 0 { + return errors.Newf( + "Empty arithmetic expression. Generated sql: %s", + out.String()) + } + + clauses := make([]Clause, len(arith.expressions), len(arith.expressions)) + for i, expr := range arith.expressions { + clauses[i] = expr + } + + useParentheses := len(clauses) > 1 + if useParentheses { + _ = out.WriteByte('(') + } + + if err = serializeClauses(clauses, arith.operator, out); err != nil { + return + } + + if useParentheses { + _ = out.WriteByte(')') + } + + return nil +} + +type tupleExpression struct { + isExpression + elements listClause +} + +func (tuple *tupleExpression) SerializeSql(out *bytes.Buffer) error { + if len(tuple.elements.clauses) < 1 { + return errors.Newf("Tuples must include at least one element") + } + return tuple.elements.SerializeSql(out) +} + +func Tuple(exprs ...Expression) Expression { + clauses := make([]Clause, 0, len(exprs)) + for _, expr := range exprs { + clauses = append(clauses, expr) + } + return &tupleExpression{ + elements: listClause{ + clauses: clauses, + includeParentheses: true, + }, + } +} + +// Representation of a tuple enclosed, comma separated list of clauses +type listClause struct { + clauses []Clause + includeParentheses bool +} + +func (list *listClause) SerializeSql(out *bytes.Buffer) error { + if list.includeParentheses { + _ = out.WriteByte('(') + } + + if err := serializeClauses(list.clauses, []byte(","), out); err != nil { + return err + } + + if list.includeParentheses { + _ = out.WriteByte(')') + } + return nil +} + +// A not expression which negates a expression value +type negateExpression struct { + isExpression + isBoolExpression + + nested BoolExpression +} + +func (c *negateExpression) SerializeSql(out *bytes.Buffer) (err error) { + _, _ = out.WriteString("NOT (") + + if c.nested == nil { + return errors.Newf("nil nested. Generated sql: %s", out.String()) + } + if err = c.nested.SerializeSql(out); err != nil { + return + } + + _ = out.WriteByte(')') + return nil +} + +// Returns a representation of "not expr" +func Not(expr BoolExpression) BoolExpression { + return &negateExpression{ + nested: expr, + } +} + +// Representation of binary operations (e.g. comparisons, arithmetic) +type binaryExpression struct { + isExpression + lhs, rhs Expression + operator []byte +} + +func (c *binaryExpression) SerializeSql(out *bytes.Buffer) (err error) { + if c.lhs == nil { + return errors.Newf("nil lhs. Generated sql: %s", out.String()) + } + if err = c.lhs.SerializeSql(out); err != nil { + return + } + + _, _ = out.Write(c.operator) + + if c.rhs == nil { + return errors.Newf("nil rhs. Generated sql: %s", out.String()) + } + if err = c.rhs.SerializeSql(out); err != nil { + return + } + + return nil +} + +// A binary expression that evaluates to a boolean value. +type boolExpression struct { + isBoolExpression + binaryExpression +} + +func newBoolExpression(lhs, rhs Expression, operator []byte) *boolExpression { + // go does not allow {} syntax for initializing promoted fields ... + expr := new(boolExpression) + expr.lhs = lhs + expr.rhs = rhs + expr.operator = operator + return expr +} + +type funcExpression struct { + isExpression + funcName string + args *listClause +} + +func (c *funcExpression) SerializeSql(out *bytes.Buffer) (err error) { + if !validIdentifierName(c.funcName) { + return errors.Newf( + "Invalid function name: %s. Generated sql: %s", + c.funcName, + out.String()) + } + _, _ = out.WriteString(c.funcName) + if c.args == nil { + _, _ = out.WriteString("()") + } else { + return c.args.SerializeSql(out) + } + return nil +} + +// Returns a representation of sql function call "func_call(c[0], ..., c[n-1]) +func SqlFunc(funcName string, expressions ...Expression) Expression { + f := &funcExpression{ + funcName: funcName, + } + if len(expressions) > 0 { + args := make([]Clause, len(expressions), len(expressions)) + for i, expr := range expressions { + args[i] = expr + } + + f.args = &listClause{ + clauses: args, + includeParentheses: true, + } + } + return f +} + +type intervalExpression struct { + isExpression + duration time.Duration + negative bool +} + +var intervalSep = ":" + +func (c *intervalExpression) SerializeSql(out *bytes.Buffer) (err error) { + hours := c.duration / time.Hour + minutes := (c.duration % time.Hour) / time.Minute + sec := (c.duration % time.Minute) / time.Second + msec := (c.duration % time.Second) / time.Microsecond + _, _ = out.WriteString("INTERVAL '") + if c.negative { + _, _ = out.WriteString("-") + } + _, _ = out.WriteString(strconv.FormatInt(int64(hours), 10)) + _, _ = out.WriteString(intervalSep) + _, _ = out.WriteString(strconv.FormatInt(int64(minutes), 10)) + _, _ = out.WriteString(intervalSep) + _, _ = out.WriteString(strconv.FormatInt(int64(sec), 10)) + _, _ = out.WriteString(intervalSep) + _, _ = out.WriteString(strconv.FormatInt(int64(msec), 10)) + _, _ = out.WriteString("' HOUR_MICROSECOND") + return nil +} + +// Interval returns a representation of duration +// in a form "INTERVAL `hour:min:sec:microsec` HOUR_MICROSECOND" +func Interval(duration time.Duration) Expression { + negative := false + if duration < 0 { + negative = true + duration = -duration + } + return &intervalExpression{ + duration: duration, + negative: negative, + } +} + +var likeEscaper = strings.NewReplacer("_", "\\_", "%", "\\%") + +func EscapeForLike(s string) string { + return likeEscaper.Replace(s) +} + +// Returns an escaped literal string +func Literal(v interface{}) Expression { + value, err := sqltypes.BuildValue(v) + if err != nil { + panic(errors.Wrap(err, "Invalid literal value")) + } + return &literalExpression{value: value} +} + +// Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses +func And(expressions ...BoolExpression) BoolExpression { + return &conjunctExpression{ + expressions: expressions, + conjunction: []byte(" AND "), + } +} + +// Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses +func Or(expressions ...BoolExpression) BoolExpression { + return &conjunctExpression{ + expressions: expressions, + conjunction: []byte(" OR "), + } +} + +func Like(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(" LIKE ")) +} + +func LikeL(lhs Expression, val string) BoolExpression { + return Like(lhs, Literal(val)) +} + +func Regexp(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(" REGEXP ")) +} + +func RegexpL(lhs Expression, val string) BoolExpression { + return Regexp(lhs, Literal(val)) +} + +// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses +func Add(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" + "), + } +} + +// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses +func Sub(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" - "), + } +} + +// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses +func Mul(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" * "), + } +} + +// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses +func Div(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" / "), + } +} + +// Returns a representation of "a=b" +func Eq(lhs, rhs Expression) BoolExpression { + lit, ok := rhs.(*literalExpression) + if ok && sqltypes.Value(lit.value).IsNull() { + return newBoolExpression(lhs, rhs, []byte(" IS ")) + } + return newBoolExpression(lhs, rhs, []byte("=")) +} + +// Returns a representation of "a=b", where b is a literal +func EqL(lhs Expression, val interface{}) BoolExpression { + return Eq(lhs, Literal(val)) +} + +// Returns a representation of "a!=b" +func Neq(lhs, rhs Expression) BoolExpression { + lit, ok := rhs.(*literalExpression) + if ok && sqltypes.Value(lit.value).IsNull() { + return newBoolExpression(lhs, rhs, []byte(" IS NOT ")) + } + return newBoolExpression(lhs, rhs, []byte("!=")) +} + +// Returns a representation of "a!=b", where b is a literal +func NeqL(lhs Expression, val interface{}) BoolExpression { + return Neq(lhs, Literal(val)) +} + +// Returns a representation of "ab" +func Gt(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(">")) +} + +// Returns a representation of "a>b", where b is a literal +func GtL(lhs Expression, val interface{}) BoolExpression { + return Gt(lhs, Literal(val)) +} + +// Returns a representation of "a>=b" +func Gte(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(">=")) +} + +// Returns a representation of "a>=b", where b is a literal +func GteL(lhs Expression, val interface{}) BoolExpression { + return Gte(lhs, Literal(val)) +} + +func BitOr(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" | "), + } +} + +func BitAnd(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" & "), + } +} + +func BitXor(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" ^ "), + } +} + +func Plus(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" + "), + } +} + +func Minus(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" - "), + } +} + +// in expression representation +type inExpression struct { + isExpression + isBoolExpression + + lhs Expression + rhs *listClause + + err error +} + +func (c *inExpression) SerializeSql(out *bytes.Buffer) error { + if c.err != nil { + return errors.Wrap(c.err, "Invalid IN expression") + } + + if c.lhs == nil { + return errors.Newf( + "lhs of in expression is nil. Generated sql: %s", + out.String()) + } + + // We'll serialize the lhs even if we don't need it to ensure no error + buf := &bytes.Buffer{} + + err := c.lhs.SerializeSql(buf) + if err != nil { + return err + } + + if c.rhs == nil { + _, _ = out.WriteString("FALSE") + return nil + } + + _, _ = out.WriteString(buf.String()) + _, _ = out.WriteString(" IN ") + + err = c.rhs.SerializeSql(out) + if err != nil { + return err + } + + return nil +} + +// Returns a representation of "a IN (b[0], ..., b[n-1])", where b is a list +// of literals valList must be a slice type +func In(lhs Expression, valList interface{}) BoolExpression { + var clauses []Clause + switch val := valList.(type) { + // This atrocious body of copy-paste code is due to the fact that if you + // try to merge the cases, you can't treat val as a list + case []int: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []int32: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []int64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint32: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []float64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []string: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case [][]byte: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []time.Time: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Numeric: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Fractional: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.String: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Value: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + default: + return &inExpression{ + err: errors.Newf( + "Unknown value list type in IN clause: %s", + reflect.TypeOf(valList)), + } + } + + expr := &inExpression{lhs: lhs} + if len(clauses) > 0 { + expr.rhs = &listClause{clauses: clauses, includeParentheses: true} + } + return expr +} + +type ifExpression struct { + isExpression + conditional BoolExpression + trueExpression Expression + falseExpression Expression +} + +func (exp *ifExpression) SerializeSql(out *bytes.Buffer) error { + _, _ = out.WriteString("IF(") + _ = exp.conditional.SerializeSql(out) + _, _ = out.WriteString(",") + _ = exp.trueExpression.SerializeSql(out) + _, _ = out.WriteString(",") + _ = exp.falseExpression.SerializeSql(out) + _, _ = out.WriteString(")") + return nil +} + +// Returns a representation of an if-expression, of the form: +// IF (BOOLEAN TEST, VALUE-IF-TRUE, VALUE-IF-FALSE) +func If(conditional BoolExpression, + trueExpression Expression, + falseExpression Expression) Expression { + return &ifExpression{ + conditional: conditional, + trueExpression: trueExpression, + falseExpression: falseExpression, + } +} + +type columnValueExpression struct { + isExpression + column NonAliasColumn +} + +func ColumnValue(col NonAliasColumn) Expression { + return &columnValueExpression{ + column: col, + } +} + +func (cv *columnValueExpression) SerializeSql(out *bytes.Buffer) error { + _, _ = out.WriteString("VALUES(") + _ = cv.column.SerializeSqlForColumnList(out) + _ = out.WriteByte(')') + return nil +} diff --git a/sqlbuilder/expression_test.go b/sqlbuilder/expression_test.go new file mode 100644 index 0000000..648b653 --- /dev/null +++ b/sqlbuilder/expression_test.go @@ -0,0 +1,547 @@ +package sqlbuilder + +import ( + "bytes" + "time" + + gc "gopkg.in/check.v1" +) + +type ExprSuite struct { +} + +var _ = gc.Suite(&ExprSuite{}) + +func (s *ExprSuite) TestConjunctExprEmptyList(c *gc.C) { + expr := And() + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.NotNil) +} + +func (s *ExprSuite) TestConjunctExprNilInList(c *gc.C) { + expr := And(nil, EqL(table1Col1, 1)) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.NotNil) +} + +func (s *ExprSuite) TestConjunctExprSingleElement(c *gc.C) { + expr := And(EqL(table1Col1, 1)) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1`=1") +} + +func (s *ExprSuite) TestTupleExpr(c *gc.C) { + + expr := Tuple() + buf := &bytes.Buffer{} + err := expr.SerializeSql(buf) + c.Assert(err, gc.NotNil) + + expr = Tuple(table1Col1, Literal(1), Literal("five")) + err = expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "(`table1`.`col1`,1,'five')") + +} + +func (s *ExprSuite) TestLikeExpr(c *gc.C) { + expr := LikeL(table1Col1, EscapeForLike("%my_prefix")+"%") + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "`table1`.`col1` LIKE '\\%my\\_prefix%'") + +} + +func (s *ExprSuite) TestRegexExpr(c *gc.C) { + expr := RegexpL(table1Col1, "[[:<:]]log|[[.low-line.]]log") + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "`table1`.`col1` REGEXP '[[:<:]]log|[[.low-line.]]log'") + +} + +func (s *ExprSuite) TestAndExpr(c *gc.C) { + expr := And(EqL(table1Col1, 1), EqL(table1Col2, 2), EqL(table1Col3, 3)) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "(`table1`.`col1`=1 AND `table1`.`col2`=2 AND `table1`.`col3`=3)") +} + +func (s *ExprSuite) TestOrExpr(c *gc.C) { + expr := Or(EqL(table1Col1, 1), EqL(table1Col2, 2), EqL(table1Col3, 3)) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "(`table1`.`col1`=1 OR `table1`.`col2`=2 OR `table1`.`col3`=3)") +} + +func (s *ExprSuite) TestAddExpr(c *gc.C) { + expr := Add(Literal(1), Literal(2), Literal(3)) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "(1 + 2 + 3)") +} + +func (s *ExprSuite) TestSubExpr(c *gc.C) { + expr := Sub(Literal(1), Literal(2), Literal(3)) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "(1 - 2 - 3)") +} + +func (s *ExprSuite) TestMulExpr(c *gc.C) { + expr := Mul(Literal(1), Literal(2), Literal(3)) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "(1 * 2 * 3)") +} + +func (s *ExprSuite) TestDivExpr(c *gc.C) { + expr := Div(Literal(1), Literal(2), Literal(3)) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "(1 / 2 / 3)") +} + +func (s *ExprSuite) TestBinaryExprNilLHS(c *gc.C) { + expr := Gt(nil, table1Col1) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.NotNil) +} + +func (s *ExprSuite) TestNegateExpr(c *gc.C) { + expr := Not(EqL(table1Col1, 123)) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "NOT (`table1`.`col1`=123)") +} + +func (s *ExprSuite) TestBinaryExprNilRHS(c *gc.C) { + expr := Lt(table1Col1, nil) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.NotNil) +} + +func (s *ExprSuite) TestEqExpr(c *gc.C) { + expr := EqL(table1Col1, 321) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1`=321") +} + +func (s *ExprSuite) TestEqExprNilLHS(c *gc.C) { + expr := EqL(table1Col1, nil) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1` IS null") +} + +func (s *ExprSuite) TestNeqExpr(c *gc.C) { + expr := NeqL(table1Col1, 123) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1`!=123") +} + +func (s *ExprSuite) TestNeqExprNilLHS(c *gc.C) { + expr := NeqL(table1Col1, nil) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1` IS NOT null") +} + +func (s *ExprSuite) TestLtExpr(c *gc.C) { + expr := LtL(table1Col1, -1.5) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1`<-1.5") +} + +func (s *ExprSuite) TestLteExpr(c *gc.C) { + expr := LteL(table1Col1, "foo\"';drop user table;") + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "`table1`.`col1`<='foo\\\"\\';drop user table;'") +} + +func (s *ExprSuite) TestGtExpr(c *gc.C) { + expr := GtL(table1Col1, 1.1) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1`>1.1") +} + +func (s *ExprSuite) TestGteExpr(c *gc.C) { + expr := GteL(table1Col1, 1) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1`>=1") +} + +func (s *ExprSuite) TestInExpr(c *gc.C) { + values := []int32{1, 2, 3} + expr := In(table1Col1, values) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1` IN (1,2,3)") +} + +func (s *ExprSuite) TestInExprEmptyList(c *gc.C) { + values := []int32{} + expr := In(table1Col1, values) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "FALSE") +} + +func (s *ExprSuite) TestSqlFuncExprNilInArgList(c *gc.C) { + expr := SqlFunc("rand", nil) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.NotNil) +} + +func (s *ExprSuite) TestSqlFuncExprEmptyArgList(c *gc.C) { + expr := SqlFunc("rand") + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "rand()") +} + +func (s *ExprSuite) TestSqlFuncExprNonEmptyArgList(c *gc.C) { + expr := SqlFunc("add", table1Col1, table1Col2) + + buf := &bytes.Buffer{} + + err := expr.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "add(`table1`.`col1`,`table1`.`col2`)") +} + +func (s *ExprSuite) TestOrderByClauseNilExpr(c *gc.C) { + clause := Asc(nil) + + buf := &bytes.Buffer{} + + err := clause.SerializeSql(buf) + c.Assert(err, gc.NotNil) +} + +func (s *ExprSuite) TestAsc(c *gc.C) { + clause := Asc(table1Col1) + + buf := &bytes.Buffer{} + + err := clause.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1` ASC") +} + +func (s *ExprSuite) TestDesc(c *gc.C) { + clause := Desc(table1Col1) + + buf := &bytes.Buffer{} + + err := clause.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1` DESC") +} + +func (s *ExprSuite) TestIf(c *gc.C) { + test := GtL(table1Col1, 1.1) + clause := If(test, table1Col1, table1Col2) + + buf := &bytes.Buffer{} + + err := clause.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "IF(`table1`.`col1`>1.1,`table1`.`col1`,`table1`.`col2`)") +} + +func (s *ExprSuite) TestColumnValue(c *gc.C) { + clause := ColumnValue(table1Col1) + + buf := &bytes.Buffer{} + + err := clause.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "VALUES(`table1`.`col1`)") +} + +func (s *ExprSuite) TestBitwiseOr(c *gc.C) { + clause := BitOr(Literal(1), Literal(2)) + + buf := &bytes.Buffer{} + + err := clause.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "1 | 2") +} + +func (s *ExprSuite) TestBitwiseAnd(c *gc.C) { + clause := BitAnd(Literal(1), Literal(2)) + + buf := &bytes.Buffer{} + + err := clause.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "1 & 2") +} + +func (s *ExprSuite) TestBitwiseXor(c *gc.C) { + clause := BitXor(Literal(1), Literal(2)) + + buf := &bytes.Buffer{} + + err := clause.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "1 ^ 2") +} + +func (s *ExprSuite) TestPlus(c *gc.C) { + clause := Plus(Literal(1), Literal(2)) + + buf := &bytes.Buffer{} + + err := clause.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "1 + 2") +} + +func (s *ExprSuite) TestMinus(c *gc.C) { + clause := Minus(Literal(1), Literal(2)) + + buf := &bytes.Buffer{} + + err := clause.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "1 - 2") +} + +func (s *ExprSuite) TestInterval(c *gc.C) { + testTable := []struct { + interval time.Duration + expected string + expectedErr error + }{ + { + interval: 50 * time.Microsecond, + expected: "INTERVAL '0:0:0:50' HOUR_MICROSECOND", + }, + { + interval: -50 * time.Microsecond, + expected: "INTERVAL '-0:0:0:50' HOUR_MICROSECOND", + }, + { + interval: 50*time.Microsecond + 50*time.Second, + expected: "INTERVAL '0:0:50:50' HOUR_MICROSECOND", + }, + { + interval: 50*time.Microsecond + + 50*time.Second + + 50*time.Minute, + expected: "INTERVAL '0:50:50:50' HOUR_MICROSECOND", + }, + { + interval: 50*time.Microsecond + + 50*time.Second + + 50*time.Minute + + 50*time.Hour, + expected: "INTERVAL '50:50:50:50' HOUR_MICROSECOND", + }, + { + interval: 50 * time.Hour, + expected: "INTERVAL '50:0:0:0' HOUR_MICROSECOND", + }, + { + interval: 50*time.Hour + 50*time.Minute, + expected: "INTERVAL '50:50:0:0' HOUR_MICROSECOND", + }, + { + interval: 50*time.Hour + 50*time.Minute + 50*time.Second, + expected: "INTERVAL '50:50:50:0' HOUR_MICROSECOND", + }, + { + interval: 0, + expected: "INTERVAL '0:0:0:0' HOUR_MICROSECOND", + }, + { + interval: 50 * time.Nanosecond, + expected: "INTERVAL '0:0:0:0' HOUR_MICROSECOND", + }, + } + buf := &bytes.Buffer{} + + for i, tt := range testTable { + buf.Reset() + err := Interval(tt.interval).SerializeSql(buf) + c.Assert(err, gc.Equals, tt.expectedErr, + gc.Commentf("experiment #%d", i)) + if err == nil { + c.Assert(buf.String(), gc.Equals, tt.expected, + gc.Commentf("experiment #%d", i)) + } + } +} diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go new file mode 100644 index 0000000..3536c77 --- /dev/null +++ b/sqlbuilder/statement.go @@ -0,0 +1,1019 @@ +package sqlbuilder + +import ( + "bytes" + "fmt" + "regexp" + + "github.com/dropbox/godropbox/errors" +) + +type Statement interface { + // String returns generated SQL as string. + String(database string) (sql string, err error) +} + +type SelectStatement interface { + Statement + + Where(expression BoolExpression) SelectStatement + AndWhere(expression BoolExpression) SelectStatement + GroupBy(expressions ...Expression) SelectStatement + OrderBy(clauses ...OrderByClause) SelectStatement + Limit(limit int64) SelectStatement + Distinct() SelectStatement + WithSharedLock() SelectStatement + ForUpdate() SelectStatement + Offset(offset int64) SelectStatement + Comment(comment string) SelectStatement + Copy() SelectStatement +} + +type InsertStatement interface { + Statement + + // Add a row of values to the insert statement. + Add(row ...Expression) InsertStatement + AddOnDuplicateKeyUpdate(col NonAliasColumn, expr Expression) InsertStatement + Comment(comment string) InsertStatement + IgnoreDuplicates(ignore bool) InsertStatement +} + +// By default, rows selected by a UNION statement are out-of-order +// If you have an ORDER BY on an inner SELECT statement, the only thing +// it affects is the LIMIT clause on that inner statement (the ordering will +// still be out-of-order). +type UnionStatement interface { + Statement + + // Warning! You cannot include table names for the next 4 clauses, or + // you'll get errors like: + // Table 'server_file_journal' from one of the SELECTs cannot be used in + // global ORDER clause + Where(expression BoolExpression) UnionStatement + AndWhere(expression BoolExpression) UnionStatement + GroupBy(expressions ...Expression) UnionStatement + OrderBy(clauses ...OrderByClause) UnionStatement + + Limit(limit int64) UnionStatement + Offset(offset int64) UnionStatement +} + +type UpdateStatement interface { + Statement + + Set(column NonAliasColumn, expression Expression) UpdateStatement + Where(expression BoolExpression) UpdateStatement + OrderBy(clauses ...OrderByClause) UpdateStatement + Limit(limit int64) UpdateStatement + Comment(comment string) UpdateStatement +} + +type DeleteStatement interface { + Statement + + Where(expression BoolExpression) DeleteStatement + OrderBy(clauses ...OrderByClause) DeleteStatement + Limit(limit int64) DeleteStatement + Comment(comment string) DeleteStatement +} + +// LockStatement is used to take Read/Write lock on tables. +// See http://dev.mysql.com/doc/refman/5.0/en/lock-tables.html +type LockStatement interface { + Statement + + AddReadLock(table *Table) LockStatement + AddWriteLock(table *Table) LockStatement +} + +// UnlockStatement can be used to release table locks taken using LockStatement. +// NOTE: You can not selectively release a lock and continue to hold lock on +// another table. UnlockStatement releases all the lock held in the current +// session. +type UnlockStatement interface { + Statement +} + +// SetGtidNextStatement returns a SQL statement that can be used to explicitly set the next GTID. +type GtidNextStatement interface { + Statement +} + +// +// UNION SELECT Statement ====================================================== +// + +func Union(selects ...SelectStatement) UnionStatement { + return &unionStatementImpl{ + selects: selects, + limit: -1, + offset: -1, + unique: true, + } +} + +func UnionAll(selects ...SelectStatement) UnionStatement { + return &unionStatementImpl{ + selects: selects, + limit: -1, + offset: -1, + unique: false, + } +} + +// Similar to selectStatementImpl, but less complete +type unionStatementImpl struct { + selects []SelectStatement + where BoolExpression + group *listClause + order *listClause + limit, offset int64 + // True if results of the union should be deduped. + unique bool +} + +func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement { + us.where = expression + return us +} + +// Further filter the query, instead of replacing the filter +func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement { + if us.where == nil { + return us.Where(expression) + } + us.where = And(us.where, expression) + return us +} + +func (us *unionStatementImpl) GroupBy( + expressions ...Expression) UnionStatement { + + us.group = &listClause{ + clauses: make([]Clause, len(expressions), len(expressions)), + includeParentheses: false, + } + + for i, e := range expressions { + us.group.clauses[i] = e + } + return us +} + +func (us *unionStatementImpl) OrderBy( + clauses ...OrderByClause) UnionStatement { + + us.order = newOrderByListClause(clauses...) + return us +} + +func (us *unionStatementImpl) Limit(limit int64) UnionStatement { + us.limit = limit + return us +} + +func (us *unionStatementImpl) Offset(offset int64) UnionStatement { + us.offset = offset + return us +} + +func (us *unionStatementImpl) String(database string) (sql string, err error) { + if len(us.selects) == 0 { + return "", errors.Newf("Union statement must have at least one SELECT") + } + + if len(us.selects) == 1 { + return us.selects[0].String(database) + } + + // Union statements in MySQL require that the same number of columns in each subquery + var projections []Projection + + for _, statement := range us.selects { + // do a type assertion to get at the underlying struct + statementImpl, ok := statement.(*selectStatementImpl) + if !ok { + return "", errors.Newf( + "Expected inner select statement to be of type " + + "selectStatementImpl") + } + + // check that for limit for statements with order by clauses + if statementImpl.order != nil && statementImpl.limit < 0 { + return "", errors.Newf( + "All inner selects in Union statement must have LIMIT if " + + "they have ORDER BY") + } + + // check number of projections + if projections == nil { + projections = statementImpl.projections + } else { + if len(projections) != len(statementImpl.projections) { + return "", errors.Newf( + "All inner selects in Union statement must select the " + + "same number of columns. For sanity, you probably " + + "want to select the same table columns in the same " + + "order. If you are selecting on multiple tables, " + + "use Null to pad to the right number of fields.") + } + } + } + + buf := new(bytes.Buffer) + for i, statement := range us.selects { + if i != 0 { + if us.unique { + _, _ = buf.WriteString(" UNION ") + } else { + _, _ = buf.WriteString(" UNION ALL ") + } + } + _, _ = buf.WriteString("(") + selectSql, err := statement.String(database) + if err != nil { + return "", err + } + _, _ = buf.WriteString(selectSql) + _, _ = buf.WriteString(")") + } + + if us.where != nil { + _, _ = buf.WriteString(" WHERE ") + if err = us.where.SerializeSql(buf); err != nil { + return + } + } + + if us.group != nil { + _, _ = buf.WriteString(" GROUP BY ") + if err = us.group.SerializeSql(buf); err != nil { + return + } + } + + if us.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = us.order.SerializeSql(buf); err != nil { + return + } + } + + if us.limit >= 0 { + if us.offset >= 0 { + _, _ = buf.WriteString( + fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit)) + } else { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit)) + } + } + return buf.String(), nil +} + +// +// SELECT Statement ============================================================ +// + +func newSelectStatement( + table ReadableTable, + projections []Projection) SelectStatement { + + return &selectStatementImpl{ + table: table, + projections: projections, + limit: -1, + offset: -1, + withSharedLock: false, + forUpdate: false, + distinct: false, + } +} + +// NOTE: SelectStatement purposely does not implement the Table interface since +// mysql's subquery performance is horrible. +type selectStatementImpl struct { + table ReadableTable + projections []Projection + where BoolExpression + group *listClause + order *listClause + comment string + limit, offset int64 + withSharedLock bool + forUpdate bool + distinct bool +} + +func (s *selectStatementImpl) Copy() SelectStatement { + ret := *s + return &ret +} + +// Further filter the query, instead of replacing the filter +func (q *selectStatementImpl) AndWhere( + expression BoolExpression) SelectStatement { + + if q.where == nil { + return q.Where(expression) + } + q.where = And(q.where, expression) + return q +} + +func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement { + q.where = expression + return q +} + +func (q *selectStatementImpl) GroupBy( + expressions ...Expression) SelectStatement { + + q.group = &listClause{ + clauses: make([]Clause, len(expressions), len(expressions)), + includeParentheses: false, + } + + for i, e := range expressions { + q.group.clauses[i] = e + } + return q +} + +func (q *selectStatementImpl) OrderBy( + clauses ...OrderByClause) SelectStatement { + + q.order = newOrderByListClause(clauses...) + return q +} + +func (q *selectStatementImpl) Limit(limit int64) SelectStatement { + q.limit = limit + return q +} + +func (q *selectStatementImpl) Distinct() SelectStatement { + q.distinct = true + return q +} + +func (q *selectStatementImpl) WithSharedLock() SelectStatement { + // We don't need to grab a read lock if we're going to grab a write one + if !q.forUpdate { + q.withSharedLock = true + } + return q +} + +func (q *selectStatementImpl) ForUpdate() SelectStatement { + // Clear a request for a shared lock if we're asking for a write one + q.withSharedLock = false + q.forUpdate = true + return q +} + +func (q *selectStatementImpl) Offset(offset int64) SelectStatement { + q.offset = offset + return q +} + +func (q *selectStatementImpl) Comment(comment string) SelectStatement { + q.comment = comment + return q +} + +// Return the properly escaped SQL statement, against the specified database +func (q *selectStatementImpl) String(database string) (sql string, err error) { + if !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("SELECT ") + + if err = writeComment(q.comment, buf); err != nil { + return + } + + if q.distinct { + _, _ = buf.WriteString("DISTINCT ") + } + + if q.projections == nil || len(q.projections) == 0 { + return "", errors.Newf( + "No column selected. Generated sql: %s", + buf.String()) + } + + for i, col := range q.projections { + if i > 0 { + _ = buf.WriteByte(',') + } + if col == nil { + return "", errors.Newf( + "nil column selected. Generated sql: %s", + buf.String()) + } + if err = col.SerializeSqlForColumnList(buf); err != nil { + return + } + } + + _, _ = buf.WriteString(" FROM ") + if q.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + if err = q.table.SerializeSql(database, buf); err != nil { + return + } + + if q.where != nil { + _, _ = buf.WriteString(" WHERE ") + if err = q.where.SerializeSql(buf); err != nil { + return + } + } + + if q.group != nil { + _, _ = buf.WriteString(" GROUP BY ") + if err = q.group.SerializeSql(buf); err != nil { + return + } + } + + if q.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = q.order.SerializeSql(buf); err != nil { + return + } + } + + if q.limit >= 0 { + if q.offset >= 0 { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d, %d", q.offset, q.limit)) + } else { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", q.limit)) + } + } + + if q.forUpdate { + _, _ = buf.WriteString(" FOR UPDATE") + } else if q.withSharedLock { + _, _ = buf.WriteString(" LOCK IN SHARE MODE") + } + + return buf.String(), nil +} + +// +// INSERT Statement ============================================================ +// + +func newInsertStatement( + t WritableTable, + columns ...NonAliasColumn) InsertStatement { + + return &insertStatementImpl{ + table: t, + columns: columns, + rows: make([][]Expression, 0, 1), + onDuplicateKeyUpdates: make([]columnAssignment, 0, 0), + } +} + +type columnAssignment struct { + col NonAliasColumn + expr Expression +} + +type insertStatementImpl struct { + table WritableTable + columns []NonAliasColumn + rows [][]Expression + onDuplicateKeyUpdates []columnAssignment + comment string + ignore bool +} + +func (s *insertStatementImpl) Add( + row ...Expression) InsertStatement { + + s.rows = append(s.rows, row) + return s +} + +func (s *insertStatementImpl) AddOnDuplicateKeyUpdate( + col NonAliasColumn, + expr Expression) InsertStatement { + + s.onDuplicateKeyUpdates = append( + s.onDuplicateKeyUpdates, + columnAssignment{col, expr}) + + return s +} + +func (s *insertStatementImpl) IgnoreDuplicates(ignore bool) InsertStatement { + s.ignore = ignore + return s +} + +func (s *insertStatementImpl) Comment(comment string) InsertStatement { + s.comment = comment + return s +} + +func (s *insertStatementImpl) String(database string) (sql string, err error) { + if !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("INSERT ") + if s.ignore { + _, _ = buf.WriteString("IGNORE ") + } + _, _ = buf.WriteString("INTO ") + + if err = writeComment(s.comment, buf); err != nil { + return + } + + if s.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = s.table.SerializeSql(database, buf); err != nil { + return + } + + if len(s.columns) == 0 { + return "", errors.Newf( + "No column specified. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" (") + for i, col := range s.columns { + if i > 0 { + _ = buf.WriteByte(',') + } + + if col == nil { + return "", errors.Newf( + "nil column in columns list. Generated sql: %s", + buf.String()) + } + + if err = col.SerializeSqlForColumnList(buf); err != nil { + return + } + } + + if len(s.rows) == 0 { + return "", errors.Newf( + "No row specified. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(") VALUES (") + for row_i, row := range s.rows { + if row_i > 0 { + _, _ = buf.WriteString(", (") + } + + if len(row) != len(s.columns) { + return "", errors.Newf( + "# of values does not match # of columns. Generated sql: %s", + buf.String()) + } + + for col_i, value := range row { + if col_i > 0 { + _ = buf.WriteByte(',') + } + + if value == nil { + return "", errors.Newf( + "nil value in row %d col %d. Generated sql: %s", + row_i, + col_i, + buf.String()) + } + + if err = value.SerializeSql(buf); err != nil { + return + } + } + _ = buf.WriteByte(')') + } + + if len(s.onDuplicateKeyUpdates) > 0 { + _, _ = buf.WriteString(" ON DUPLICATE KEY UPDATE ") + for i, colExpr := range s.onDuplicateKeyUpdates { + if i > 0 { + _, _ = buf.WriteString(", ") + } + + if colExpr.col == nil { + return "", errors.Newf( + ("nil column in on duplicate key update list. " + + "Generated sql: %s"), + buf.String()) + } + + if err = colExpr.col.SerializeSqlForColumnList(buf); err != nil { + return + } + + _ = buf.WriteByte('=') + + if colExpr.expr == nil { + return "", errors.Newf( + ("nil expression in on duplicate key update list. " + + "Generated sql: %s"), + buf.String()) + } + + if err = colExpr.expr.SerializeSql(buf); err != nil { + return + } + } + } + + return buf.String(), nil +} + +// +// UPDATE statement =========================================================== +// + +func newUpdateStatement(table WritableTable) UpdateStatement { + return &updateStatementImpl{ + table: table, + updateValues: make(map[NonAliasColumn]Expression), + limit: -1, + } +} + +type updateStatementImpl struct { + table WritableTable + updateValues map[NonAliasColumn]Expression + where BoolExpression + order *listClause + limit int64 + comment string +} + +func (u *updateStatementImpl) Set( + column NonAliasColumn, + expression Expression) UpdateStatement { + + u.updateValues[column] = expression + return u +} + +func (u *updateStatementImpl) Where(expression BoolExpression) UpdateStatement { + u.where = expression + return u +} + +func (u *updateStatementImpl) OrderBy( + clauses ...OrderByClause) UpdateStatement { + + u.order = newOrderByListClause(clauses...) + return u +} + +func (u *updateStatementImpl) Limit(limit int64) UpdateStatement { + u.limit = limit + return u +} + +func (u *updateStatementImpl) Comment(comment string) UpdateStatement { + u.comment = comment + return u +} + +func (u *updateStatementImpl) String(database string) (sql string, err error) { + if !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("UPDATE ") + + if err = writeComment(u.comment, buf); err != nil { + return + } + + if u.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = u.table.SerializeSql(database, buf); err != nil { + return + } + + if len(u.updateValues) == 0 { + return "", errors.Newf( + "No column updated. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" SET ") + addComma := false + + // Sorting is too hard in go, just create a second map ... + updateValues := make(map[string]Expression) + for col, expr := range u.updateValues { + if col == nil { + return "", errors.Newf( + "nil column. Generated sql: %s", + buf.String()) + } + + updateValues[col.Name()] = expr + } + + for _, col := range u.table.Columns() { + val, inMap := updateValues[col.Name()] + if !inMap { + continue + } + + if addComma { + _, _ = buf.WriteString(", ") + } + + if val == nil { + return "", errors.Newf( + "nil value. Generated sql: %s", + buf.String()) + } + + if err = col.SerializeSql(buf); err != nil { + return + } + + _ = buf.WriteByte('=') + if err = val.SerializeSql(buf); err != nil { + return + } + + addComma = true + } + + if u.where == nil { + return "", errors.Newf( + "Updating without a WHERE clause. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" WHERE ") + if err = u.where.SerializeSql(buf); err != nil { + return + } + + if u.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = u.order.SerializeSql(buf); err != nil { + return + } + } + + if u.limit >= 0 { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", u.limit)) + } + + return buf.String(), nil +} + +// +// DELETE statement =========================================================== +// + +func newDeleteStatement(table WritableTable) DeleteStatement { + return &deleteStatementImpl{ + table: table, + limit: -1, + } +} + +type deleteStatementImpl struct { + table WritableTable + where BoolExpression + order *listClause + limit int64 + comment string +} + +func (d *deleteStatementImpl) Where(expression BoolExpression) DeleteStatement { + d.where = expression + return d +} + +func (d *deleteStatementImpl) OrderBy( + clauses ...OrderByClause) DeleteStatement { + + d.order = newOrderByListClause(clauses...) + return d +} + +func (d *deleteStatementImpl) Limit(limit int64) DeleteStatement { + d.limit = limit + return d +} + +func (d *deleteStatementImpl) Comment(comment string) DeleteStatement { + d.comment = comment + return d +} + +func (d *deleteStatementImpl) String(database string) (sql string, err error) { + if !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("DELETE FROM ") + + if err = writeComment(d.comment, buf); err != nil { + return + } + + if d.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = d.table.SerializeSql(database, buf); err != nil { + return + } + + if d.where == nil { + return "", errors.Newf( + "Deleting without a WHERE clause. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" WHERE ") + if err = d.where.SerializeSql(buf); err != nil { + return + } + + if d.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = d.order.SerializeSql(buf); err != nil { + return + } + } + + if d.limit >= 0 { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", d.limit)) + } + + return buf.String(), nil +} + +// +// LOCK statement =========================================================== +// + +// NewLockStatement returns a SQL representing empty set of locks. You need to use +// AddReadLock/AddWriteLock to add tables that need to be locked. +// NOTE: You need at least one lock in the set for it to be a valid statement. +func NewLockStatement() LockStatement { + return &lockStatementImpl{} +} + +type lockStatementImpl struct { + locks []tableLock +} + +type tableLock struct { + t *Table + w bool +} + +// AddReadLock takes read lock on the table. +func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement { + s.locks = append(s.locks, tableLock{t: t, w: false}) + return s +} + +// AddWriteLock takes write lock on the table. +func (s *lockStatementImpl) AddWriteLock(t *Table) LockStatement { + s.locks = append(s.locks, tableLock{t: t, w: true}) + return s +} + +func (s *lockStatementImpl) String(database string) (sql string, err error) { + if !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + if len(s.locks) == 0 { + return "", errors.New("No locks added") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("LOCK TABLES ") + + for idx, lock := range s.locks { + if lock.t == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = lock.t.SerializeSql(database, buf); err != nil { + return + } + + if lock.w { + _, _ = buf.WriteString(" WRITE") + } else { + _, _ = buf.WriteString(" READ") + } + + if idx != len(s.locks)-1 { + _, _ = buf.WriteString(", ") + } + } + + return buf.String(), nil +} + +// NewUnlockStatement returns SQL statement that can be used to release table locks +// grabbed by the current session. +func NewUnlockStatement() UnlockStatement { + return &unlockStatementImpl{} +} + +type unlockStatementImpl struct { +} + +func (s *unlockStatementImpl) String(database string) (sql string, err error) { + return "UNLOCK TABLES", nil +} + +// Set GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID. +func NewGtidNextStatement(sid []byte, gno uint64) GtidNextStatement { + return >idNextStatementImpl{ + sid: sid, + gno: gno, + } +} + +type gtidNextStatementImpl struct { + sid []byte + gno uint64 +} + +func (s *gtidNextStatementImpl) String(database string) (sql string, err error) { + // This statement sets a session local variable defining what the next transaction ID is. It + // does not interact with other MySQL sessions. It is neither a DDL nor DML statement, so we + // don't have to worry about data corruption. + // Because of the string formatting (hex plus an integer), can't morph into another statement. + // See: https://dev.mysql.com/doc/refman/5.7/en/replication-options-gtids.html + const gtidFormatString = "SET GTID_NEXT=\"%x-%x-%x-%x-%x:%d\"" + + buf := new(bytes.Buffer) + _, _ = buf.WriteString(fmt.Sprintf(gtidFormatString, + s.sid[:4], s.sid[4:6], s.sid[6:8], s.sid[8:10], s.sid[10:], s.gno)) + return buf.String(), nil +} + +// +// Util functions ============================================================= +// + +// Once again, teisenberger is lazy. Here's a quick filter on comments +var validCommentRegexp *regexp.Regexp = regexp.MustCompile("^[\\w .?]*$") + +func isValidComment(comment string) bool { + return validCommentRegexp.MatchString(comment) +} + +func writeComment(comment string, buf *bytes.Buffer) error { + if comment != "" { + _, _ = buf.WriteString("/* ") + if !isValidComment(comment) { + return errors.Newf("Invalid comment: %s", comment) + } + _, _ = buf.WriteString(comment) + _, _ = buf.WriteString(" */") + } + return nil +} + +func newOrderByListClause(clauses ...OrderByClause) *listClause { + ret := &listClause{ + clauses: make([]Clause, len(clauses), len(clauses)), + includeParentheses: false, + } + + for i, c := range clauses { + ret.clauses[i] = c + } + + return ret +} diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go new file mode 100644 index 0000000..83d1998 --- /dev/null +++ b/sqlbuilder/statement_test.go @@ -0,0 +1,660 @@ +package sqlbuilder + +import ( + "time" + + gc "gopkg.in/check.v1" + + "github.com/dropbox/godropbox/errors" +) + +type StmtSuite struct { +} + +var _ = gc.Suite(&StmtSuite{}) + +// NOTE: tables / columns are defined in test_utils.go + +// +// SELECT statement tests +// + +func (s *StmtSuite) TestSelectEmptyProjection(c *gc.C) { + _, err := table1.Select().String("db") + + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestSelectSingleColumn(c *gc.C) { + sql, err := table1.Select(table1Col1).String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1` FROM `db`.`table1`") +} + +func (s *StmtSuite) TestSelectMultiColumns(c *gc.C) { + sql, err := table1.Select(table1Col1, table1Col2).String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1`,`table1`.`col2` FROM `db`.`table1`") +} + +func (s *StmtSuite) TestSelectWhere(c *gc.C) { + q := table1.Select(table1Col1).Where(GtL(table1Col1, 123)) + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1` FROM `db`.`table1` WHERE `table1`.`col1`>123") +} + +func (s *StmtSuite) TestSelectWhereDate(c *gc.C) { + date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) + + q := table1.Select(table1Col1).Where(GtL(table1Col4, date)) + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1` FROM `db`.`table1` "+ + "WHERE `table1`.`col4`>'1999-01-02 03:04:05.000000'") +} + +func (s *StmtSuite) TestSelectAndWhere(c *gc.C) { + q := table1.Select(table1Col1).AndWhere(GtL(table1Col1, 123)) + q.AndWhere(LtL(table1Col1, 321)) + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1` FROM `db`.`table1` WHERE (`table1`.`col1`>123 AND `table1`.`col1`<321)") +} + +func (s *StmtSuite) TestSelectCopy(c *gc.C) { + q := table1.Select(table1Col1).Where(GtL(table1Col1, 123)) + qq := q.Copy().Where(GtL(table1Col1, 321)).OrderBy(table1Col1) + + // Initial query unchanged + sql, err := q.String("db") + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1` FROM `db`.`table1` WHERE `table1`.`col1`>123") + // New query changed + sql, err = qq.String("db") + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1` FROM `db`.`table1` WHERE `table1`.`col1`>321 ORDER BY `table1`.`col1`") + +} + +func (s *StmtSuite) TestSelectLimitWithoutOffset(c *gc.C) { + q := table1.Select(table1Col1).Limit(5) + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1` FROM `db`.`table1` LIMIT 5") +} + +func (s *StmtSuite) TestSelectLimitWithOffset(c *gc.C) { + q := table1.Select(table1Col1).Limit(5).Offset(2) + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1` FROM `db`.`table1` LIMIT 2, 5") +} + +func (s *StmtSuite) TestSelectGroupBy(c *gc.C) { + q := table1.Select( + table1Col1, + table1Col2, + Alias("total", SqlFunc("sum", table1Col3))) + q.GroupBy(table1Col1, table1Col2) + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1`,`table1`.`col2`,"+ + "(sum(`table1`.`col3`)) AS `total` "+ + "FROM `db`.`table1` GROUP BY `table1`.`col1`,`table1`.`col2`") +} + +func (s *StmtSuite) TestSelectSingleOrderBy(c *gc.C) { + q := table1.Select(table1Col1, table1Col2).OrderBy(table1Col2) + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1`,`table1`.`col2` "+ + "FROM `db`.`table1` ORDER BY `table1`.`col2`") +} + +func (s *StmtSuite) TestSelectOrderByAsc(c *gc.C) { + q := table1.Select(table1Col1, table1Col2).OrderBy(Asc(table1Col2)) + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1`,`table1`.`col2` "+ + "FROM `db`.`table1` ORDER BY `table1`.`col2` ASC") +} + +func (s *StmtSuite) TestSelectOrderByDesc(c *gc.C) { + q := table1.Select(table1Col1, table1Col2).OrderBy(Desc(table1Col2)) + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1`,`table1`.`col2` "+ + "FROM `db`.`table1` ORDER BY `table1`.`col2` DESC") +} + +func (s *StmtSuite) TestSelectMultiOrderBy(c *gc.C) { + q := table1.Select(table1Col1, table1Col2) + q.OrderBy(table1Col2, table1Col1) + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1`,`table1`.`col2` "+ + "FROM `db`.`table1` "+ + "ORDER BY `table1`.`col2`,`table1`.`col1`") +} + +func (s *StmtSuite) TestSelectOnJoin(c *gc.C) { + + join := table1.InnerJoinOn(table2, Eq(table1Col3, table2Col3)) + sql, err := join.Select(table1Col1, table2Col4).String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1`,`table2`.`col4` "+ + "FROM `db`.`table1` JOIN `db`.`table2` "+ + "ON `table1`.`col3`=`table2`.`col3`") +} + +func (s *StmtSuite) TestSelectWithSharedLock(c *gc.C) { + + q := table1.Select(table1Col1).Where(GtL(table1Col1, 123)).WithSharedLock() + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT `table1`.`col1` FROM `db`.`table1` "+ + "WHERE `table1`.`col1`>123 LOCK IN SHARE MODE") +} + +func (s *StmtSuite) TestSelectDistinct(c *gc.C) { + q := table1.Select(table1Col1).Distinct() + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "SELECT DISTINCT `table1`.`col1` FROM `db`.`table1`") +} + +// +// INSERT statement tests +// + +func (s *StmtSuite) TestInsertNoColumn(c *gc.C) { + _, err := table1.Insert().Add().String("db") + + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestInsertNoRow(c *gc.C) { + _, err := table1.Insert(table1Col1).String("db") + + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestInsertColumnLengthMismatch(c *gc.C) { + _, err := table1.Insert(table1Col1, table1Col2).Add(nil).String("db") + + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestInsertNilValue(c *gc.C) { + _, err := table1.Insert(table1Col1).Add(nil).String("db") + + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestInsertNilColumn(c *gc.C) { + _, err := table1.Insert(nil).Add(Literal(1)).String("db") + + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestInsertSingleValue(c *gc.C) { + sql, err := table1.Insert(table1Col1).Add(Literal(1)).String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "INSERT INTO `db`.`table1` (`table1`.`col1`) VALUES (1)") +} + +func (s *StmtSuite) TestInsertDate(c *gc.C) { + date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) + + sql, err := table1.Insert(table1Col4).Add(Literal(date)).String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "INSERT INTO `db`.`table1` (`table1`.`col4`) "+ + "VALUES ('1999-01-02 03:04:05.000000')") +} + +func (s *StmtSuite) TestInsertIgnore(c *gc.C) { + stmt := table1.Insert(table1Col1).Add(Literal(1)).IgnoreDuplicates(true) + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "INSERT IGNORE INTO `db`.`table1` (`table1`.`col1`) VALUES (1)") +} + +func (s *StmtSuite) TestInsertMultipleValues(c *gc.C) { + stmt := table1.Insert(table1Col1, table1Col2, table1Col3) + stmt.Add(Literal(1), Literal(2), Literal(3)) + + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "INSERT INTO `db`.`table1` "+ + "(`table1`.`col1`,`table1`.`col2`,`table1`.`col3`) "+ + "VALUES (1,2,3)") +} + +func (s *StmtSuite) TestInsertMultipleRows(c *gc.C) { + stmt := table1.Insert(table1Col1, table1Col2) + stmt.Add(Literal(1), Literal(2)) + stmt.Add(Literal(11), Literal(22)) + stmt.Add(Literal(111), Literal(222)) + + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "INSERT INTO `db`.`table1` "+ + "(`table1`.`col1`,`table1`.`col2`) "+ + "VALUES (1,2), (11,22), (111,222)") +} + +func (s *StmtSuite) TestOnDuplicateKeyUpdateNilCol(c *gc.C) { + stmt := table1.Insert(table1Col1, table1Col2) + stmt.Add(Literal(1), Literal(2)) + stmt.AddOnDuplicateKeyUpdate(nil, Literal(3)) + + _, err := stmt.String("db") + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestOnDuplicateKeyUpdateNilExpr(c *gc.C) { + stmt := table1.Insert(table1Col1, table1Col2) + stmt.Add(Literal(1), Literal(2)) + stmt.AddOnDuplicateKeyUpdate(table1Col1, nil) + + _, err := stmt.String("db") + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestOnDuplicateKeyUpdateSingle(c *gc.C) { + stmt := table1.Insert(table1Col1, table1Col2) + stmt.Add(Literal(1), Literal(2)) + stmt.AddOnDuplicateKeyUpdate(table1Col3, Literal(3)) + + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "INSERT INTO `db`.`table1` "+ + "(`table1`.`col1`,`table1`.`col2`) "+ + "VALUES (1,2) "+ + "ON DUPLICATE KEY UPDATE `table1`.`col3`=3") +} + +func (s *StmtSuite) TestOnDuplicateKeyUpdateMulti(c *gc.C) { + stmt := table1.Insert(table1Col1, table1Col2) + stmt.Add(Literal(1), Literal(2)) + stmt.AddOnDuplicateKeyUpdate(table1Col3, Literal(3)) + stmt.AddOnDuplicateKeyUpdate(table1Col2, Literal(4)) + + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "INSERT INTO `db`.`table1` "+ + "(`table1`.`col1`,`table1`.`col2`) "+ + "VALUES (1,2) "+ + "ON DUPLICATE KEY UPDATE `table1`.`col3`=3, `table1`.`col2`=4") +} + +// +// UPDATE statement tests ===================================================== +// + +func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) { + stmt := table1.Update().Set(nil, Literal(1)) + _, err := stmt.String("db") + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestUpdateNilExpr(c *gc.C) { + stmt := table1.Update().Set(table1Col1, nil) + _, err := stmt.String("db") + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestUpdateUnconditionally(c *gc.C) { + stmt := table1.Update().Set(table1Col1, Literal(1)) + _, err := stmt.String("db") + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) { + stmt := table1.Update().Set(table1Col1, Literal(1)) + stmt.Where(EqL(table1Col2, 2)) + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "UPDATE `db`.`table1` SET `table1`.`col1`=1 WHERE `table1`.`col2`=2") +} + +func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) { + stmt := table1.Update().Set(table1.C("col1"), Literal(1)) + stmt.Where(EqL(table1Col2, 2)) + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "UPDATE `db`.`table1` SET `table1`.`col1`=1 WHERE `table1`.`col2`=2") +} + +func (s *StmtSuite) TestUpdateMultiValues(c *gc.C) { + stmt := table1.Update() + stmt.Set(table1Col1, Literal(1)) + stmt.Set(table1Col2, Literal(2)) + stmt.Where(EqL(table1Col2, 3)) + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "UPDATE `db`.`table1` "+ + "SET `table1`.`col1`=1, `table1`.`col2`=2 "+ + "WHERE `table1`.`col2`=3") +} + +func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) { + stmt := table1.Update().Set(table1Col1, Literal(1)) + stmt.Where(EqL(table1Col2, 2)) + stmt.OrderBy(table1Col2) + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "UPDATE `db`.`table1` "+ + "SET `table1`.`col1`=1 "+ + "WHERE `table1`.`col2`=2 "+ + "ORDER BY `table1`.`col2`") +} + +func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) { + stmt := table1.Update().Set(table1Col1, Literal(1)) + stmt.Where(EqL(table1Col2, 2)) + stmt.Limit(5) + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "UPDATE `db`.`table1` "+ + "SET `table1`.`col1`=1 "+ + "WHERE `table1`.`col2`=2 "+ + "LIMIT 5") +} + +// +// DELETE statement tests ===================================================== +// + +func (s *StmtSuite) TestDeleteUnconditionally(c *gc.C) { + _, err := table1.Delete().String("db") + c.Assert(err, gc.NotNil) +} + +func (s *StmtSuite) TestDeleteWithWhere(c *gc.C) { + sql, err := table1.Delete().Where(EqL(table1Col1, 1)).String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "DELETE FROM `db`.`table1` WHERE `table1`.`col1`=1") +} + +func (s *StmtSuite) TestDeleteWithOrderBy(c *gc.C) { + stmt := table1.Delete().Where(EqL(table1Col1, 1)).OrderBy(table1Col1) + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "DELETE FROM `db`.`table1` "+ + "WHERE `table1`.`col1`=1 "+ + "ORDER BY `table1`.`col1`") +} + +func (s *StmtSuite) TestDeleteWithLimit(c *gc.C) { + stmt := table1.Delete().Where(EqL(table1Col1, 1)).Limit(5) + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert( + sql, + gc.Equals, + "DELETE FROM `db`.`table1` WHERE `table1`.`col1`=1 LIMIT 5") +} + +// +// LOCK/UNLOCK statement tests ================================================ +// + +func (s *StmtSuite) TestLockStatement(c *gc.C) { + stmt := NewLockStatement().AddReadLock(table1).AddWriteLock(table2) + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + + c.Assert(sql, gc.Equals, "LOCK TABLES `db`.`table1` READ, `db`.`table2` WRITE") +} + +func (s *StmtSuite) TestUnlockStatement(c *gc.C) { + stmt := NewUnlockStatement() + sql, err := stmt.String("db") + c.Assert(err, gc.IsNil) + c.Assert(sql, gc.Equals, "UNLOCK TABLES") + +} + +func (s *StmtSuite) TestUnionSelectStatement(c *gc.C) { + select_queries := make([]SelectStatement, 0, 3) + + select_queries = append(select_queries, + table1.Select(table1Col1).Where(GtL(table1Col1, 123)), + table1.Select(table1Col1).Where(GtL(table1Col1, 456)), + table1.Select(table1Col1).Where(LtL(table1Col1, 23)), + ) + + q := Union(select_queries...) + + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "(SELECT `table1`.`col1` FROM `db`.`table1` WHERE `table1`.`col1`>123) "+ + "UNION (SELECT `table1`.`col1` FROM `db`.`table1` WHERE `table1`.`col1`>456) "+ + "UNION (SELECT `table1`.`col1` FROM `db`.`table1` WHERE `table1`.`col1`<23)") +} + +func (s *StmtSuite) TestUnionLimitWithoutOrderBy(c *gc.C) { + select_queries := make([]SelectStatement, 0, 3) + + select_queries = append(select_queries, + table1.Select(table1Col1).Where(GtL(table1Col1, 123)).OrderBy(table1Col2), + table1.Select(table1Col1).Where(GtL(table1Col1, 456)), + table1.Select(table1Col1).Where(LtL(table1Col1, 23)), + ) + + q := Union(select_queries...) + + _, err := q.String("db") + + c.Assert(err, gc.NotNil) + c.Assert( + errors.GetMessage(err), + gc.Equals, + "All inner selects in Union statement must have LIMIT if they have ORDER BY") +} + +func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { + select_queries := make([]SelectStatement, 0, 3) + + select_queries = append(select_queries, + + table1.Select( + table1Col1, + table1Col2, + table1Col3, + table1Col4).AndWhere(GtL(table1Col1, 123)).AndWhere(LtL(table1Col1, 321)), + table1.Select(table1Col1).Where(And(GtL(table1Col1, 123), LtL(table1Col1, 321))), + table1.Select(table1Col1).Where(LtL(table1Col1, 23)).OrderBy(table1Col4).Limit(20), + ) + + q := Union(select_queries...) + q = q.Where(And(LtL(table1Col1, 1000), GtL(table1Col1, 15))) + q = q.OrderBy(Desc(table1Col4), Asc(table1Col3)) + q = q.Limit(5) + + _, err := q.String("db") + + c.Assert(err, gc.NotNil) + c.Assert( + errors.GetMessage(err), + gc.Equals, + "All inner selects in Union statement must select the "+ + "same number of columns. For sanity, you probably "+ + "want to select the same table columns in the same "+ + "order. If you are selecting on multiple tables, "+ + "use Null to pad to the right number of fields.") +} + +func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) { + + // tests on outer statement: Group By, Order By, Limit + // on inner statement: AndWhere, Where (with And), Order By, Limit + select_queries := make([]SelectStatement, 0, 3) + + // We're not trying to write a SQL parser, so we won't warn if you do something silly like + // try to apply a where clause on more columns than you've selected in your union select + select_queries = append(select_queries, + table1.Select( + table1Col1, + ).AndWhere(GtL(table1Col1, 123)).AndWhere(LtL(table1Col1, 321)), + table1.Select( + table1Col1, + ).Where(And(GtL(table1Col1, 456), LtL(table1Col1, 654))), + table1.Select( + table1Col1, + ).Where(LtL(table1Col1, 23)).OrderBy(table1Col4).Limit(20), + ) + + q := Union(select_queries...) + q = q.Where(And(LtL(table1Col1, 1000), GtL(table1Col1, 15))) + + q = q.OrderBy(Desc(table1Col4), Asc(table1Col3)) + q = q.Limit(5) + q = q.GroupBy(table1Col4) + + sql, err := q.String("db") + + c.Assert(err, gc.IsNil) + c.Assert( + sql, + gc.Equals, + "(SELECT `table1`.`col1` FROM `db`.`table1` WHERE "+ + "(`table1`.`col1`>123 AND `table1`.`col1`<321)) "+ + "UNION (SELECT `table1`.`col1` FROM `db`.`table1` "+ + "WHERE (`table1`.`col1`>456 AND `table1`.`col1`<654)) "+ + "UNION (SELECT `table1`.`col1` FROM `db`.`table1` "+ + "WHERE `table1`.`col1`<23 ORDER BY `table1`.`col4` LIMIT 20) "+ + "WHERE (`table1`.`col1`<1000 AND `table1`.`col1`>15) "+ + "GROUP BY `table1`.`col4` ORDER BY `table1`.`col4` DESC,`table1`.`col3` ASC "+ + "LIMIT 5") + +} diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go new file mode 100644 index 0000000..dfcc4a9 --- /dev/null +++ b/sqlbuilder/table.go @@ -0,0 +1,317 @@ +// Modeling of tables. This is where query preparation starts + +package sqlbuilder + +import ( + "bytes" + "fmt" + + "github.com/dropbox/godropbox/errors" +) + +// The sql table read interface. NOTE: NATURAL JOINs, and join "USING" clause +// are not supported. +type ReadableTable interface { + // Returns the list of columns that are in the current table expression. + Columns() []NonAliasColumn + + // Generates the sql string for the current table expression. Note: the + // generated string may not be a valid/executable sql statement. + // The database is the name of the database the table is on + SerializeSql(database string, out *bytes.Buffer) error + + // Generates a select query on the current table. + Select(projections ...Projection) SelectStatement + + // Creates a inner join table expression using onCondition. + InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable + + // Creates a left join table expression using onCondition. + LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable + + // Creates a right join table expression using onCondition. + RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable +} + +// The sql table write interface. +type WritableTable interface { + // Returns the list of columns that are in the table. + Columns() []NonAliasColumn + + // Generates the sql string for the current table expression. Note: the + // generated string may not be a valid/executable sql statement. + // The database is the name of the database the table is on + SerializeSql(database string, out *bytes.Buffer) error + + Insert(columns ...NonAliasColumn) InsertStatement + Update() UpdateStatement + Delete() DeleteStatement +} + +// Defines a physical table in the database that is both readable and writable. +// This function will panic if name is not valid +func NewTable(name string, columns ...NonAliasColumn) *Table { + if !validIdentifierName(name) { + panic("Invalid table name") + } + + t := &Table{ + name: name, + columns: columns, + columnLookup: make(map[string]NonAliasColumn), + } + for _, c := range columns { + err := c.setTableName(name) + if err != nil { + panic(err) + } + t.columnLookup[c.Name()] = c + } + + if len(columns) == 0 { + panic(fmt.Sprintf("Table %s has no columns", name)) + } + + return t +} + +type Table struct { + name string + columns []NonAliasColumn + columnLookup map[string]NonAliasColumn + // If not empty, the name of the index to force + forcedIndex string +} + +// Returns the specified column, or errors if it doesn't exist in the table +func (t *Table) getColumn(name string) (NonAliasColumn, error) { + if c, ok := t.columnLookup[name]; ok { + return c, nil + } + return nil, errors.Newf("No such column '%s' in table '%s'", name, t.name) +} + +// Returns a pseudo column representation of the column name. Error checking +// is deferred to SerializeSql. +func (t *Table) C(name string) NonAliasColumn { + return &deferredLookupColumn{ + table: t, + colName: name, + } +} + +// Returns all columns for a table as a slice of projections +func (t *Table) Projections() []Projection { + result := make([]Projection, 0) + + for _, col := range t.columns { + result = append(result, col) + } + + return result +} + +// Returns the table's name in the database +func (t *Table) Name() string { + return t.name +} + +// Returns a list of the table's columns +func (t *Table) Columns() []NonAliasColumn { + return t.columns +} + +// Returns a copy of this table, but with the specified index forced. +func (t *Table) ForceIndex(index string) *Table { + newTable := *t + newTable.forcedIndex = index + return &newTable +} + +// Generates the sql string for the current table expression. Note: the +// generated string may not be a valid/executable sql statement. +func (t *Table) SerializeSql(database string, out *bytes.Buffer) error { + _, _ = out.WriteString("`") + _, _ = out.WriteString(database) + _, _ = out.WriteString("`.`") + _, _ = out.WriteString(t.Name()) + _, _ = out.WriteString("`") + + if t.forcedIndex != "" { + if !validIdentifierName(t.forcedIndex) { + return errors.Newf("'%s' is not a valid identifier for an index", t.forcedIndex) + } + _, _ = out.WriteString(" FORCE INDEX (`") + _, _ = out.WriteString(t.forcedIndex) + _, _ = out.WriteString("`)") + } + + return nil +} + +// Generates a select query on the current table. +func (t *Table) Select(projections ...Projection) SelectStatement { + return newSelectStatement(t, projections) +} + +// Creates a inner join table expression using onCondition. +func (t *Table) InnerJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return InnerJoinOn(t, table, onCondition) +} + +// Creates a left join table expression using onCondition. +func (t *Table) LeftJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return LeftJoinOn(t, table, onCondition) +} + +// Creates a right join table expression using onCondition. +func (t *Table) RightJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return RightJoinOn(t, table, onCondition) +} + +func (t *Table) Insert(columns ...NonAliasColumn) InsertStatement { + return newInsertStatement(t, columns...) +} + +func (t *Table) Update() UpdateStatement { + return newUpdateStatement(t) +} + +func (t *Table) Delete() DeleteStatement { + return newDeleteStatement(t) +} + +type joinType int + +const ( + INNER_JOIN joinType = iota + LEFT_JOIN + RIGHT_JOIN +) + +// Join expressions are pseudo readable tables. +type joinTable struct { + lhs ReadableTable + rhs ReadableTable + join_type joinType + onCondition BoolExpression +} + +func newJoinTable( + lhs ReadableTable, + rhs ReadableTable, + join_type joinType, + onCondition BoolExpression) ReadableTable { + + return &joinTable{ + lhs: lhs, + rhs: rhs, + join_type: join_type, + onCondition: onCondition, + } +} + +func InnerJoinOn( + lhs ReadableTable, + rhs ReadableTable, + onCondition BoolExpression) ReadableTable { + + return newJoinTable(lhs, rhs, INNER_JOIN, onCondition) +} + +func LeftJoinOn( + lhs ReadableTable, + rhs ReadableTable, + onCondition BoolExpression) ReadableTable { + + return newJoinTable(lhs, rhs, LEFT_JOIN, onCondition) +} + +func RightJoinOn( + lhs ReadableTable, + rhs ReadableTable, + onCondition BoolExpression) ReadableTable { + + return newJoinTable(lhs, rhs, RIGHT_JOIN, onCondition) +} + +func (t *joinTable) Columns() []NonAliasColumn { + columns := make([]NonAliasColumn, 0) + columns = append(columns, t.lhs.Columns()...) + columns = append(columns, t.rhs.Columns()...) + + return columns +} + +func (t *joinTable) SerializeSql( + database string, + out *bytes.Buffer) (err error) { + + if t.lhs == nil { + return errors.Newf("nil lhs. Generated sql: %s", out.String()) + } + if t.rhs == nil { + return errors.Newf("nil rhs. Generated sql: %s", out.String()) + } + if t.onCondition == nil { + return errors.Newf("nil onCondition. Generated sql: %s", out.String()) + } + + if err = t.lhs.SerializeSql(database, out); err != nil { + return + } + + switch t.join_type { + case INNER_JOIN: + _, _ = out.WriteString(" JOIN ") + case LEFT_JOIN: + _, _ = out.WriteString(" LEFT JOIN ") + case RIGHT_JOIN: + _, _ = out.WriteString(" RIGHT JOIN ") + } + + if err = t.rhs.SerializeSql(database, out); err != nil { + return + } + + _, _ = out.WriteString(" ON ") + if err = t.onCondition.SerializeSql(out); err != nil { + return + } + + return nil +} + +func (t *joinTable) Select(projections ...Projection) SelectStatement { + return newSelectStatement(t, projections) +} + +func (t *joinTable) InnerJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return InnerJoinOn(t, table, onCondition) +} + +func (t *joinTable) LeftJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return LeftJoinOn(t, table, onCondition) +} + +func (t *joinTable) RightJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return RightJoinOn(t, table, onCondition) +} diff --git a/sqlbuilder/table_test.go b/sqlbuilder/table_test.go new file mode 100644 index 0000000..8b30fa7 --- /dev/null +++ b/sqlbuilder/table_test.go @@ -0,0 +1,209 @@ +package sqlbuilder + +import ( + "bytes" + + gc "gopkg.in/check.v1" +) + +type TableSuite struct { +} + +var _ = gc.Suite(&TableSuite{}) + +// NOTE: tables / columns are defined in statement_test.go + +func (s *TableSuite) TestBasicColumns(c *gc.C) { + cols := table1.Columns() + + c.Assert(len(cols), gc.Equals, 4) + c.Assert(cols[0], gc.Equals, table1Col1) + c.Assert(cols[1], gc.Equals, table1Col2) + c.Assert(cols[2], gc.Equals, table1Col3) + c.Assert(cols[3], gc.Equals, table1Col4) +} + +func (s *TableSuite) TestCValidLookup(c *gc.C) { + col := table1.C("col1") + + buf := &bytes.Buffer{} + + err := col.SerializeSql(buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert(sql, gc.Equals, "`table1`.`col1`") +} + +func (s *TableSuite) TestCInvalidLookup(c *gc.C) { + col := table1.C("foo") + + buf := &bytes.Buffer{} + + err := col.SerializeSql(buf) + c.Assert(err, gc.NotNil) +} + +func (s *TableSuite) TestValidForcedIndex(c *gc.C) { + t := table1.ForceIndex("foo") + buf := &bytes.Buffer{} + err := t.SerializeSql("db", buf) + c.Assert(err, gc.IsNil) + sql := buf.String() + c.Assert(sql, gc.Equals, "`db`.`table1` FORCE INDEX (`foo`)") + + // Ensure the original table is unchanged + buf = &bytes.Buffer{} + err = table1.SerializeSql("db", buf) + c.Assert(err, gc.IsNil) + sql = buf.String() + c.Assert(sql, gc.Equals, "`db`.`table1`") +} + +func (s *TableSuite) TestInvalidForcedIndex(c *gc.C) { + t := table1.ForceIndex("foo\x00") + buf := &bytes.Buffer{} + err := t.SerializeSql("db", buf) + c.Assert(err, gc.NotNil) +} + +func (s *TableSuite) TestJoinNilLeftTable(c *gc.C) { + join := InnerJoinOn(nil, table2, EqL(table2Col3, 123)) + + buf := &bytes.Buffer{} + + err := join.SerializeSql("db", buf) + c.Assert(err, gc.NotNil) +} + +func (s *TableSuite) TestJoinNilRightTable(c *gc.C) { + join := InnerJoinOn(table1, nil, EqL(table2Col3, 123)) + + buf := &bytes.Buffer{} + + err := join.SerializeSql("db", buf) + c.Assert(err, gc.NotNil) +} + +func (s *TableSuite) TestJoinNilOnCondition(c *gc.C) { + join := InnerJoinOn(table1, table2, nil) + + buf := &bytes.Buffer{} + + err := join.SerializeSql("db", buf) + c.Assert(err, gc.NotNil) +} + +func (s *TableSuite) TestInnerJoin(c *gc.C) { + join := table1.InnerJoinOn(table2, Eq(table1Col3, table2Col3)) + + buf := &bytes.Buffer{} + + err := join.SerializeSql("db", buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "`db`.`table1` JOIN `db`.`table2` ON `table1`.`col3`=`table2`.`col3`") +} + +func (s *TableSuite) TestLeftJoin(c *gc.C) { + join := table1.LeftJoinOn(table2, Eq(table1Col3, table2Col3)) + + buf := &bytes.Buffer{} + + err := join.SerializeSql("db", buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "`db`.`table1` LEFT JOIN `db`.`table2` "+ + "ON `table1`.`col3`=`table2`.`col3`") +} + +func (s *TableSuite) TestRightJoin(c *gc.C) { + join := table1.RightJoinOn(table2, Eq(table1Col3, table2Col3)) + + buf := &bytes.Buffer{} + + err := join.SerializeSql("db", buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "`db`.`table1` RIGHT JOIN `db`.`table2` "+ + "ON `table1`.`col3`=`table2`.`col3`") +} + +func (s *TableSuite) TestJoinColumns(c *gc.C) { + join := table1.RightJoinOn(table2, Eq(table1Col3, table2Col3)) + + cols := join.Columns() + c.Assert(len(cols), gc.Equals, 6) + c.Assert(cols[0], gc.Equals, table1Col1) + c.Assert(cols[1], gc.Equals, table1Col2) + c.Assert(cols[2], gc.Equals, table1Col3) + c.Assert(cols[3], gc.Equals, table1Col4) + c.Assert(cols[4], gc.Equals, table2Col3) + c.Assert(cols[5], gc.Equals, table2Col4) +} + +func (s *TableSuite) TestNestedInnerJoin(c *gc.C) { + join1 := table1.InnerJoinOn(table2, Eq(table1Col3, table2Col3)) + join2 := join1.InnerJoinOn(table3, Eq(table1Col1, table3Col1)) + + buf := &bytes.Buffer{} + + err := join2.SerializeSql("db", buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "`db`.`table1` "+ + "JOIN `db`.`table2` ON `table1`.`col3`=`table2`.`col3` "+ + "JOIN `db`.`table3` ON `table1`.`col1`=`table3`.`col1`") +} + +func (s *TableSuite) TestNestedLeftJoin(c *gc.C) { + join1 := table1.InnerJoinOn(table2, Eq(table1Col3, table2Col3)) + join2 := join1.LeftJoinOn(table3, Eq(table1Col1, table3Col1)) + + buf := &bytes.Buffer{} + + err := join2.SerializeSql("db", buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "`db`.`table1` "+ + "JOIN `db`.`table2` ON `table1`.`col3`=`table2`.`col3` "+ + "LEFT JOIN `db`.`table3` ON `table1`.`col1`=`table3`.`col1`") +} + +func (s *TableSuite) TestNestedRightJoin(c *gc.C) { + join1 := table1.InnerJoinOn(table2, Eq(table1Col3, table2Col3)) + join2 := join1.RightJoinOn(table3, Eq(table1Col1, table3Col1)) + + buf := &bytes.Buffer{} + + err := join2.SerializeSql("db", buf) + c.Assert(err, gc.IsNil) + + sql := buf.String() + c.Assert( + sql, + gc.Equals, + "`db`.`table1` "+ + "JOIN `db`.`table2` ON `table1`.`col3`=`table2`.`col3` "+ + "RIGHT JOIN `db`.`table3` ON `table1`.`col1`=`table3`.`col1`") +} diff --git a/sqlbuilder/test_utils.go b/sqlbuilder/test_utils.go new file mode 100644 index 0000000..a7a0250 --- /dev/null +++ b/sqlbuilder/test_utils.go @@ -0,0 +1,26 @@ +package sqlbuilder + +var table1Col1 = IntColumn("col1", Nullable) +var table1Col2 = IntColumn("col2", Nullable) +var table1Col3 = IntColumn("col3", Nullable) +var table1Col4 = DateTimeColumn("col4", Nullable) +var table1 = NewTable( + "table1", + table1Col1, + table1Col2, + table1Col3, + table1Col4) + +var table2Col3 = IntColumn("col3", Nullable) +var table2Col4 = IntColumn("col4", Nullable) +var table2 = NewTable( + "table2", + table2Col3, + table2Col4) + +var table3Col1 = IntColumn("col1", Nullable) +var table3Col2 = IntColumn("col2", Nullable) +var table3 = NewTable( + "table3", + table3Col1, + table3Col2) diff --git a/sqlbuilder/types.go b/sqlbuilder/types.go new file mode 100644 index 0000000..c9d05ea --- /dev/null +++ b/sqlbuilder/types.go @@ -0,0 +1,79 @@ +package sqlbuilder + +import ( + "bytes" +) + +type Clause interface { + SerializeSql(out *bytes.Buffer) error +} + +// A clause that can be used in order by +type OrderByClause interface { + Clause + isOrderByClauseInterface +} + +// An expression +type Expression interface { + Clause + isExpressionInterface +} + +type BoolExpression interface { + Clause + isBoolExpressionInterface +} + +// A clause that is selectable. +type Projection interface { + Clause + isProjectionInterface + SerializeSqlForColumnList(out *bytes.Buffer) error +} + +// +// Boiler plates ... +// + +type isOrderByClauseInterface interface { + isOrderByClauseType() +} + +type isOrderByClause struct { +} + +func (o *isOrderByClause) isOrderByClauseType() { +} + +type isExpressionInterface interface { + isExpressionType() +} + +type isExpression struct { + isOrderByClause // can always use expression in order by. +} + +func (e *isExpression) isExpressionType() { +} + +type isBoolExpressionInterface interface { + isExpressionInterface + isBoolExpressionType() +} + +type isBoolExpression struct { +} + +func (e *isBoolExpression) isBoolExpressionType() { +} + +type isProjectionInterface interface { + isProjectionType() +} + +type isProjection struct { +} + +func (p *isProjection) isProjectionType() { +}