1
0
mirror of https://github.com/aclindsa/moneygo.git synced 2025-07-03 04:38:38 -04:00

Pass DB as a closure instead of a global variable

This is part of an ongoing attempt to restructure the code to make it
more 'testable'.
This commit is contained in:
2017-10-04 08:05:51 -04:00
parent 9abafa50b2
commit 156b9aaf0c
13 changed files with 253 additions and 208 deletions

View File

@ -146,11 +146,7 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
if t.Splits[i].AccountId != -1 {
var err error
var account *Account
if transaction != nil {
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
} else {
account, err = GetAccount(t.Splits[i].AccountId, t.UserId)
}
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
if err != nil {
return nil, err
}
@ -164,16 +160,12 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
return sums, nil
}
func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) {
return t.GetImbalancesTx(nil)
}
// Returns true if all securities contained in this transaction are balanced,
// false otherwise
func (t *Transaction) Balanced() (bool, error) {
func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) {
var zero big.Rat
sums, err := t.GetImbalances()
sums, err := t.GetImbalancesTx(transaction)
if err != nil {
return false, err
}
@ -186,21 +178,23 @@ func (t *Transaction) Balanced() (bool, error) {
return true, nil
}
func GetTransaction(transactionid int64, userid int64) (*Transaction, error) {
func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) {
var t Transaction
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
if err != nil {
transaction.Rollback()
return nil, err
}
_, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
if err != nil {
transaction.Rollback()
return nil, err
}
@ -213,10 +207,10 @@ func GetTransaction(transactionid int64, userid int64) (*Transaction, error) {
return &t, nil
}
func GetTransactions(userid int64) (*[]Transaction, error) {
func GetTransactions(db *DB, userid int64) (*[]Transaction, error) {
var transactions []Transaction
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
@ -316,8 +310,8 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
return nil
}
func InsertTransaction(t *Transaction, user *User) error {
transaction, err := DB.Begin()
func InsertTransaction(db *DB, t *Transaction, user *User) error {
transaction, err := db.Begin()
if err != nil {
return err
}
@ -337,17 +331,11 @@ func InsertTransaction(t *Transaction, user *User) error {
return nil
}
func UpdateTransaction(t *Transaction, user *User) error {
transaction, err := DB.Begin()
if err != nil {
return err
}
func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error {
var existing_splits []*Split
_, err = transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
_, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
if err != nil {
transaction.Rollback()
return err
}
@ -367,11 +355,9 @@ func UpdateTransaction(t *Transaction, user *User) error {
if ok {
count, err := transaction.Update(t.Splits[i])
if err != nil {
transaction.Rollback()
return err
}
if count != 1 {
transaction.Rollback()
return errors.New("Updated more than one transaction split")
}
delete(s_map, t.Splits[i].SplitId)
@ -379,7 +365,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
t.Splits[i].SplitId = -1
err := transaction.Insert(t.Splits[i])
if err != nil {
transaction.Rollback()
return err
}
}
@ -397,7 +382,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
if ok {
_, err := transaction.Delete(existing_splits[i])
if err != nil {
transaction.Rollback()
return err
}
}
@ -410,31 +394,22 @@ func UpdateTransaction(t *Transaction, user *User) error {
}
err = incrementAccountVersions(transaction, user, a_ids)
if err != nil {
transaction.Rollback()
return err
}
count, err := transaction.Update(t)
if err != nil {
transaction.Rollback()
return err
}
if count != 1 {
transaction.Rollback()
return errors.New("Updated more than one transaction")
}
err = transaction.Commit()
if err != nil {
transaction.Rollback()
return err
}
return nil
}
func DeleteTransaction(t *Transaction, user *User) error {
transaction, err := DB.Begin()
func DeleteTransaction(db *DB, t *Transaction, user *User) error {
transaction, err := db.Begin()
if err != nil {
return err
}
@ -477,8 +452,8 @@ func DeleteTransaction(t *Transaction, user *User) error {
return nil
}
func TransactionHandler(w http.ResponseWriter, r *http.Request) {
user, err := GetUserFromSession(r)
func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
user, err := GetUserFromSession(db, r)
if err != nil {
WriteError(w, 1 /*Not Signed In*/)
return
@ -500,27 +475,37 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
transaction.TransactionId = -1
transaction.UserId = user.UserId
balanced, err := transaction.Balanced()
sqltx, err := db.Begin()
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
balanced, err := transaction.Balanced(sqltx)
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
if !transaction.Valid() || !balanced {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/)
return
}
for i := range transaction.Splits {
transaction.Splits[i].SplitId = -1
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId)
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
if err != nil {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/)
return
}
}
err = InsertTransaction(&transaction, user)
err = InsertTransactionTx(sqltx, &transaction, user)
if err != nil {
if _, ok := err.(AccountMissingError); ok {
WriteError(w, 3 /*Invalid Request*/)
@ -528,6 +513,15 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
}
sqltx.Rollback()
return
}
err = sqltx.Commit()
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
@ -543,7 +537,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
if err != nil {
//Return all Transactions
var al TransactionList
transactions, err := GetTransactions(user.UserId)
transactions, err := GetTransactions(db, user.UserId)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -558,7 +552,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
}
} else {
//Return Transaction with this Id
transaction, err := GetTransaction(transactionid, user.UserId)
transaction, err := GetTransaction(db, transactionid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
@ -591,27 +585,46 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
}
transaction.UserId = user.UserId
balanced, err := transaction.Balanced()
sqltx, err := db.Begin()
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
balanced, err := transaction.Balanced(sqltx)
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
if !transaction.Valid() || !balanced {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/)
return
}
for i := range transaction.Splits {
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId)
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
if err != nil {
sqltx.Rollback()
WriteError(w, 3 /*Invalid Request*/)
return
}
}
err = UpdateTransaction(&transaction, user)
err = UpdateTransactionTx(sqltx, &transaction, user)
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
err = sqltx.Commit()
if err != nil {
sqltx.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
@ -630,13 +643,13 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
return
}
transaction, err := GetTransaction(transactionid, user.UserId)
transaction, err := GetTransaction(db, transactionid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
}
err = DeleteTransaction(transaction, user)
err = DeleteTransaction(db, transaction, user)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -672,9 +685,9 @@ func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int6
return &pageDifference, nil
}
func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) {
func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) {
var splits []Split
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
@ -707,9 +720,9 @@ func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) {
}
// Assumes accountid is valid and is owned by the current user
func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.Rat, error) {
func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) {
var splits []Split
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
@ -741,9 +754,9 @@ func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.R
return &balance, nil
}
func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
var splits []Split
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
@ -775,11 +788,11 @@ func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Ti
return &balance, nil
}
func GetAccountTransactions(user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
var transactions []Transaction
var atl AccountTransactionsList
transaction, err := DB.Begin()
transaction, err := db.Begin()
if err != nil {
return nil, err
}
@ -878,7 +891,7 @@ func GetAccountTransactions(user *User, accountid int64, sort string, page uint6
// Return only those transactions which have at least one split pertaining to
// an account
func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request,
user *User, accountid int64) {
var page uint64 = 0
@ -916,7 +929,7 @@ func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
sort = sortstring
}
accountTransactions, err := GetAccountTransactions(user, accountid, sort, page, limit)
accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)