diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 2022a2a..b330ed6 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -14,8 +14,6 @@ type ResponseWriterWriter interface { Write(http.ResponseWriter) error } -type Tx = gorp.Transaction - type Context struct { Tx *Tx User *User @@ -51,7 +49,7 @@ type APIHandler struct { } func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { - tx, err := ah.DB.Begin() + tx, err := GetTx(ah.DB) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/tx.go b/internal/handlers/tx.go new file mode 100644 index 0000000..ae19a4b --- /dev/null +++ b/internal/handlers/tx.go @@ -0,0 +1,65 @@ +package handlers + +import ( + "database/sql" + "gopkg.in/gorp.v1" + "strings" +) + +type Tx struct { + Dialect gorp.Dialect + Tx *gorp.Transaction +} + +func (tx *Tx) Rebind(query string) string { + chunks := strings.Split(query, "?") + str := chunks[0] + for i := 1; i < len(chunks); i++ { + str += tx.Dialect.BindVar(i-1) + chunks[i] + } + return str +} + +func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { + return tx.Tx.Select(i, tx.Rebind(query), args...) +} + +func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { + return tx.Tx.Exec(tx.Rebind(query), args...) +} + +func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) { + return tx.Tx.SelectInt(tx.Rebind(query), args...) +} + +func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error { + return tx.Tx.SelectOne(holder, tx.Rebind(query), args...) +} + +func (tx *Tx) Insert(list ...interface{}) error { + return tx.Tx.Insert(list...) +} + +func (tx *Tx) Update(list ...interface{}) (int64, error) { + return tx.Tx.Update(list...) +} + +func (tx *Tx) Delete(list ...interface{}) (int64, error) { + return tx.Tx.Delete(list...) +} + +func (tx *Tx) Commit() error { + return tx.Tx.Commit() +} + +func (tx *Tx) Rollback() error { + return tx.Tx.Rollback() +} + +func GetTx(db *gorp.DbMap) (*Tx, error) { + tx, err := db.Begin() + if err != nil { + return nil, err + } + return &Tx{db.Dialect, tx}, nil +}