mirror of
https://github.com/aclindsa/moneygo.git
synced 2025-07-03 04:38:38 -04:00
Pass DB as a closure instead of a global variable
This is part of an ongoing attempt to restructure the code to make it more 'testable'.
This commit is contained in:
137
transactions.go
137
transactions.go
@ -146,11 +146,7 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
|
||||
if t.Splits[i].AccountId != -1 {
|
||||
var err error
|
||||
var account *Account
|
||||
if transaction != nil {
|
||||
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
|
||||
} else {
|
||||
account, err = GetAccount(t.Splits[i].AccountId, t.UserId)
|
||||
}
|
||||
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -164,16 +160,12 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
|
||||
return sums, nil
|
||||
}
|
||||
|
||||
func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) {
|
||||
return t.GetImbalancesTx(nil)
|
||||
}
|
||||
|
||||
// Returns true if all securities contained in this transaction are balanced,
|
||||
// false otherwise
|
||||
func (t *Transaction) Balanced() (bool, error) {
|
||||
func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) {
|
||||
var zero big.Rat
|
||||
|
||||
sums, err := t.GetImbalances()
|
||||
sums, err := t.GetImbalancesTx(transaction)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -186,21 +178,23 @@ func (t *Transaction) Balanced() (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func GetTransaction(transactionid int64, userid int64) (*Transaction, error) {
|
||||
func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) {
|
||||
var t Transaction
|
||||
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -213,10 +207,10 @@ func GetTransaction(transactionid int64, userid int64) (*Transaction, error) {
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func GetTransactions(userid int64) (*[]Transaction, error) {
|
||||
func GetTransactions(db *DB, userid int64) (*[]Transaction, error) {
|
||||
var transactions []Transaction
|
||||
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -316,8 +310,8 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertTransaction(t *Transaction, user *User) error {
|
||||
transaction, err := DB.Begin()
|
||||
func InsertTransaction(db *DB, t *Transaction, user *User) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -337,17 +331,11 @@ func InsertTransaction(t *Transaction, user *User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateTransaction(t *Transaction, user *User) error {
|
||||
transaction, err := DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error {
|
||||
var existing_splits []*Split
|
||||
|
||||
_, err = transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||
_, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
@ -367,11 +355,9 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
||||
if ok {
|
||||
count, err := transaction.Update(t.Splits[i])
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
transaction.Rollback()
|
||||
return errors.New("Updated more than one transaction split")
|
||||
}
|
||||
delete(s_map, t.Splits[i].SplitId)
|
||||
@ -379,7 +365,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
||||
t.Splits[i].SplitId = -1
|
||||
err := transaction.Insert(t.Splits[i])
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -397,7 +382,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
||||
if ok {
|
||||
_, err := transaction.Delete(existing_splits[i])
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -410,31 +394,22 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
||||
}
|
||||
err = incrementAccountVersions(transaction, user, a_ids)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := transaction.Update(t)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
transaction.Rollback()
|
||||
return errors.New("Updated more than one transaction")
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteTransaction(t *Transaction, user *User) error {
|
||||
transaction, err := DB.Begin()
|
||||
func DeleteTransaction(db *DB, t *Transaction, user *User) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -477,8 +452,8 @@ func DeleteTransaction(t *Transaction, user *User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := GetUserFromSession(r)
|
||||
func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
@ -500,27 +475,37 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
transaction.TransactionId = -1
|
||||
transaction.UserId = user.UserId
|
||||
|
||||
balanced, err := transaction.Balanced()
|
||||
sqltx, err := db.Begin()
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
balanced, err := transaction.Balanced(sqltx)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
if !transaction.Valid() || !balanced {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
for i := range transaction.Splits {
|
||||
transaction.Splits[i].SplitId = -1
|
||||
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId)
|
||||
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = InsertTransaction(&transaction, user)
|
||||
err = InsertTransactionTx(sqltx, &transaction, user)
|
||||
if err != nil {
|
||||
if _, ok := err.(AccountMissingError); ok {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
@ -528,6 +513,15 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
}
|
||||
sqltx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
err = sqltx.Commit()
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -543,7 +537,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
//Return all Transactions
|
||||
var al TransactionList
|
||||
transactions, err := GetTransactions(user.UserId)
|
||||
transactions, err := GetTransactions(db, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -558,7 +552,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
} else {
|
||||
//Return Transaction with this Id
|
||||
transaction, err := GetTransaction(transactionid, user.UserId)
|
||||
transaction, err := GetTransaction(db, transactionid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
@ -591,27 +585,46 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
transaction.UserId = user.UserId
|
||||
|
||||
balanced, err := transaction.Balanced()
|
||||
sqltx, err := db.Begin()
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
balanced, err := transaction.Balanced(sqltx)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
if !transaction.Valid() || !balanced {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
for i := range transaction.Splits {
|
||||
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId)
|
||||
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = UpdateTransaction(&transaction, user)
|
||||
err = UpdateTransactionTx(sqltx, &transaction, user)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = sqltx.Commit()
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
@ -630,13 +643,13 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
transaction, err := GetTransaction(transactionid, user.UserId)
|
||||
transaction, err := GetTransaction(db, transactionid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = DeleteTransaction(transaction, user)
|
||||
err = DeleteTransaction(db, transaction, user)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -672,9 +685,9 @@ func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int6
|
||||
return &pageDifference, nil
|
||||
}
|
||||
|
||||
func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) {
|
||||
func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) {
|
||||
var splits []Split
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -707,9 +720,9 @@ func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) {
|
||||
}
|
||||
|
||||
// Assumes accountid is valid and is owned by the current user
|
||||
func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.Rat, error) {
|
||||
func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) {
|
||||
var splits []Split
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -741,9 +754,9 @@ func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.R
|
||||
return &balance, nil
|
||||
}
|
||||
|
||||
func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
||||
func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
||||
var splits []Split
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -775,11 +788,11 @@ func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Ti
|
||||
return &balance, nil
|
||||
}
|
||||
|
||||
func GetAccountTransactions(user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
|
||||
func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
|
||||
var transactions []Transaction
|
||||
var atl AccountTransactionsList
|
||||
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -878,7 +891,7 @@ func GetAccountTransactions(user *User, accountid int64, sort string, page uint6
|
||||
|
||||
// Return only those transactions which have at least one split pertaining to
|
||||
// an account
|
||||
func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
|
||||
func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request,
|
||||
user *User, accountid int64) {
|
||||
|
||||
var page uint64 = 0
|
||||
@ -916,7 +929,7 @@ func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
|
||||
sort = sortstring
|
||||
}
|
||||
|
||||
accountTransactions, err := GetAccountTransactions(user, accountid, sort, page, limit)
|
||||
accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
|
Reference in New Issue
Block a user