Refactor write update (#2310)
All checks were successful
test mariadb / test mariadb (push) Successful in 5m2s
test cockroach / test cockroach (push) Successful in 7m30s
test mssql / test mssql (push) Successful in 6m10s
test mysql / test mysql (push) Successful in 4m55s
test mysql8 / test mysql8 (push) Successful in 5m16s
test postgres / test postgres (push) Successful in 5m55s
test tidb / test tidb (push) Successful in 5m30s
test sqlite / unit test & test sqlite (push) Successful in 8m37s
All checks were successful
test mariadb / test mariadb (push) Successful in 5m2s
test cockroach / test cockroach (push) Successful in 7m30s
test mssql / test mssql (push) Successful in 6m10s
test mysql / test mysql (push) Successful in 4m55s
test mysql8 / test mysql8 (push) Successful in 5m16s
test postgres / test postgres (push) Successful in 5m55s
test tidb / test tidb (push) Successful in 5m30s
test sqlite / unit test & test sqlite (push) Successful in 8m37s
Reviewed-on: #2310
This commit is contained in:
parent
9aab1f689c
commit
cb4f310151
@ -427,7 +427,158 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter,
|
||||
}
|
||||
}
|
||||
|
||||
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string, args []interface{}) error {
|
||||
func (statement *Statement) GenConditionsFromMap(m interface{}) ([]builder.Cond, error) {
|
||||
switch t := m.(type) {
|
||||
case map[string]interface{}:
|
||||
conds := []builder.Cond{}
|
||||
for k, v := range t {
|
||||
conds = append(conds, builder.Eq{k: v})
|
||||
}
|
||||
return conds, nil
|
||||
case map[string]string:
|
||||
conds := []builder.Cond{}
|
||||
for k, v := range t {
|
||||
conds = append(conds, builder.Eq{k: v})
|
||||
}
|
||||
return conds, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported condition map type %v", t)
|
||||
}
|
||||
}
|
||||
|
||||
func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Value, hasPreviousSet bool) error {
|
||||
if v.Type().Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
table := statement.RefTable
|
||||
if !(statement.RefTable != nil && table.Version != "" && statement.CheckVersion) {
|
||||
return nil
|
||||
}
|
||||
|
||||
verValue, err := table.VersionColumn().ValueOfV(&v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if verValue == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if hasPreviousSet {
|
||||
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprint(w, statement.quote(table.Version), " = ", statement.quote(table.Version), " + 1"); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error {
|
||||
for i, expr := range statement.IncrColumns {
|
||||
if i > 0 || hasPreviousSet {
|
||||
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " + ?"); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(expr.Arg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error {
|
||||
// for update action to like "column = column - ?"
|
||||
for i, expr := range statement.DecrColumns {
|
||||
if i > 0 || hasPreviousSet {
|
||||
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " - ?"); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(expr.Arg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error {
|
||||
// for update action to like "column = expression"
|
||||
for i, expr := range statement.ExprColumns {
|
||||
if i > 0 || hasPreviousSet {
|
||||
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
switch tp := expr.Arg.(type) {
|
||||
case string:
|
||||
if len(tp) == 0 {
|
||||
tp = "''"
|
||||
}
|
||||
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", tp); err != nil {
|
||||
return err
|
||||
}
|
||||
case *builder.Builder:
|
||||
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tp.WriteTo(statement.QuoteReplacer(w)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprint(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ?"); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(expr.Arg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error {
|
||||
previousLen := w.Len()
|
||||
for i, colName := range colNames {
|
||||
if i > 0 {
|
||||
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprint(w, colName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.Append(args...)
|
||||
|
||||
if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
|
||||
|
||||
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error {
|
||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -444,17 +595,16 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond
|
||||
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
|
||||
return err
|
||||
}
|
||||
for i, colName := range colNames {
|
||||
if i > 0 {
|
||||
if _, err := fmt.Fprint(updateWriter, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprint(updateWriter, colName); err != nil {
|
||||
return err
|
||||
}
|
||||
previousLen := updateWriter.Len()
|
||||
|
||||
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if no columns to be updated, return error
|
||||
if previousLen == updateWriter.Len() {
|
||||
return ErrNoColumnsTobeUpdated
|
||||
}
|
||||
updateWriter.Append(args...)
|
||||
|
||||
// write from
|
||||
if err := statement.writeUpdateFrom(updateWriter); err != nil {
|
||||
|
@ -5,17 +5,17 @@
|
||||
package xorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm/internal/statements"
|
||||
"xorm.io/xorm/internal/utils"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
// enumerated all errors
|
||||
var (
|
||||
ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
|
||||
ErrNoColumnsTobeUpdated = statements.ErrNoColumnsTobeUpdated
|
||||
)
|
||||
|
||||
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
|
||||
@ -74,9 +74,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||
v := utils.ReflectValue(bean)
|
||||
t := v.Type()
|
||||
|
||||
var colNames []string
|
||||
var args []interface{}
|
||||
|
||||
// handle before update processors
|
||||
for _, closure := range session.beforeClosures {
|
||||
closure(bean)
|
||||
@ -87,6 +84,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||
}
|
||||
// --
|
||||
|
||||
var colNames []string
|
||||
var args []interface{}
|
||||
var err error
|
||||
isMap := t.Kind() == reflect.Map
|
||||
isStruct := t.Kind() == reflect.Struct
|
||||
@ -148,41 +147,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||
}
|
||||
}
|
||||
|
||||
// for update action to like "column = column + ?"
|
||||
incColumns := session.statement.IncrColumns
|
||||
for _, expr := range incColumns {
|
||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?")
|
||||
args = append(args, expr.Arg)
|
||||
}
|
||||
// for update action to like "column = column - ?"
|
||||
decColumns := session.statement.DecrColumns
|
||||
for _, expr := range decColumns {
|
||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?")
|
||||
args = append(args, expr.Arg)
|
||||
}
|
||||
// for update action to like "column = expression"
|
||||
exprColumns := session.statement.ExprColumns
|
||||
for _, expr := range exprColumns {
|
||||
switch tp := expr.Arg.(type) {
|
||||
case string:
|
||||
if len(tp) == 0 {
|
||||
tp = "''"
|
||||
}
|
||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp)
|
||||
case *builder.Builder:
|
||||
subQuery, subArgs, err := builder.ToSQL(tp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
subQuery = session.statement.ReplaceQuote(subQuery)
|
||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")")
|
||||
args = append(args, subArgs...)
|
||||
default:
|
||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?")
|
||||
args = append(args, expr.Arg)
|
||||
}
|
||||
}
|
||||
|
||||
if err = session.statement.ProcessIDParam(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -211,23 +175,18 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||
verValue *reflect.Value
|
||||
)
|
||||
if doIncVer {
|
||||
verValue, err = table.VersionColumn().ValueOf(bean)
|
||||
verValue, err = table.VersionColumn().ValueOfV(&v)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if verValue != nil {
|
||||
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
|
||||
colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
|
||||
}
|
||||
}
|
||||
|
||||
if len(colNames) == 0 {
|
||||
return 0, ErrNoColumnsTobeUpdated
|
||||
}
|
||||
|
||||
updateWriter := builder.NewWriter()
|
||||
if err := session.statement.WriteUpdate(updateWriter, cond, colNames, args); err != nil {
|
||||
if err := session.statement.WriteUpdate(updateWriter, cond, v, colNames, args); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user