diff --git a/sqlbuilder/clause.go b/sqlbuilder/clause.go index 1d8bbcc..66463ae 100644 --- a/sqlbuilder/clause.go +++ b/sqlbuilder/clause.go @@ -158,7 +158,11 @@ func (q *queryData) finalize() (string, []interface{}) { return q.buff.String() + ";\n", q.args } -func (q *queryData) insertArgument(arg interface{}) { +func (q *queryData) insertConstantArgument(arg interface{}) { + q.writeString(ArgToString(arg)) +} + +func (q *queryData) insertPreparedArgument(arg interface{}) { q.args = append(q.args, arg) argPlaceholder := "$" + strconv.Itoa(len(q.args)) diff --git a/sqlbuilder/func_expression.go b/sqlbuilder/func_expression.go index c4ed1a8..a9816ee 100644 --- a/sqlbuilder/func_expression.go +++ b/sqlbuilder/func_expression.go @@ -7,6 +7,7 @@ type funcExpressionImpl struct { name string expressions []expression + noBrackets bool } func ROW(expressions ...expression) expression { @@ -33,17 +34,40 @@ func (f *funcExpressionImpl) serialize(statement statementType, out *queryData, return errors.New("Function expressions is nil. ") } - out.writeString(f.name + "(") + addBrackets := !f.noBrackets || len(f.expressions) > 0 + + if addBrackets { + out.writeString(f.name + "(") + } else { + out.writeString(f.name) + } err := serializeExpressionList(statement, f.expressions, ", ", out) if err != nil { return err } - out.writeString(")") + + if addBrackets { + out.writeString(")") + } return nil } +type boolFunc struct { + funcExpressionImpl + boolInterfaceImpl +} + +func newBoolFunc(name string, expressions ...expression) BoolExpression { + boolFunc := &boolFunc{} + + boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) + boolFunc.boolInterfaceImpl.parent = boolFunc + + return boolFunc +} + type floatFunc struct { funcExpressionImpl floatInterfaceImpl @@ -91,7 +115,7 @@ type dateFunc struct { dateInterfaceImpl } -func newDateFunc(name string, expressions ...expression) DateExpression { +func newDateFunc(name string, expressions ...expression) *dateFunc { dateFunc := &dateFunc{} dateFunc.funcExpressionImpl = *newFunc(name, expressions, dateFunc) @@ -100,18 +124,46 @@ func newDateFunc(name string, expressions ...expression) DateExpression { return dateFunc } -type boolFunc struct { +type timeFunc struct { funcExpressionImpl - boolInterfaceImpl + timeInterfaceImpl } -func newBoolFunc(name string, expressions ...expression) BoolExpression { - boolFunc := &boolFunc{} +func newTimeFunc(name string, expressions ...expression) *timeFunc { + timeFun := &timeFunc{} - boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) - boolFunc.boolInterfaceImpl.parent = boolFunc + timeFun.funcExpressionImpl = *newFunc(name, expressions, timeFun) + timeFun.timeInterfaceImpl.parent = timeFun - return boolFunc + return timeFun +} + +type timezFunc struct { + funcExpressionImpl + timezInterfaceImpl +} + +func newTimezFunc(name string, expressions ...expression) *timezFunc { + timezFun := &timezFunc{} + + timezFun.funcExpressionImpl = *newFunc(name, expressions, timezFun) + timezFun.timezInterfaceImpl.parent = timezFun + + return timezFun +} + +type timestampFunc struct { + funcExpressionImpl + timestampInterfaceImpl +} + +func newTimestampFunc(name string, expressions ...expression) *timestampFunc { + timestampFunc := ×tampFunc{} + + timestampFunc.funcExpressionImpl = *newFunc(name, expressions, timestampFunc) + timestampFunc.timestampInterfaceImpl.parent = timestampFunc + + return timestampFunc } type timestampzFunc struct { @@ -119,7 +171,7 @@ type timestampzFunc struct { timestampzInterfaceImpl } -func newTimestampzFunc(name string, expressions ...expression) TimestampzExpression { +func newTimestampzFunc(name string, expressions ...expression) *timestampzFunc { timestampzFunc := ×tampzFunc{} timestampzFunc.funcExpressionImpl = *newFunc(name, expressions, timestampzFunc) @@ -258,6 +310,7 @@ func CHR(integerExpression IntegerExpression) StringExpression { return newStringFunc("CHR", integerExpression) } +// //func CONCAT(expressions ...expression) StringExpression { // return newStringFunc("CONCAT", expressions...) //} @@ -382,3 +435,71 @@ func TO_NUMBER(floatStr, format StringExpression) FloatExpression { func TO_TIMESTAMP(timestampzStr, format StringExpression) TimestampzExpression { return newTimestampzFunc("TO_TIMESTAMP", timestampzStr, format) } + +//----------------- Date/Time Functions and Operators ---------------// + +func CURRENT_DATE() DateExpression { + dateFunc := newDateFunc("CURRENT_DATE") + dateFunc.noBrackets = true + return dateFunc +} + +func CURRENT_TIME(precision ...int) TimezExpression { + var timezFunc *timezFunc + + if len(precision) > 0 { + timezFunc = newTimezFunc("CURRENT_TIME", ConstantLiteral(precision[0])) + } else { + timezFunc = newTimezFunc("CURRENT_TIME") + } + + timezFunc.noBrackets = true + + return timezFunc +} + +func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression { + var timestampzFunc *timestampzFunc + + if len(precision) > 0 { + timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", ConstantLiteral(precision[0])) + } else { + timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP") + } + + timestampzFunc.noBrackets = true + + return timestampzFunc +} + +func LOCALTIME(precision ...int) TimeExpression { + var timeFunc *timeFunc + + if len(precision) > 0 { + timeFunc = newTimeFunc("LOCALTIME", ConstantLiteral(precision[0])) + } else { + timeFunc = newTimeFunc("LOCALTIME") + } + + timeFunc.noBrackets = true + + return timeFunc +} + +func LOCALTIMESTAMP(precision ...int) TimestampExpression { + var timestampFunc *timestampFunc + + if len(precision) > 0 { + timestampFunc = newTimestampFunc("LOCALTIMESTAMP", ConstantLiteral(precision[0])) + } else { + timestampFunc = newTimestampFunc("LOCALTIMESTAMP") + } + + timestampFunc.noBrackets = true + + return timestampFunc +} + +func NOW() TimestampzExpression { + return newTimestampzFunc("NOW") +} diff --git a/sqlbuilder/literal_expression.go b/sqlbuilder/literal_expression.go index 4269d1d..247b1a0 100644 --- a/sqlbuilder/literal_expression.go +++ b/sqlbuilder/literal_expression.go @@ -5,7 +5,8 @@ import "fmt" // Representation of an escaped literal type literalExpression struct { expressionInterfaceImpl - value interface{} + value interface{} + constant bool } func Literal(value interface{}) *literalExpression { @@ -15,8 +16,19 @@ func Literal(value interface{}) *literalExpression { return &exp } +func ConstantLiteral(value interface{}) *literalExpression { + exp := Literal(value) + exp.constant = true + + return exp +} + func (l literalExpression) serialize(statement statementType, out *queryData, options ...serializeOption) error { - out.insertArgument(l.value) + if l.constant { + out.insertConstantArgument(l.value) + } else { + out.insertPreparedArgument(l.value) + } return nil } diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 70d9fe1..d2ec6ed 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -174,13 +174,13 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error { if s.limit >= 0 { out.nextLine() out.writeString("LIMIT") - out.insertArgument(s.limit) + out.insertPreparedArgument(s.limit) } if s.offset >= 0 { out.nextLine() out.writeString("OFFSET") - out.insertArgument(s.offset) + out.insertPreparedArgument(s.offset) } if s.forUpdate { diff --git a/sqlbuilder/set_statement.go b/sqlbuilder/set_statement.go index 3191f8c..6bc84ce 100644 --- a/sqlbuilder/set_statement.go +++ b/sqlbuilder/set_statement.go @@ -168,13 +168,13 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error { if s.limit >= 0 { out.nextLine() out.writeString("LIMIT") - out.insertArgument(s.limit) + out.insertPreparedArgument(s.limit) } if s.offset >= 0 { out.nextLine() out.writeString("OFFSET") - out.insertArgument(s.offset) + out.insertPreparedArgument(s.offset) } return nil diff --git a/sqlbuilder/time_expression_test.go b/sqlbuilder/time_expression_test.go index 2b5c6bf..154de6d 100644 --- a/sqlbuilder/time_expression_test.go +++ b/sqlbuilder/time_expression_test.go @@ -6,30 +6,30 @@ import ( func TestTimeExpressionEQ(t *testing.T) { assertExpressionSerialize(t, table1ColTime.EQ(table2ColTime), "(table1.colTime = table2.colTime)") - assertExpressionSerialize(t, table1ColTime.EQ(Time(10, 20, 0, 0)), "(table1.colTime = $1)", "10:20:00.000") + assertExpressionSerialize(t, table1ColTime.EQ(Time(10, 20, 0, 0)), "(table1.colTime = $1::time without time zone)", "10:20:00.000") } func TestTimeExpressionNOT_EQ(t *testing.T) { assertExpressionSerialize(t, table1ColTime.NOT_EQ(table2ColTime), "(table1.colTime != table2.colTime)") - assertExpressionSerialize(t, table1ColTime.NOT_EQ(Time(10, 20, 0, 0)), "(table1.colTime != $1)", "10:20:00.000") + assertExpressionSerialize(t, table1ColTime.NOT_EQ(Time(10, 20, 0, 0)), "(table1.colTime != $1::time without time zone)", "10:20:00.000") } func TestTimeExpressionLT(t *testing.T) { assertExpressionSerialize(t, table1ColTime.LT(table2ColTime), "(table1.colTime < table2.colTime)") - assertExpressionSerialize(t, table1ColTime.LT(Time(10, 20, 0, 0)), "(table1.colTime < $1)", "10:20:00.000") + assertExpressionSerialize(t, table1ColTime.LT(Time(10, 20, 0, 0)), "(table1.colTime < $1::time without time zone)", "10:20:00.000") } func TestTimeExpressionLT_EQ(t *testing.T) { assertExpressionSerialize(t, table1ColTime.LT_EQ(table2ColTime), "(table1.colTime <= table2.colTime)") - assertExpressionSerialize(t, table1ColTime.LT_EQ(Time(10, 20, 0, 0)), "(table1.colTime <= $1)", "10:20:00.000") + assertExpressionSerialize(t, table1ColTime.LT_EQ(Time(10, 20, 0, 0)), "(table1.colTime <= $1::time without time zone)", "10:20:00.000") } func TestTimeExpressionGT(t *testing.T) { assertExpressionSerialize(t, table1ColTime.GT(table2ColTime), "(table1.colTime > table2.colTime)") - assertExpressionSerialize(t, table1ColTime.GT(Time(10, 20, 0, 0)), "(table1.colTime > $1)", "10:20:00.000") + assertExpressionSerialize(t, table1ColTime.GT(Time(10, 20, 0, 0)), "(table1.colTime > $1::time without time zone)", "10:20:00.000") } func TestTimeExpressionGT_EQ(t *testing.T) { assertExpressionSerialize(t, table1ColTime.GT_EQ(table2ColTime), "(table1.colTime >= table2.colTime)") - assertExpressionSerialize(t, table1ColTime.GT_EQ(Time(10, 20, 0, 0)), "(table1.colTime >= $1)", "10:20:00.000") + assertExpressionSerialize(t, table1ColTime.GT_EQ(Time(10, 20, 0, 0)), "(table1.colTime >= $1::time without time zone)", "10:20:00.000") } diff --git a/tests/select_test.go b/tests/select_test.go index a4387a6..3ed9ace 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -973,7 +973,7 @@ SELECT payment.payment_id AS "payment.payment_id", payment.amount AS "payment.amount", payment.payment_date AS "payment.payment_date" FROM dvds.payment -WHERE payment.payment_date < '2007-02-14 22:16:01.000' +WHERE payment.payment_date < '2007-02-14 22:16:01.000'::timestamp without time zone ORDER BY payment.payment_date ASC; ` diff --git a/tests/types_test.go b/tests/types_test.go index ffe1d41..c12ac6b 100644 --- a/tests/types_test.go +++ b/tests/types_test.go @@ -111,7 +111,10 @@ func TestStringOperators(t *testing.T) { TO_HEX(AllTypes.IntegerPtr), ) - //fmt.Println(query.Sql()) + _, args, _ := query.Sql() + + fmt.Println(query.Sql()) + fmt.Println(args[15]) fmt.Println(query.DebugSql()) err := query.Query(db, &struct{}{}) @@ -284,6 +287,17 @@ func TestTimeOperators(t *testing.T) { AllTypes.Time.GT_EQ(AllTypes.Time), AllTypes.Time.GT_EQ(Time(23, 6, 6, 1)), + + CURRENT_DATE(), + CURRENT_TIME(), + CURRENT_TIME(2), + CURRENT_TIMESTAMP(), + CURRENT_TIMESTAMP(1), + LOCALTIME(), + LOCALTIME(11), + LOCALTIMESTAMP(), + LOCALTIMESTAMP(4), + NOW(), ) fmt.Println(query.DebugSql())