1
0
mirror of https://github.com/aclindsa/moneygo.git synced 2025-07-03 04:38:38 -04:00

Use SQL transactions for the entirety of every request

This commit is contained in:
2017-10-14 14:20:50 -04:00
parent 6726d9cb2f
commit 4e53a5e59c
14 changed files with 496 additions and 989 deletions

View File

@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"gopkg.in/gorp.v1"
"io"
"log"
"net/http"
@ -47,61 +46,52 @@ func (u *User) HashPassword() {
u.Password = ""
}
func GetUser(db *DB, userid int64) (*User, error) {
func GetUser(tx *Tx, userid int64) (*User, error) {
var u User
err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid)
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
if err != nil {
return nil, err
}
return &u, nil
}
func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) {
func GetUserTx(tx *Tx, userid int64) (*User, error) {
var u User
err := transaction.SelectOne(&u, "SELECT * from users where UserId=?", userid)
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
if err != nil {
return nil, err
}
return &u, nil
}
func GetUserByUsername(db *DB, username string) (*User, error) {
func GetUserByUsername(tx *Tx, username string) (*User, error) {
var u User
err := db.SelectOne(&u, "SELECT * from users where Username=?", username)
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
if err != nil {
return nil, err
}
return &u, nil
}
func InsertUser(db *DB, u *User) error {
transaction, err := db.Begin()
if err != nil {
return err
}
func InsertUser(tx *Tx, u *User) error {
security_template := FindCurrencyTemplate(u.DefaultCurrency)
if security_template == nil {
transaction.Rollback()
return errors.New("Invalid ISO4217 Default Currency")
}
existing, err := transaction.SelectInt("SELECT count(*) from users where Username=?", u.Username)
existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username)
if err != nil {
transaction.Rollback()
return err
}
if existing > 0 {
transaction.Rollback()
return UserExistsError{}
}
err = transaction.Insert(u)
err = tx.Insert(u)
if err != nil {
transaction.Rollback()
return err
}
@ -110,201 +100,146 @@ func InsertUser(db *DB, u *User) error {
security = *security_template
security.UserId = u.UserId
err = InsertSecurityTx(transaction, &security)
err = InsertSecurityTx(tx, &security)
if err != nil {
transaction.Rollback()
return err
}
// Update the user's DefaultCurrency to our new SecurityId
u.DefaultCurrency = security.SecurityId
count, err := transaction.Update(u)
count, err := tx.Update(u)
if err != nil {
transaction.Rollback()
return err
} else if count != 1 {
transaction.Rollback()
return errors.New("Would have updated more than one user")
}
err = transaction.Commit()
if err != nil {
transaction.Rollback()
return err
}
return nil
}
func GetUserFromSession(db *DB, r *http.Request) (*User, error) {
s, err := GetSession(db, r)
func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) {
s, err := GetSession(tx, r)
if err != nil {
return nil, err
}
return GetUser(db, s.UserId)
return GetUser(tx, s.UserId)
}
func UpdateUser(db *DB, u *User) error {
transaction, err := db.Begin()
func GetUserFromSessionTx(tx *Tx, r *http.Request) (*User, error) {
s, err := GetSessionTx(tx, r)
if err != nil {
return err
return nil, err
}
return GetUserTx(tx, s.UserId)
}
security, err := GetSecurityTx(transaction, u.DefaultCurrency, u.UserId)
func UpdateUser(tx *Tx, u *User) error {
security, err := GetSecurityTx(tx, u.DefaultCurrency, u.UserId)
if err != nil {
transaction.Rollback()
return err
} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
transaction.Rollback()
return errors.New("UserId and DefaultCurrency don't match the fetched security")
} else if security.Type != Currency {
transaction.Rollback()
return errors.New("New DefaultCurrency security is not a currency")
}
count, err := transaction.Update(u)
count, err := tx.Update(u)
if err != nil {
transaction.Rollback()
return err
} else if count != 1 {
transaction.Rollback()
return errors.New("Would have updated more than one user")
}
err = transaction.Commit()
if err != nil {
transaction.Rollback()
return err
}
return nil
}
func DeleteUser(db *DB, u *User) error {
transaction, err := db.Begin()
func DeleteUser(tx *Tx, u *User) error {
count, err := tx.Delete(u)
if err != nil {
return err
}
count, err := transaction.Delete(u)
if err != nil {
transaction.Rollback()
return err
}
if count != 1 {
transaction.Rollback()
return fmt.Errorf("No user to delete")
}
_, err = transaction.Exec("DELETE FROM prices WHERE prices.PriceId IN (SELECT prices.PriceId FROM prices INNER JOIN securities ON prices.SecurityId=securities.SecurityId WHERE securities.UserId=?)", u.UserId)
_, err = tx.Exec("DELETE FROM prices WHERE prices.PriceId IN (SELECT prices.PriceId FROM prices INNER JOIN securities ON prices.SecurityId=securities.SecurityId WHERE securities.UserId=?)", u.UserId)
if err != nil {
transaction.Rollback()
return err
}
_, err = transaction.Exec("DELETE FROM splits WHERE splits.SplitId IN (SELECT splits.SplitId FROM splits INNER JOIN transactions ON splits.TransactionId=transactions.TransactionId WHERE transactions.UserId=?)", u.UserId)
_, err = tx.Exec("DELETE FROM splits WHERE splits.SplitId IN (SELECT splits.SplitId FROM splits INNER JOIN transactions ON splits.TransactionId=transactions.TransactionId WHERE transactions.UserId=?)", u.UserId)
if err != nil {
transaction.Rollback()
return err
}
_, err = transaction.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId)
_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId)
if err != nil {
transaction.Rollback()
return err
}
_, err = transaction.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId)
_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId)
if err != nil {
transaction.Rollback()
return err
}
_, err = transaction.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId)
_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId)
if err != nil {
transaction.Rollback()
return err
}
_, err = transaction.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId)
_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId)
if err != nil {
transaction.Rollback()
return err
}
_, err = transaction.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId)
_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId)
if err != nil {
transaction.Rollback()
return err
}
err = transaction.Commit()
if err != nil {
transaction.Rollback()
return err
}
return nil
}
func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) {
func UserHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
if r.Method == "POST" {
user_json := r.PostFormValue("user")
if user_json == "" {
WriteError(w, 3 /*Invalid Request*/)
return
return NewError(3 /*Invalid Request*/)
}
var user User
err := user.Read(user_json)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
return NewError(3 /*Invalid Request*/)
}
user.UserId = -1
user.HashPassword()
err = InsertUser(db, &user)
err = InsertUser(tx, &user)
if err != nil {
if _, ok := err.(UserExistsError); ok {
WriteError(w, 4 /*User Exists*/)
return NewError(4 /*User Exists*/)
} else {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return NewError(999 /*Internal Error*/)
}
return
}
w.WriteHeader(201 /*Created*/)
err = user.Write(w)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
return ResponseWrapper{201, &user}
} else {
user, err := GetUserFromSession(db, r)
user, err := GetUserFromSession(tx, r)
if err != nil {
WriteError(w, 1 /*Not Signed In*/)
return
return NewError(1 /*Not Signed In*/)
}
userid, err := GetURLID(r.URL.Path)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
return NewError(3 /*Invalid Request*/)
}
if userid != user.UserId {
WriteError(w, 2 /*Unauthorized Access*/)
return
return NewError(2 /*Unauthorized Access*/)
}
if r.Method == "GET" {
err = user.Write(w)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
return user
} else if r.Method == "PUT" {
user_json := r.PostFormValue("user")
if user_json == "" {
WriteError(w, 3 /*Invalid Request*/)
return
return NewError(3 /*Invalid Request*/)
}
// Save old PWHash in case the new password is bogus
@ -312,8 +247,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) {
err = user.Read(user_json)
if err != nil || user.UserId != userid {
WriteError(w, 3 /*Invalid Request*/)
return
return NewError(3 /*Invalid Request*/)
}
// If the user didn't create a new password, keep their old one
@ -324,27 +258,21 @@ func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user.PasswordHash = old_pwhash
}
err = UpdateUser(db, user)
err = UpdateUser(tx, user)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
return NewError(999 /*Internal Error*/)
}
err = user.Write(w)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
return user
} else if r.Method == "DELETE" {
err := DeleteUser(db, user)
err := DeleteUser(tx, user)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
return NewError(999 /*Internal Error*/)
}
WriteSuccess(w)
return SuccessWriter{}
}
}
return NewError(3 /*Invalid Request*/)
}