mirror of
				https://github.com/aclindsa/moneygo.git
				synced 2025-10-30 01:23:26 -04:00 
			
		
		
		
	Use SQL transactions for the entirety of every request
This commit is contained in:
		| @@ -129,10 +129,10 @@ func (al *AccountList) Read(json_str string) error { | ||||
| 	return dec.Decode(al) | ||||
| } | ||||
|  | ||||
| func GetAccount(db *DB, accountid int64, userid int64) (*Account, error) { | ||||
| func GetAccount(tx *Tx, accountid int64, userid int64) (*Account, error) { | ||||
| 	var a Account | ||||
|  | ||||
| 	err := db.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) | ||||
| 	err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -150,10 +150,10 @@ func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64) | ||||
| 	return &a, nil | ||||
| } | ||||
|  | ||||
| func GetAccounts(db *DB, userid int64) (*[]Account, error) { | ||||
| func GetAccounts(tx *Tx, userid int64) (*[]Account, error) { | ||||
| 	var accounts []Account | ||||
|  | ||||
| 	_, err := db.Select(&accounts, "SELECT * from accounts where UserId=?", userid) | ||||
| 	_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -293,12 +293,7 @@ func (cae CircularAccountsError) Error() string { | ||||
| 	return "Would result in circular account relationship" | ||||
| } | ||||
|  | ||||
| func insertUpdateAccount(db *DB, a *Account, insert bool) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| func insertUpdateAccount(tx *Tx, a *Account, insert bool) error { | ||||
| 	found := make(map[int64]bool) | ||||
| 	if !insert { | ||||
| 		found[a.AccountId] = true | ||||
| @@ -308,14 +303,12 @@ func insertUpdateAccount(db *DB, a *Account, insert bool) error { | ||||
| 	for parentid != -1 { | ||||
| 		depth += 1 | ||||
| 		if depth > 100 { | ||||
| 			transaction.Rollback() | ||||
| 			return TooMuchNestingError{} | ||||
| 		} | ||||
|  | ||||
| 		var a Account | ||||
| 		err := transaction.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) | ||||
| 		err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return ParentAccountMissingError{} | ||||
| 		} | ||||
|  | ||||
| @@ -327,107 +320,79 @@ func insertUpdateAccount(db *DB, a *Account, insert bool) error { | ||||
| 		found[parentid] = true | ||||
| 		parentid = a.ParentAccountId | ||||
| 		if _, ok := found[parentid]; ok { | ||||
| 			transaction.Rollback() | ||||
| 			return CircularAccountsError{} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if insert { | ||||
| 		err = transaction.Insert(a) | ||||
| 		err := tx.Insert(a) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		oldacct, err := GetAccountTx(transaction, a.AccountId, a.UserId) | ||||
| 		oldacct, err := GetAccountTx(tx, a.AccountId, a.UserId) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		a.AccountVersion = oldacct.AccountVersion + 1 | ||||
|  | ||||
| 		count, err := transaction.Update(a) | ||||
| 		count, err := tx.Update(a) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return err | ||||
| 		} | ||||
| 		if count != 1 { | ||||
| 			transaction.Rollback() | ||||
| 			return errors.New("Updated more than one account") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func InsertAccount(db *DB, a *Account) error { | ||||
| 	return insertUpdateAccount(db, a, true) | ||||
| func InsertAccount(tx *Tx, a *Account) error { | ||||
| 	return insertUpdateAccount(tx, a, true) | ||||
| } | ||||
|  | ||||
| func UpdateAccount(db *DB, a *Account) error { | ||||
| 	return insertUpdateAccount(db, a, false) | ||||
| } | ||||
|  | ||||
| func DeleteAccount(db *DB, a *Account) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| func UpdateAccount(tx *Tx, a *Account) error { | ||||
| 	return insertUpdateAccount(tx, a, false) | ||||
| } | ||||
|  | ||||
| func DeleteAccount(tx *Tx, a *Account) error { | ||||
| 	if a.ParentAccountId != -1 { | ||||
| 		// Re-parent splits to this account's parent account if this account isn't a root account | ||||
| 		_, err = transaction.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) | ||||
| 		_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		// Delete splits if this account is a root account | ||||
| 		_, err = transaction.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId) | ||||
| 		_, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Re-parent child accounts to this account's parent account | ||||
| 	_, err = transaction.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId) | ||||
| 	_, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Delete(a) | ||||
| 	count, err := tx.Delete(a) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Was going to delete more than one account") | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	user, err := GetUserFromSession(db, r) | ||||
| func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { | ||||
| 	user, err := GetUserFromSession(tx, r) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 1 /*Not Signed In*/) | ||||
| 		return | ||||
| 		return NewError(1 /*Not Signed In*/) | ||||
| 	} | ||||
|  | ||||
| 	if r.Method == "POST" { | ||||
| @@ -439,59 +404,46 @@ func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 			n, err := GetURLPieces(r.URL.Path, "/account/%d/import/%s", &accountid, &importtype) | ||||
|  | ||||
| 			if err != nil || n != 2 { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 			AccountImportHandler(db, w, r, user, accountid, importtype) | ||||
| 			return | ||||
| 			return AccountImportHandler(tx, r, user, accountid, importtype) | ||||
| 		} | ||||
|  | ||||
| 		account_json := r.PostFormValue("account") | ||||
| 		if account_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		var account Account | ||||
| 		err := account.Read(account_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		account.AccountId = -1 | ||||
| 		account.UserId = user.UserId | ||||
| 		account.AccountVersion = 0 | ||||
|  | ||||
| 		security, err := GetSecurity(db, account.SecurityId, user.UserId) | ||||
| 		security, err := GetSecurity(tx, account.SecurityId, user.UserId) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
| 		if security == nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		err = InsertAccount(db, &account) | ||||
| 		err = InsertAccount(tx, &account) | ||||
| 		if err != nil { | ||||
| 			if _, ok := err.(ParentAccountMissingError); ok { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} else { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		w.WriteHeader(201 /*Created*/) | ||||
| 		err = account.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 		return ResponseWrapper{201, &account} | ||||
| 	} else if r.Method == "GET" { | ||||
| 		var accountid int64 | ||||
| 		n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid) | ||||
| @@ -499,112 +451,86 @@ func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 		if err != nil || n != 1 { | ||||
| 			//Return all Accounts | ||||
| 			var al AccountList | ||||
| 			accounts, err := GetAccounts(db, user.UserId) | ||||
| 			accounts, err := GetAccounts(tx, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 			al.Accounts = accounts | ||||
| 			err = (&al).Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return &al | ||||
| 		} else { | ||||
| 			// if URL looks like /account/[0-9]+/transactions, use the account | ||||
| 			// transaction handler | ||||
| 			if accountTransactionsRE.MatchString(r.URL.Path) { | ||||
| 				AccountTransactionsHandler(db, w, r, user, accountid) | ||||
| 				return | ||||
| 				return AccountTransactionsHandler(tx, r, user, accountid) | ||||
| 			} | ||||
|  | ||||
| 			// Return Account with this Id | ||||
| 			account, err := GetAccount(db, accountid, user.UserId) | ||||
| 			account, err := GetAccount(tx, accountid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = account.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return account | ||||
| 		} | ||||
| 	} else { | ||||
| 		accountid, err := GetURLID(r.URL.Path) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		if r.Method == "PUT" { | ||||
| 			account_json := r.PostFormValue("account") | ||||
| 			if account_json == "" { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			var account Account | ||||
| 			err := account.Read(account_json) | ||||
| 			if err != nil || account.AccountId != accountid { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| 			account.UserId = user.UserId | ||||
|  | ||||
| 			security, err := GetSecurity(db, account.SecurityId, user.UserId) | ||||
| 			security, err := GetSecurity(tx, account.SecurityId, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 			if security == nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			if account.ParentAccountId == account.AccountId { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = UpdateAccount(db, &account) | ||||
| 			err = UpdateAccount(tx, &account) | ||||
| 			if err != nil { | ||||
| 				if _, ok := err.(ParentAccountMissingError); ok { | ||||
| 					WriteError(w, 3 /*Invalid Request*/) | ||||
| 					return NewError(3 /*Invalid Request*/) | ||||
| 				} else if _, ok := err.(CircularAccountsError); ok { | ||||
| 					WriteError(w, 3 /*Invalid Request*/) | ||||
| 					return NewError(3 /*Invalid Request*/) | ||||
| 				} else { | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print(err) | ||||
| 					return NewError(999 /*Internal Error*/) | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = account.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return &account | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			account, err := GetAccount(db, accountid, user.UserId) | ||||
| 			account, err := GetAccount(tx, accountid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteAccount(db, account) | ||||
| 			err = DeleteAccount(tx, account) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
|  | ||||
| 			WriteSuccess(w) | ||||
| 			return SuccessWriter{} | ||||
| 		} | ||||
| 	} | ||||
| 	return NewError(3 /*Invalid Request*/) | ||||
| } | ||||
|   | ||||
| @@ -15,9 +15,9 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { | ||||
|  | ||||
| 	ctx := L.Context() | ||||
|  | ||||
| 	db, ok := ctx.Value(dbContextKey).(*DB) | ||||
| 	tx, ok := ctx.Value(dbContextKey).(*Tx) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Couldn't find DB in lua's Context") | ||||
| 		return nil, errors.New("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
|  | ||||
| 	account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account) | ||||
| @@ -27,7 +27,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { | ||||
| 			return nil, errors.New("Couldn't find User in lua's Context") | ||||
| 		} | ||||
|  | ||||
| 		accounts, err := GetAccounts(db, user.UserId) | ||||
| 		accounts, err := GetAccounts(tx, user.UserId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @@ -149,9 +149,9 @@ func luaAccountBalance(L *lua.LState) int { | ||||
| 	a := luaCheckAccount(L, 1) | ||||
|  | ||||
| 	ctx := L.Context() | ||||
| 	db, ok := ctx.Value(dbContextKey).(*DB) | ||||
| 	tx, ok := ctx.Value(dbContextKey).(*Tx) | ||||
| 	if !ok { | ||||
| 		panic("Couldn't find DB in lua's Context") | ||||
| 		panic("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
| 	user, ok := ctx.Value(userContextKey).(*User) | ||||
| 	if !ok { | ||||
| @@ -171,12 +171,12 @@ func luaAccountBalance(L *lua.LState) int { | ||||
| 	if date != nil { | ||||
| 		end := luaWeakCheckTime(L, 3) | ||||
| 		if end != nil { | ||||
| 			rat, err = GetAccountBalanceDateRange(db, user, a.AccountId, date, end) | ||||
| 			rat, err = GetAccountBalanceDateRange(tx, user, a.AccountId, date, end) | ||||
| 		} else { | ||||
| 			rat, err = GetAccountBalanceDate(db, user, a.AccountId, date) | ||||
| 			rat, err = GetAccountBalanceDate(tx, user, a.AccountId, date) | ||||
| 		} | ||||
| 	} else { | ||||
| 		rat, err = GetAccountBalance(db, user, a.AccountId) | ||||
| 		rat, err = GetAccountBalance(tx, user, a.AccountId) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		panic("Failed to GetAccountBalance:" + err.Error()) | ||||
|   | ||||
| @@ -38,13 +38,17 @@ var error_codes = map[int]string{ | ||||
| 	999: "Internal Error", | ||||
| } | ||||
|  | ||||
| func WriteError(w http.ResponseWriter, error_code int) { | ||||
| func NewError(error_code int) *Error { | ||||
| 	msg, ok := error_codes[error_code] | ||||
| 	if !ok { | ||||
| 		log.Printf("Error: WriteError received error code of %d", error_code) | ||||
| 		msg = error_codes[999] | ||||
| 	} | ||||
| 	e := Error{error_code, msg} | ||||
| 	return &Error{error_code, msg} | ||||
| } | ||||
|  | ||||
| func WriteError(w http.ResponseWriter, error_code int) { | ||||
| 	e := NewError(error_code) | ||||
|  | ||||
| 	err := e.Write(w) | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -308,42 +308,37 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { | ||||
| 	return &gncimport, nil | ||||
| } | ||||
|  | ||||
| func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	user, err := GetUserFromSession(db, r) | ||||
| func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { | ||||
| 	user, err := GetUserFromSession(tx, r) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 1 /*Not Signed In*/) | ||||
| 		return | ||||
| 		return NewError(1 /*Not Signed In*/) | ||||
| 	} | ||||
|  | ||||
| 	if r.Method != "POST" { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	multipartReader, err := r.MultipartReader() | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	// Assume there is only one 'part' and it's the one we care about | ||||
| 	part, err := multipartReader.NextPart() | ||||
| 	if err != nil { | ||||
| 		if err == io.EOF { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} else { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	bufread := bufio.NewReader(part) | ||||
| 	gzHeader, err := bufread.Peek(2) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 		return NewError(999 /*Internal Error*/) | ||||
| 	} | ||||
|  | ||||
| 	// Does this look like a gzipped file? | ||||
| @@ -351,9 +346,8 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	if gzHeader[0] == 0x1f && gzHeader[1] == 0x8b { | ||||
| 		gzr, err := gzip.NewReader(bufread) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
| 		gnucashImport, err = ImportGnucash(gzr) | ||||
| 	} else { | ||||
| @@ -361,15 +355,7 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	} | ||||
|  | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	sqltransaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	// Import securities, building map from Gnucash security IDs to our | ||||
| @@ -377,13 +363,11 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	securityMap := make(map[int64]int64) | ||||
| 	for _, security := range gnucashImport.Securities { | ||||
| 		securityId := security.SecurityId // save off because it could be updated | ||||
| 		s, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &security) | ||||
| 		s, err := ImportGetCreateSecurity(tx, user.UserId, &security) | ||||
| 		if err != nil { | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 6 /*Import Error*/) | ||||
| 			log.Print(err) | ||||
| 			log.Print(security) | ||||
| 			return | ||||
| 			return NewError(6 /*Import Error*/) | ||||
| 		} | ||||
| 		securityMap[securityId] = s.SecurityId | ||||
| 	} | ||||
| @@ -394,12 +378,10 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 		price.CurrencyId = securityMap[price.CurrencyId] | ||||
| 		price.PriceId = 0 | ||||
|  | ||||
| 		err := CreatePriceIfNotExist(sqltransaction, &price) | ||||
| 		err := CreatePriceIfNotExist(tx, &price) | ||||
| 		if err != nil { | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 6 /*Import Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 			return NewError(6 /*Import Error*/) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -425,12 +407,10 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 					account.ParentAccountId = accountMap[account.ParentAccountId] | ||||
| 				} | ||||
| 				account.SecurityId = securityMap[account.SecurityId] | ||||
| 				a, err := GetCreateAccountTx(sqltransaction, account) | ||||
| 				a, err := GetCreateAccountTx(tx, account) | ||||
| 				if err != nil { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print(err) | ||||
| 					return | ||||
| 					return NewError(999 /*Internal Error*/) | ||||
| 				} | ||||
| 				accountMap[account.AccountId] = a.AccountId | ||||
| 				accountsRemaining-- | ||||
| @@ -438,10 +418,8 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 		} | ||||
| 		if accountsRemaining == accountsRemainingLast { | ||||
| 			//We didn't make any progress in importing the next level of accounts, so there must be a circular parent-child relationship, so give up and tell the user they're wrong | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(fmt.Errorf("Circular account parent-child relationship when importing %s", part.FileName())) | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
| 		accountsRemainingLast = accountsRemaining | ||||
| 	} | ||||
| @@ -453,41 +431,27 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 		for _, split := range transaction.Splits { | ||||
| 			acctId, ok := accountMap[split.AccountId] | ||||
| 			if !ok { | ||||
| 				sqltransaction.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(fmt.Errorf("Error: Split's AccountID Doesn't exist: %d\n", split.AccountId)) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 			split.AccountId = acctId | ||||
|  | ||||
| 			exists, err := split.AlreadyImportedTx(sqltransaction) | ||||
| 			exists, err := split.AlreadyImportedTx(tx) | ||||
| 			if err != nil { | ||||
| 				sqltransaction.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print("Error checking if split was already imported:", err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} else if exists { | ||||
| 				already_imported = true | ||||
| 			} | ||||
| 		} | ||||
| 		if !already_imported { | ||||
| 			err := InsertTransactionTx(sqltransaction, &transaction, user) | ||||
| 			err := InsertTransactionTx(tx, &transaction, user) | ||||
| 			if err != nil { | ||||
| 				sqltransaction.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	err = sqltransaction.Commit() | ||||
| 	if err != nil { | ||||
| 		sqltransaction.Rollback() | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	WriteSuccess(w) | ||||
| 	return SuccessWriter{} | ||||
| } | ||||
|   | ||||
| @@ -2,30 +2,64 @@ package handlers | ||||
|  | ||||
| import ( | ||||
| 	"gopkg.in/gorp.v1" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| // Create a closure over db, allowing the handlers to look like a | ||||
| // http.HandlerFunc | ||||
| type DB = gorp.DbMap | ||||
| type DBHandler func(http.ResponseWriter, *http.Request, *DB) | ||||
| // But who writes the ResponseWriterWriter? | ||||
| type ResponseWriterWriter interface { | ||||
| 	Write(http.ResponseWriter) error | ||||
| } | ||||
| type Tx = gorp.Transaction | ||||
| type TxHandler func(*http.Request, *Tx) ResponseWriterWriter | ||||
|  | ||||
| func DBHandlerFunc(h DBHandler, db *DB) http.HandlerFunc { | ||||
| func TxHandlerFunc(t TxHandler, db *gorp.DbMap) http.HandlerFunc { | ||||
| 	return func(w http.ResponseWriter, r *http.Request) { | ||||
| 		h(w, r, db) | ||||
| 		tx, err := db.Begin() | ||||
| 		if err != nil { | ||||
| 			log.Print(err) | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			return | ||||
| 		} | ||||
| 		defer func() { | ||||
| 			if r := recover(); r != nil { | ||||
| 				tx.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				panic(r) | ||||
| 			} | ||||
| 		}() | ||||
|  | ||||
| 		writer := t(r, tx) | ||||
|  | ||||
| 		if e, ok := writer.(*Error); ok { | ||||
| 			tx.Rollback() | ||||
| 			e.Write(w) | ||||
| 		} else { | ||||
| 			err = tx.Commit() | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 			} else { | ||||
| 				err = writer.Write(w) | ||||
| 				if err != nil { | ||||
| 					log.Print(err) | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetHandler(db *DB) *http.ServeMux { | ||||
| func GetHandler(db *gorp.DbMap) *http.ServeMux { | ||||
| 	servemux := http.NewServeMux() | ||||
| 	servemux.HandleFunc("/session/", DBHandlerFunc(SessionHandler, db)) | ||||
| 	servemux.HandleFunc("/user/", DBHandlerFunc(UserHandler, db)) | ||||
| 	servemux.HandleFunc("/security/", DBHandlerFunc(SecurityHandler, db)) | ||||
| 	servemux.HandleFunc("/session/", TxHandlerFunc(SessionHandler, db)) | ||||
| 	servemux.HandleFunc("/user/", TxHandlerFunc(UserHandler, db)) | ||||
| 	servemux.HandleFunc("/security/", TxHandlerFunc(SecurityHandler, db)) | ||||
| 	servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler) | ||||
| 	servemux.HandleFunc("/account/", DBHandlerFunc(AccountHandler, db)) | ||||
| 	servemux.HandleFunc("/transaction/", DBHandlerFunc(TransactionHandler, db)) | ||||
| 	servemux.HandleFunc("/import/gnucash", DBHandlerFunc(GnucashImportHandler, db)) | ||||
| 	servemux.HandleFunc("/report/", DBHandlerFunc(ReportHandler, db)) | ||||
| 	servemux.HandleFunc("/account/", TxHandlerFunc(AccountHandler, db)) | ||||
| 	servemux.HandleFunc("/transaction/", TxHandlerFunc(TransactionHandler, db)) | ||||
| 	servemux.HandleFunc("/import/gnucash", TxHandlerFunc(GnucashImportHandler, db)) | ||||
| 	servemux.HandleFunc("/report/", TxHandlerFunc(ReportHandler, db)) | ||||
|  | ||||
| 	return servemux | ||||
| } | ||||
|   | ||||
| @@ -22,48 +22,35 @@ func (od *OFXDownload) Read(json_str string) error { | ||||
| 	return dec.Decode(od) | ||||
| } | ||||
|  | ||||
| func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, accountid int64) { | ||||
| func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseWriterWriter { | ||||
| 	itl, err := ImportOFX(r) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		//TODO is this necessarily an invalid request (what if it was an error on our end)? | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	if len(itl.Accounts) != 1 { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Printf("Found %d accounts when importing OFX, expected 1", len(itl.Accounts)) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	sqltransaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	// Return Account with this Id | ||||
| 	account, err := GetAccountTx(sqltransaction, accountid, user.UserId) | ||||
| 	account, err := GetAccountTx(tx, accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		sqltransaction.Rollback() | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	importedAccount := itl.Accounts[0] | ||||
|  | ||||
| 	if len(account.ExternalAccountId) > 0 && | ||||
| 		account.ExternalAccountId != importedAccount.ExternalAccountId { | ||||
| 		sqltransaction.Rollback() | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Printf("OFX import has \"%s\" as ExternalAccountId, but the account being imported to has\"%s\"", | ||||
| 			importedAccount.ExternalAccountId, | ||||
| 			account.ExternalAccountId) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	// Find matching existing securities or create new ones for those | ||||
| @@ -74,21 +61,17 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc | ||||
| 		// save off since ImportGetCreateSecurity overwrites SecurityId on | ||||
| 		// ofxsecurity | ||||
| 		oldsecurityid := ofxsecurity.SecurityId | ||||
| 		security, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &ofxsecurity) | ||||
| 		security, err := ImportGetCreateSecurity(tx, user.UserId, &ofxsecurity) | ||||
| 		if err != nil { | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
| 		securitymap[oldsecurityid] = *security | ||||
| 	} | ||||
|  | ||||
| 	if account.SecurityId != securitymap[importedAccount.SecurityId].SecurityId { | ||||
| 		sqltransaction.Rollback() | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Printf("OFX import account's SecurityId (%d) does not match this account's (%d)", securitymap[importedAccount.SecurityId].SecurityId, account.SecurityId) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	// TODO Ensure all transactions have at least one split in the account | ||||
| @@ -99,10 +82,8 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc | ||||
| 		transaction.UserId = user.UserId | ||||
|  | ||||
| 		if !transaction.Valid() { | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print("Unexpected invalid transaction from OFX import") | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
|  | ||||
| 		// Ensure that either AccountId or SecurityId is set for this split, | ||||
| @@ -112,10 +93,8 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc | ||||
| 			split.Status = Imported | ||||
| 			if split.AccountId != -1 { | ||||
| 				if split.AccountId != importedAccount.AccountId { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print("Imported split's AccountId wasn't -1 but also didn't match the account") | ||||
| 					return | ||||
| 					return NewError(999 /*Internal Error*/) | ||||
| 				} | ||||
| 				split.AccountId = account.AccountId | ||||
| 			} else if split.SecurityId != -1 { | ||||
| @@ -123,12 +102,10 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc | ||||
| 					// TODO try to auto-match splits to existing accounts based on past transactions that look like this one | ||||
| 					if split.ImportSplitType == TradingAccount { | ||||
| 						// Find/make trading account if we're that type of split | ||||
| 						trading_account, err := GetTradingAccount(sqltransaction, user.UserId, sec.SecurityId) | ||||
| 						trading_account, err := GetTradingAccount(tx, user.UserId, sec.SecurityId) | ||||
| 						if err != nil { | ||||
| 							sqltransaction.Rollback() | ||||
| 							WriteError(w, 999 /*Internal Error*/) | ||||
| 							log.Print("Couldn't find split's SecurityId in map during OFX import") | ||||
| 							return | ||||
| 							return NewError(999 /*Internal Error*/) | ||||
| 						} | ||||
| 						split.AccountId = trading_account.AccountId | ||||
| 						split.SecurityId = -1 | ||||
| @@ -140,12 +117,10 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc | ||||
| 							SecurityId:      sec.SecurityId, | ||||
| 							Type:            account.Type, | ||||
| 						} | ||||
| 						subaccount, err := GetCreateAccountTx(sqltransaction, *subaccount) | ||||
| 						subaccount, err := GetCreateAccountTx(tx, *subaccount) | ||||
| 						if err != nil { | ||||
| 							sqltransaction.Rollback() | ||||
| 							WriteError(w, 999 /*Internal Error*/) | ||||
| 							log.Print(err) | ||||
| 							return | ||||
| 							return NewError(999 /*Internal Error*/) | ||||
| 						} | ||||
| 						split.AccountId = subaccount.AccountId | ||||
| 						split.SecurityId = -1 | ||||
| @@ -153,49 +128,39 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc | ||||
| 						split.SecurityId = sec.SecurityId | ||||
| 					} | ||||
| 				} else { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print("Couldn't find split's SecurityId in map during OFX import") | ||||
| 					return | ||||
| 					return NewError(999 /*Internal Error*/) | ||||
| 				} | ||||
| 			} else { | ||||
| 				sqltransaction.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print("Neither Split.AccountId Split.SecurityId was set during OFX import") | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		imbalances, err := transaction.GetImbalancesTx(sqltransaction) | ||||
| 		imbalances, err := transaction.GetImbalancesTx(tx) | ||||
| 		if err != nil { | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
|  | ||||
| 		// Fixup any imbalances in transactions | ||||
| 		var zero big.Rat | ||||
| 		for imbalanced_security, imbalance := range imbalances { | ||||
| 			if imbalance.Cmp(&zero) != 0 { | ||||
| 				imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, imbalanced_security) | ||||
| 				imbalanced_account, err := GetImbalanceAccount(tx, user.UserId, imbalanced_security) | ||||
| 				if err != nil { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print(err) | ||||
| 					return | ||||
| 					return NewError(999 /*Internal Error*/) | ||||
| 				} | ||||
|  | ||||
| 				// Add new split to fixup imbalance | ||||
| 				split := new(Split) | ||||
| 				r := new(big.Rat) | ||||
| 				r.Neg(&imbalance) | ||||
| 				security, err := GetSecurityTx(sqltransaction, imbalanced_security, user.UserId) | ||||
| 				security, err := GetSecurityTx(tx, imbalanced_security, user.UserId) | ||||
| 				if err != nil { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print(err) | ||||
| 					return | ||||
| 					return NewError(999 /*Internal Error*/) | ||||
| 				} | ||||
| 				split.Amount = r.FloatString(security.Precision) | ||||
| 				split.SecurityId = -1 | ||||
| @@ -210,24 +175,20 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc | ||||
| 		var already_imported bool | ||||
| 		for _, split := range transaction.Splits { | ||||
| 			if split.SecurityId != -1 || split.AccountId == -1 { | ||||
| 				imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, split.SecurityId) | ||||
| 				imbalanced_account, err := GetImbalanceAccount(tx, user.UserId, split.SecurityId) | ||||
| 				if err != nil { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print(err) | ||||
| 					return | ||||
| 					return NewError(999 /*Internal Error*/) | ||||
| 				} | ||||
|  | ||||
| 				split.AccountId = imbalanced_account.AccountId | ||||
| 				split.SecurityId = -1 | ||||
| 			} | ||||
|  | ||||
| 			exists, err := split.AlreadyImportedTx(sqltransaction) | ||||
| 			exists, err := split.AlreadyImportedTx(tx) | ||||
| 			if err != nil { | ||||
| 				sqltransaction.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print("Error checking if split was already imported:", err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} else if exists { | ||||
| 				already_imported = true | ||||
| 			} | ||||
| @@ -239,55 +200,38 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc | ||||
| 	} | ||||
|  | ||||
| 	for _, transaction := range transactions { | ||||
| 		err := InsertTransactionTx(sqltransaction, &transaction, user) | ||||
| 		err := InsertTransactionTx(tx, &transaction, user) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	err = sqltransaction.Commit() | ||||
| 	if err != nil { | ||||
| 		sqltransaction.Rollback() | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	return SuccessWriter{} | ||||
| } | ||||
|  | ||||
| 	WriteSuccess(w) | ||||
| } | ||||
|  | ||||
| func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) { | ||||
| func OFXImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { | ||||
| 	download_json := r.PostFormValue("ofxdownload") | ||||
| 	if download_json == "" { | ||||
| 		log.Print("download_json") | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	var ofxdownload OFXDownload | ||||
| 	err := ofxdownload.Read(download_json) | ||||
| 	if err != nil { | ||||
| 		log.Print("ofxdownload.Read") | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	account, err := GetAccount(db, accountid, user.UserId) | ||||
| 	account, err := GetAccount(tx, accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		log.Print("GetAccount") | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	ofxver := ofxgo.OfxVersion203 | ||||
| 	if len(account.OFXVersion) != 0 { | ||||
| 		ofxver, err = ofxgo.NewOfxVersion(account.OFXVersion) | ||||
| 		if err != nil { | ||||
| 			log.Print("NewOfxVersion") | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -308,9 +252,8 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User | ||||
|  | ||||
| 	transactionuid, err := ofxgo.RandomUID() | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Println("Error creating uid for transaction:", err) | ||||
| 		return | ||||
| 		return NewError(999 /*Internal Error*/) | ||||
| 	} | ||||
|  | ||||
| 	if account.Type == Investment { | ||||
| @@ -343,8 +286,7 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User | ||||
| 		// Import generic bank transactions | ||||
| 		acctTypeEnum, err := ofxgo.NewAcctType(account.OFXAcctType) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		statementRequest := ofxgo.StatementRequest{ | ||||
| 			TrnUID: *transactionuid, | ||||
| @@ -361,49 +303,46 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User | ||||
| 	response, err := client.RequestNoParse(&query) | ||||
| 	if err != nil { | ||||
| 		// TODO this could be an error talking with the OFX server... | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
| 	defer response.Body.Close() | ||||
|  | ||||
| 	ofxImportHelper(db, response.Body, w, user, accountid) | ||||
| 	return ofxImportHelper(tx, response.Body, user, accountid) | ||||
| } | ||||
|  | ||||
| func OFXFileImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) { | ||||
| func OFXFileImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { | ||||
| 	multipartReader, err := r.MultipartReader() | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	// assume there is only one 'part' | ||||
| 	part, err := multipartReader.NextPart() | ||||
| 	if err != nil { | ||||
| 		if err == io.EOF { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 			log.Print("Encountered unexpected EOF") | ||||
| 		} else { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	ofxImportHelper(db, part, w, user, accountid) | ||||
| 	return ofxImportHelper(tx, part, user, accountid) | ||||
| } | ||||
|  | ||||
| /* | ||||
|  * Assumes the User is a valid, signed-in user, but accountid has not yet been validated | ||||
|  */ | ||||
| func AccountImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) { | ||||
| func AccountImportHandler(tx *Tx, r *http.Request, user *User, accountid int64, importtype string) ResponseWriterWriter { | ||||
|  | ||||
| 	switch importtype { | ||||
| 	case "ofx": | ||||
| 		OFXImportHandler(db, w, r, user, accountid) | ||||
| 		return OFXImportHandler(tx, r, user, accountid) | ||||
| 	case "ofxfile": | ||||
| 		OFXFileImportHandler(db, w, r, user, accountid) | ||||
| 		return OFXFileImportHandler(tx, r, user, accountid) | ||||
| 	default: | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -91,23 +91,6 @@ func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Securi | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetClosestPrice(db *DB, security, currency *Security, date *time.Time) (*Price, error) { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	price, err := GetClosestPriceTx(transaction, security, currency, date) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return price, nil | ||||
| func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { | ||||
| 	return GetClosestPriceTx(tx, security, currency, date) | ||||
| } | ||||
|   | ||||
| @@ -77,36 +77,36 @@ func (r *Tabulation) Write(w http.ResponseWriter) error { | ||||
| 	return enc.Encode(r) | ||||
| } | ||||
|  | ||||
| func GetReport(db *DB, reportid int64, userid int64) (*Report, error) { | ||||
| func GetReport(tx *Tx, reportid int64, userid int64) (*Report, error) { | ||||
| 	var r Report | ||||
|  | ||||
| 	err := db.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) | ||||
| 	err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &r, nil | ||||
| } | ||||
|  | ||||
| func GetReports(db *DB, userid int64) (*[]Report, error) { | ||||
| func GetReports(tx *Tx, userid int64) (*[]Report, error) { | ||||
| 	var reports []Report | ||||
|  | ||||
| 	_, err := db.Select(&reports, "SELECT * from reports where UserId=?", userid) | ||||
| 	_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &reports, nil | ||||
| } | ||||
|  | ||||
| func InsertReport(db *DB, r *Report) error { | ||||
| 	err := db.Insert(r) | ||||
| func InsertReport(tx *Tx, r *Report) error { | ||||
| 	err := tx.Insert(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func UpdateReport(db *DB, r *Report) error { | ||||
| 	count, err := db.Update(r) | ||||
| func UpdateReport(tx *Tx, r *Report) error { | ||||
| 	count, err := tx.Update(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -116,8 +116,8 @@ func UpdateReport(db *DB, r *Report) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func DeleteReport(db *DB, r *Report) error { | ||||
| 	count, err := db.Delete(r) | ||||
| func DeleteReport(tx *Tx, r *Report) error { | ||||
| 	count, err := tx.Delete(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -127,14 +127,14 @@ func DeleteReport(db *DB, r *Report) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func runReport(db *DB, user *User, report *Report) (*Tabulation, error) { | ||||
| func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) { | ||||
| 	// Create a new LState without opening the default libs for security | ||||
| 	L := lua.NewState(lua.Options{SkipOpenLibs: true}) | ||||
| 	defer L.Close() | ||||
|  | ||||
| 	// Create a new context holding the current user with a timeout | ||||
| 	ctx := context.WithValue(context.Background(), userContextKey, user) | ||||
| 	ctx = context.WithValue(ctx, dbContextKey, db) | ||||
| 	ctx = context.WithValue(ctx, dbContextKey, tx) | ||||
| 	ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second) | ||||
| 	defer cancel() | ||||
| 	L.SetContext(ctx) | ||||
| @@ -191,79 +191,60 @@ func runReport(db *DB, user *User, report *Report) (*Tabulation, error) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ReportTabulationHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, reportid int64) { | ||||
| 	report, err := GetReport(db, reportid, user.UserId) | ||||
| func ReportTabulationHandler(tx *Tx, r *http.Request, user *User, reportid int64) ResponseWriterWriter { | ||||
| 	report, err := GetReport(tx, reportid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	tabulation, err := runReport(db, user, report) | ||||
| 	tabulation, err := runReport(tx, user, report) | ||||
| 	if err != nil { | ||||
| 		// TODO handle different failure cases differently | ||||
| 		log.Print("runReport returned:", err) | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	tabulation.ReportId = reportid | ||||
|  | ||||
| 	err = tabulation.Write(w) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
| 	return tabulation | ||||
| } | ||||
|  | ||||
| func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	user, err := GetUserFromSession(db, r) | ||||
| func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { | ||||
| 	user, err := GetUserFromSession(tx, r) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 1 /*Not Signed In*/) | ||||
| 		return | ||||
| 		return NewError(1 /*Not Signed In*/) | ||||
| 	} | ||||
|  | ||||
| 	if r.Method == "POST" { | ||||
| 		report_json := r.PostFormValue("report") | ||||
| 		if report_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		var report Report | ||||
| 		err := report.Read(report_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		report.ReportId = -1 | ||||
| 		report.UserId = user.UserId | ||||
|  | ||||
| 		err = InsertReport(db, &report) | ||||
| 		err = InsertReport(tx, &report) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
|  | ||||
| 		w.WriteHeader(201 /*Created*/) | ||||
| 		err = report.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 		return ResponseWrapper{201, &report} | ||||
| 	} else if r.Method == "GET" { | ||||
| 		if reportTabulationRE.MatchString(r.URL.Path) { | ||||
| 			var reportid int64 | ||||
| 			n, err := GetURLPieces(r.URL.Path, "/report/%d/tabulation", &reportid) | ||||
| 			if err != nil || n != 1 { | ||||
| 				WriteError(w, 999 /*InternalError*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*InternalError*/) | ||||
| 			} | ||||
| 			ReportTabulationHandler(db, w, r, user, reportid) | ||||
| 			return | ||||
| 			return ReportTabulationHandler(tx, r, user, reportid) | ||||
| 		} | ||||
|  | ||||
| 		var reportid int64 | ||||
| @@ -271,84 +252,62 @@ func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 		if err != nil || n != 1 { | ||||
| 			//Return all Reports | ||||
| 			var rl ReportList | ||||
| 			reports, err := GetReports(db, user.UserId) | ||||
| 			reports, err := GetReports(tx, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 			rl.Reports = reports | ||||
| 			err = (&rl).Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return &rl | ||||
| 		} else { | ||||
| 			// Return Report with this Id | ||||
| 			report, err := GetReport(db, reportid, user.UserId) | ||||
| 			report, err := GetReport(tx, reportid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = report.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return report | ||||
| 		} | ||||
| 	} else { | ||||
| 		reportid, err := GetURLID(r.URL.Path) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		if r.Method == "PUT" { | ||||
| 			report_json := r.PostFormValue("report") | ||||
| 			if report_json == "" { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			var report Report | ||||
| 			err := report.Read(report_json) | ||||
| 			if err != nil || report.ReportId != reportid { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| 			report.UserId = user.UserId | ||||
|  | ||||
| 			err = UpdateReport(db, &report) | ||||
| 			err = UpdateReport(tx, &report) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
|  | ||||
| 			err = report.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return &report | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			report, err := GetReport(db, reportid, user.UserId) | ||||
| 			report, err := GetReport(tx, reportid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteReport(db, report) | ||||
| 			err = DeleteReport(tx, report) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
|  | ||||
| 			WriteSuccess(w) | ||||
| 			return SuccessWriter{} | ||||
| 		} | ||||
| 	} | ||||
| 	return NewError(3 /*Invalid Request*/) | ||||
| } | ||||
|   | ||||
| @@ -103,10 +103,10 @@ func FindCurrencyTemplate(iso4217 int64) *Security { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) { | ||||
| func GetSecurity(tx *Tx, securityid int64, userid int64) (*Security, error) { | ||||
| 	var s Security | ||||
|  | ||||
| 	err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) | ||||
| 	err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -123,18 +123,18 @@ func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64 | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func GetSecurities(db *DB, userid int64) (*[]*Security, error) { | ||||
| func GetSecurities(tx *Tx, userid int64) (*[]*Security, error) { | ||||
| 	var securities []*Security | ||||
|  | ||||
| 	_, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid) | ||||
| 	_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &securities, nil | ||||
| } | ||||
|  | ||||
| func InsertSecurity(db *DB, s *Security) error { | ||||
| 	err := db.Insert(s) | ||||
| func InsertSecurity(tx *Tx, s *Security) error { | ||||
| 	err := tx.Insert(s) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -149,37 +149,22 @@ func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func UpdateSecurity(db *DB, s *Security) error { | ||||
| 	transaction, err := db.Begin() | ||||
| func UpdateSecurity(tx *Tx, s *Security) (err error) { | ||||
| 	user, err := GetUserTx(tx, s.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	user, err := GetUserTx(transaction, s.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 		return | ||||
| 	} else if user.DefaultCurrency == s.SecurityId && s.Type != Currency { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Cannot change security which is user's default currency to be non-currency") | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Update(s) | ||||
| 	count, err := tx.Update(s) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 		return | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Updated more than one security") | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -191,53 +176,36 @@ func (e SecurityInUseError) Error() string { | ||||
| 	return e.message | ||||
| } | ||||
|  | ||||
| func DeleteSecurity(db *DB, s *Security) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| func DeleteSecurity(tx *Tx, s *Security) error { | ||||
| 	// First, ensure no accounts are using this security | ||||
| 	accounts, err := transaction.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) | ||||
| 	accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) | ||||
|  | ||||
| 	if accounts != 0 { | ||||
| 		transaction.Rollback() | ||||
| 		return SecurityInUseError{"One or more accounts still use this security"} | ||||
| 	} | ||||
|  | ||||
| 	user, err := GetUserTx(transaction, s.UserId) | ||||
| 	user, err := GetUserTx(tx, s.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} else if user.DefaultCurrency == s.SecurityId { | ||||
| 		transaction.Rollback() | ||||
| 		return SecurityInUseError{"Cannot delete security which is user's default currency"} | ||||
| 	} | ||||
|  | ||||
| 	// Remove all prices involving this security (either of this security, or | ||||
| 	// using it as a currency) | ||||
| 	_, err = transaction.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) | ||||
| 	_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Delete(s) | ||||
| 	count, err := tx.Delete(s) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Deleted more than one security") | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -294,43 +262,33 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi | ||||
| 	return security, nil | ||||
| } | ||||
|  | ||||
| func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	user, err := GetUserFromSession(db, r) | ||||
| func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { | ||||
| 	user, err := GetUserFromSessionTx(tx, r) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 1 /*Not Signed In*/) | ||||
| 		return | ||||
| 		return NewError(1 /*Not Signed In*/) | ||||
| 	} | ||||
|  | ||||
| 	if r.Method == "POST" { | ||||
| 		security_json := r.PostFormValue("security") | ||||
| 		if security_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		var security Security | ||||
| 		err := security.Read(security_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		security.SecurityId = -1 | ||||
| 		security.UserId = user.UserId | ||||
|  | ||||
| 		err = InsertSecurity(db, &security) | ||||
| 		err = InsertSecurityTx(tx, &security) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
|  | ||||
| 		w.WriteHeader(201 /*Created*/) | ||||
| 		err = security.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 		return ResponseWrapper{201, &security} | ||||
| 	} else if r.Method == "GET" { | ||||
| 		var securityid int64 | ||||
| 		n, err := GetURLPieces(r.URL.Path, "/security/%d", &securityid) | ||||
| @@ -339,87 +297,65 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 			//Return all securities | ||||
| 			var sl SecurityList | ||||
|  | ||||
| 			securities, err := GetSecurities(db, user.UserId) | ||||
| 			securities, err := GetSecurities(tx, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
|  | ||||
| 			sl.Securities = securities | ||||
| 			err = (&sl).Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return &sl | ||||
| 		} else { | ||||
| 			security, err := GetSecurity(db, securityid, user.UserId) | ||||
| 			security, err := GetSecurityTx(tx, securityid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = security.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return security | ||||
| 		} | ||||
| 	} else { | ||||
| 		securityid, err := GetURLID(r.URL.Path) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		if r.Method == "PUT" { | ||||
| 			security_json := r.PostFormValue("security") | ||||
| 			if security_json == "" { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			var security Security | ||||
| 			err := security.Read(security_json) | ||||
| 			if err != nil || security.SecurityId != securityid { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| 			security.UserId = user.UserId | ||||
|  | ||||
| 			err = UpdateSecurity(db, &security) | ||||
| 			err = UpdateSecurity(tx, &security) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
|  | ||||
| 			err = security.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return &security | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			security, err := GetSecurity(db, securityid, user.UserId) | ||||
| 			security, err := GetSecurityTx(tx, securityid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteSecurity(db, security) | ||||
| 			err = DeleteSecurity(tx, security) | ||||
| 			if _, ok := err.(SecurityInUseError); ok { | ||||
| 				WriteError(w, 7 /*In Use Error*/) | ||||
| 				return NewError(7 /*In Use Error*/) | ||||
| 			} else if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
|  | ||||
| 			WriteSuccess(w) | ||||
| 			return SuccessWriter{} | ||||
| 		} | ||||
| 	} | ||||
| 	return NewError(3 /*Invalid Request*/) | ||||
| } | ||||
|  | ||||
| func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) { | ||||
|   | ||||
| @@ -13,9 +13,9 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { | ||||
|  | ||||
| 	ctx := L.Context() | ||||
|  | ||||
| 	db, ok := ctx.Value(dbContextKey).(*DB) | ||||
| 	tx, ok := ctx.Value(dbContextKey).(*Tx) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Couldn't find DB in lua's Context") | ||||
| 		return nil, errors.New("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
|  | ||||
| 	security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security) | ||||
| @@ -25,7 +25,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { | ||||
| 			return nil, errors.New("Couldn't find User in lua's Context") | ||||
| 		} | ||||
|  | ||||
| 		securities, err := GetSecurities(db, user.UserId) | ||||
| 		securities, err := GetSecurities(tx, user.UserId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @@ -155,12 +155,12 @@ func luaClosestPrice(L *lua.LState) int { | ||||
| 	date := luaCheckTime(L, 3) | ||||
|  | ||||
| 	ctx := L.Context() | ||||
| 	db, ok := ctx.Value(dbContextKey).(*DB) | ||||
| 	tx, ok := ctx.Value(dbContextKey).(*Tx) | ||||
| 	if !ok { | ||||
| 		panic("Couldn't find DB in lua's Context") | ||||
| 		panic("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
|  | ||||
| 	p, err := GetClosestPrice(db, s, c, date) | ||||
| 	p, err := GetClosestPrice(tx, s, c, date) | ||||
| 	if err != nil { | ||||
| 		L.Push(lua.LNil) | ||||
| 	} else { | ||||
|   | ||||
| @@ -28,7 +28,7 @@ func (s *Session) Read(json_str string) error { | ||||
| 	return dec.Decode(s) | ||||
| } | ||||
|  | ||||
| func GetSession(db *DB, r *http.Request) (*Session, error) { | ||||
| func GetSession(tx *Tx, r *http.Request) (*Session, error) { | ||||
| 	var s Session | ||||
|  | ||||
| 	cookie, err := r.Cookie("moneygo-session") | ||||
| @@ -37,18 +37,33 @@ func GetSession(db *DB, r *http.Request) (*Session, error) { | ||||
| 	} | ||||
| 	s.SessionSecret = cookie.Value | ||||
|  | ||||
| 	err = db.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) | ||||
| 	err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func DeleteSessionIfExists(db *DB, r *http.Request) error { | ||||
| 	// TODO do this in one transaction | ||||
| 	session, err := GetSession(db, r) | ||||
| func GetSessionTx(tx *Tx, r *http.Request) (*Session, error) { | ||||
| 	var s Session | ||||
|  | ||||
| 	cookie, err := r.Cookie("moneygo-session") | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("moneygo-session cookie not set") | ||||
| 	} | ||||
| 	s.SessionSecret = cookie.Value | ||||
|  | ||||
| 	err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func DeleteSessionIfExists(tx *Tx, r *http.Request) error { | ||||
| 	session, err := GetSessionTx(tx, r) | ||||
| 	if err == nil { | ||||
| 		_, err := db.Delete(session) | ||||
| 		_, err := tx.Delete(session) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -64,7 +79,17 @@ func NewSessionCookie() (string, error) { | ||||
| 	return base64.StdEncoding.EncodeToString(bits), nil | ||||
| } | ||||
|  | ||||
| func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) { | ||||
| type NewSessionWriter struct { | ||||
| 	session *Session | ||||
| 	cookie  *http.Cookie | ||||
| } | ||||
|  | ||||
| func (n *NewSessionWriter) Write(w http.ResponseWriter) error { | ||||
| 	http.SetCookie(w, n.cookie) | ||||
| 	return n.session.Write(w) | ||||
| } | ||||
|  | ||||
| func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { | ||||
| 	s := Session{} | ||||
|  | ||||
| 	session_secret, err := NewSessionCookie() | ||||
| @@ -81,79 +106,66 @@ func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (* | ||||
| 		Secure:   true, | ||||
| 		HttpOnly: true, | ||||
| 	} | ||||
| 	http.SetCookie(w, &cookie) | ||||
|  | ||||
| 	s.SessionSecret = session_secret | ||||
| 	s.UserId = userid | ||||
|  | ||||
| 	err = db.Insert(&s) | ||||
| 	err = tx.Insert(&s) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &s, nil | ||||
| 	return &NewSessionWriter{&s, &cookie}, nil | ||||
| } | ||||
|  | ||||
| func SessionHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| func SessionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { | ||||
| 	if r.Method == "POST" || r.Method == "PUT" { | ||||
| 		user_json := r.PostFormValue("user") | ||||
| 		if user_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		user := User{} | ||||
| 		err := user.Read(user_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		dbuser, err := GetUserByUsername(db, user.Username) | ||||
| 		dbuser, err := GetUserByUsername(tx, user.Username) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 2 /*Unauthorized Access*/) | ||||
| 			return | ||||
| 			return NewError(2 /*Unauthorized Access*/) | ||||
| 		} | ||||
|  | ||||
| 		user.HashPassword() | ||||
| 		if user.PasswordHash != dbuser.PasswordHash { | ||||
| 			WriteError(w, 2 /*Unauthorized Access*/) | ||||
| 			return | ||||
| 			return NewError(2 /*Unauthorized Access*/) | ||||
| 		} | ||||
|  | ||||
| 		err = DeleteSessionIfExists(db, r) | ||||
| 		err = DeleteSessionIfExists(tx, r) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
|  | ||||
| 		session, err := NewSession(db, w, r, dbuser.UserId) | ||||
| 		sessionwriter, err := NewSession(tx, r, dbuser.UserId) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		err = session.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
| 		return sessionwriter | ||||
| 	} else if r.Method == "GET" { | ||||
| 		s, err := GetSession(db, r) | ||||
| 		s, err := GetSessionTx(tx, r) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 1 /*Not Signed In*/) | ||||
| 			return | ||||
| 			return NewError(1 /*Not Signed In*/) | ||||
| 		} | ||||
|  | ||||
| 		s.Write(w) | ||||
| 		return s | ||||
| 	} else if r.Method == "DELETE" { | ||||
| 		err := DeleteSessionIfExists(db, r) | ||||
| 		err := DeleteSessionIfExists(tx, r) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
| 		WriteSuccess(w) | ||||
| 		return SuccessWriter{} | ||||
| 	} | ||||
| 	return NewError(3 /*Invalid Request*/) | ||||
| } | ||||
|   | ||||
| @@ -178,72 +178,48 @@ func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) { | ||||
| 	return true, nil | ||||
| } | ||||
|  | ||||
| func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) { | ||||
| func GetTransaction(tx *Tx, transactionid int64, userid int64) (*Transaction, error) { | ||||
| 	var t Transaction | ||||
|  | ||||
| 	transaction, err := db.Begin() | ||||
| 	err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) | ||||
| 	_, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	_, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &t, nil | ||||
| } | ||||
|  | ||||
| func GetTransactions(db *DB, userid int64) (*[]Transaction, error) { | ||||
| func GetTransactions(tx *Tx, userid int64) (*[]Transaction, error) { | ||||
| 	var transactions []Transaction | ||||
|  | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	_, err = transaction.Select(&transactions, "SELECT * from transactions where UserId=?", userid) | ||||
| 	_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	for i := range transactions { | ||||
| 		_, err := transaction.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) | ||||
| 		_, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &transactions, nil | ||||
| } | ||||
|  | ||||
| func incrementAccountVersions(transaction *gorp.Transaction, user *User, accountids []int64) error { | ||||
| func incrementAccountVersions(tx *Tx, user *User, accountids []int64) error { | ||||
| 	for i := range accountids { | ||||
| 		account, err := GetAccountTx(transaction, accountids[i], user.UserId) | ||||
| 		account, err := GetAccountTx(tx, accountids[i], user.UserId) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		account.AccountVersion++ | ||||
| 		count, err := transaction.Update(account) | ||||
| 		count, err := tx.Update(account) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -260,12 +236,12 @@ func (ame AccountMissingError) Error() string { | ||||
| 	return "Account missing" | ||||
| } | ||||
|  | ||||
| func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error { | ||||
| func InsertTransactionTx(tx *Tx, t *Transaction, user *User) error { | ||||
| 	// Map of any accounts with transaction splits being added | ||||
| 	a_map := make(map[int64]bool) | ||||
| 	for i := range t.Splits { | ||||
| 		if t.Splits[i].AccountId != -1 { | ||||
| 			existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) | ||||
| 			existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| @@ -287,13 +263,13 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us | ||||
| 	if len(a_ids) < 1 { | ||||
| 		return AccountMissingError{} | ||||
| 	} | ||||
| 	err := incrementAccountVersions(transaction, user, a_ids) | ||||
| 	err := incrementAccountVersions(tx, user, a_ids) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	t.UserId = user.UserId | ||||
| 	err = transaction.Insert(t) | ||||
| 	err = tx.Insert(t) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -301,7 +277,7 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us | ||||
| 	for i := range t.Splits { | ||||
| 		t.Splits[i].TransactionId = t.TransactionId | ||||
| 		t.Splits[i].SplitId = -1 | ||||
| 		err = transaction.Insert(t.Splits[i]) | ||||
| 		err = tx.Insert(t.Splits[i]) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -310,31 +286,19 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func InsertTransaction(db *DB, t *Transaction, user *User) error { | ||||
| 	transaction, err := db.Begin() | ||||
| func InsertTransaction(tx *Tx, t *Transaction, user *User) error { | ||||
| 	err := InsertTransactionTx(tx, t, user) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = InsertTransactionTx(transaction, t, user) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error { | ||||
| func UpdateTransactionTx(tx *Tx, t *Transaction, user *User) error { | ||||
| 	var existing_splits []*Split | ||||
|  | ||||
| 	_, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) | ||||
| 	_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -353,7 +317,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us | ||||
| 		t.Splits[i].TransactionId = t.TransactionId | ||||
| 		_, ok := s_map[t.Splits[i].SplitId] | ||||
| 		if ok { | ||||
| 			count, err := transaction.Update(t.Splits[i]) | ||||
| 			count, err := tx.Update(t.Splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| @@ -363,7 +327,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us | ||||
| 			delete(s_map, t.Splits[i].SplitId) | ||||
| 		} else { | ||||
| 			t.Splits[i].SplitId = -1 | ||||
| 			err := transaction.Insert(t.Splits[i]) | ||||
| 			err := tx.Insert(t.Splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| @@ -380,7 +344,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us | ||||
| 			a_map[existing_splits[i].AccountId] = true | ||||
| 		} | ||||
| 		if ok { | ||||
| 			_, err := transaction.Delete(existing_splits[i]) | ||||
| 			_, err := tx.Delete(existing_splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| @@ -392,12 +356,12 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us | ||||
| 	for id := range a_map { | ||||
| 		a_ids = append(a_ids, id) | ||||
| 	} | ||||
| 	err = incrementAccountVersions(transaction, user, a_ids) | ||||
| 	err = incrementAccountVersions(tx, user, a_ids) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Update(t) | ||||
| 	count, err := tx.Update(t) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -408,257 +372,165 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func DeleteTransaction(db *DB, t *Transaction, user *User) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| func DeleteTransaction(tx *Tx, t *Transaction, user *User) error { | ||||
| 	var accountids []int64 | ||||
| 	_, err = transaction.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) | ||||
| 	_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	_, err = transaction.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) | ||||
| 	_, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Delete(t) | ||||
| 	count, err := tx.Delete(t) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Deleted more than one transaction") | ||||
| 	} | ||||
|  | ||||
| 	err = incrementAccountVersions(transaction, user, accountids) | ||||
| 	err = incrementAccountVersions(tx, user, accountids) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	user, err := GetUserFromSession(db, r) | ||||
| func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { | ||||
| 	user, err := GetUserFromSession(tx, r) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 1 /*Not Signed In*/) | ||||
| 		return | ||||
| 		return NewError(1 /*Not Signed In*/) | ||||
| 	} | ||||
|  | ||||
| 	if r.Method == "POST" { | ||||
| 		transaction_json := r.PostFormValue("transaction") | ||||
| 		if transaction_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		var transaction Transaction | ||||
| 		err := transaction.Read(transaction_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		transaction.TransactionId = -1 | ||||
| 		transaction.UserId = user.UserId | ||||
|  | ||||
| 		sqltx, err := db.Begin() | ||||
| 		balanced, err := transaction.Balanced(tx) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		balanced, err := transaction.Balanced(sqltx) | ||||
| 		if err != nil { | ||||
| 			sqltx.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 		if !transaction.Valid() || !balanced { | ||||
| 			sqltx.Rollback() | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		for i := range transaction.Splits { | ||||
| 			transaction.Splits[i].SplitId = -1 | ||||
| 			_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId) | ||||
| 			_, err := GetAccountTx(tx, transaction.Splits[i].AccountId, user.UserId) | ||||
| 			if err != nil { | ||||
| 				sqltx.Rollback() | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		err = InsertTransactionTx(sqltx, &transaction, user) | ||||
| 		err = InsertTransactionTx(tx, &transaction, user) | ||||
| 		if err != nil { | ||||
| 			if _, ok := err.(AccountMissingError); ok { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} else { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 			sqltx.Rollback() | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		err = sqltx.Commit() | ||||
| 		if err != nil { | ||||
| 			sqltx.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		err = transaction.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 		return &transaction | ||||
| 	} else if r.Method == "GET" { | ||||
| 		transactionid, err := GetURLID(r.URL.Path) | ||||
|  | ||||
| 		if err != nil { | ||||
| 			//Return all Transactions | ||||
| 			var al TransactionList | ||||
| 			transactions, err := GetTransactions(db, user.UserId) | ||||
| 			transactions, err := GetTransactions(tx, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 			al.Transactions = transactions | ||||
| 			err = (&al).Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return &al | ||||
| 		} else { | ||||
| 			//Return Transaction with this Id | ||||
| 			transaction, err := GetTransaction(db, transactionid, user.UserId) | ||||
| 			transaction, err := GetTransaction(tx, transactionid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
| 			err = transaction.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| 			return transaction | ||||
| 		} | ||||
| 	} else { | ||||
| 		transactionid, err := GetURLID(r.URL.Path) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		if r.Method == "PUT" { | ||||
| 			transaction_json := r.PostFormValue("transaction") | ||||
| 			if transaction_json == "" { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			var transaction Transaction | ||||
| 			err := transaction.Read(transaction_json) | ||||
| 			if err != nil || transaction.TransactionId != transactionid { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| 			transaction.UserId = user.UserId | ||||
|  | ||||
| 			sqltx, err := db.Begin() | ||||
| 			balanced, err := transaction.Balanced(tx) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			balanced, err := transaction.Balanced(sqltx) | ||||
| 			if err != nil { | ||||
| 				sqltx.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 			if !transaction.Valid() || !balanced { | ||||
| 				sqltx.Rollback() | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			for i := range transaction.Splits { | ||||
| 				_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId) | ||||
| 				_, err := GetAccountTx(tx, transaction.Splits[i].AccountId, user.UserId) | ||||
| 				if err != nil { | ||||
| 					sqltx.Rollback() | ||||
| 					WriteError(w, 3 /*Invalid Request*/) | ||||
| 					return | ||||
| 					return NewError(3 /*Invalid Request*/) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			err = UpdateTransactionTx(sqltx, &transaction, user) | ||||
| 			err = UpdateTransactionTx(tx, &transaction, user) | ||||
| 			if err != nil { | ||||
| 				sqltx.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
|  | ||||
| 			err = sqltx.Commit() | ||||
| 			if err != nil { | ||||
| 				sqltx.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = transaction.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return &transaction | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			transactionid, err := GetURLID(r.URL.Path) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			transaction, err := GetTransaction(db, transactionid, user.UserId) | ||||
| 			transaction, err := GetTransaction(tx, transactionid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteTransaction(db, transaction, user) | ||||
| 			err = DeleteTransaction(tx, transaction, user) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
|  | ||||
| 			WriteSuccess(w) | ||||
| 			return SuccessWriter{} | ||||
| 		} | ||||
| 	} | ||||
| 	return NewError(3 /*Invalid Request*/) | ||||
| } | ||||
|  | ||||
| func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int64, transactions []Transaction) (*big.Rat, error) { | ||||
| @@ -685,17 +557,12 @@ func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int6 | ||||
| 	return &pageDifference, nil | ||||
| } | ||||
|  | ||||
| func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) { | ||||
| func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) { | ||||
| 	var splits []Split | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" | ||||
| 	_, err = transaction.Select(&splits, sql, accountid, user.UserId) | ||||
| 	_, err := tx.Select(&splits, sql, accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| @@ -703,34 +570,22 @@ func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) { | ||||
| 	for _, s := range splits { | ||||
| 		rat_amount, err := GetBigAmount(s.Amount) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &balance, nil | ||||
| } | ||||
|  | ||||
| // Assumes accountid is valid and is owned by the current user | ||||
| func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) { | ||||
| func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time) (*big.Rat, error) { | ||||
| 	var splits []Split | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" | ||||
| 	_, err = transaction.Select(&splits, sql, accountid, user.UserId, date) | ||||
| 	_, err := tx.Select(&splits, sql, accountid, user.UserId, date) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| @@ -738,33 +593,21 @@ func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) | ||||
| 	for _, s := range splits { | ||||
| 		rat_amount, err := GetBigAmount(s.Amount) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &balance, nil | ||||
| } | ||||
|  | ||||
| func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { | ||||
| func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { | ||||
| 	var splits []Split | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" | ||||
| 	_, err = transaction.Select(&splits, sql, accountid, user.UserId, begin, end) | ||||
| 	_, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| @@ -772,31 +615,19 @@ func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end | ||||
| 	for _, s := range splits { | ||||
| 		rat_amount, err := GetBigAmount(s.Amount) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &balance, nil | ||||
| } | ||||
|  | ||||
| func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { | ||||
| func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { | ||||
| 	var transactions []Transaction | ||||
| 	var atl AccountTransactionsList | ||||
|  | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var sqlsort, balanceLimitOffset string | ||||
| 	var balanceLimitOffsetArg uint64 | ||||
| 	if sort == "date-asc" { | ||||
| @@ -804,9 +635,8 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa | ||||
| 		balanceLimitOffset = " LIMIT ?" | ||||
| 		balanceLimitOffsetArg = page * limit | ||||
| 	} else if sort == "date-desc" { | ||||
| 		numSplits, err := transaction.SelectInt("SELECT count(*) FROM splits") | ||||
| 		numSplits, err := tx.SelectInt("SELECT count(*) FROM splits") | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		sqlsort = " ORDER BY transactions.Date DESC" | ||||
| @@ -819,41 +649,35 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa | ||||
| 		sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) | ||||
| 	} | ||||
|  | ||||
| 	account, err := GetAccountTx(transaction, accountid, user.UserId) | ||||
| 	account, err := GetAccountTx(tx, accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.Account = account | ||||
|  | ||||
| 	sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset | ||||
| 	_, err = transaction.Select(&transactions, sql, user.UserId, accountid, limit) | ||||
| 	_, err = tx.Select(&transactions, sql, user.UserId, accountid, limit) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.Transactions = &transactions | ||||
|  | ||||
| 	pageDifference, err := TransactionsBalanceDifference(transaction, accountid, transactions) | ||||
| 	pageDifference, err := TransactionsBalanceDifference(tx, accountid, transactions) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) | ||||
| 	count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.TotalTransactions = count | ||||
|  | ||||
| 	security, err := GetSecurityTx(transaction, atl.Account.SecurityId, user.UserId) | ||||
| 	security, err := GetSecurityTx(tx, atl.Account.SecurityId, user.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if security == nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, errors.New("Security not found") | ||||
| 	} | ||||
|  | ||||
| @@ -861,9 +685,8 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa | ||||
| 	// occurred before the page we're returning | ||||
| 	var amounts []string | ||||
| 	sql = "SELECT splits.Amount FROM splits WHERE splits.AccountId=? AND splits.TransactionId IN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ")" | ||||
| 	_, err = transaction.Select(&amounts, sql, accountid, user.UserId, accountid, balanceLimitOffsetArg) | ||||
| 	_, err = tx.Select(&amounts, sql, accountid, user.UserId, accountid, balanceLimitOffsetArg) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| @@ -871,7 +694,6 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa | ||||
| 	for _, amount := range amounts { | ||||
| 		rat_amount, err := GetBigAmount(amount) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| @@ -880,20 +702,12 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa | ||||
| 	atl.BeginningBalance = balance.FloatString(security.Precision) | ||||
| 	atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &atl, nil | ||||
| } | ||||
|  | ||||
| // Return only those transactions which have at least one split pertaining to | ||||
| // an account | ||||
| func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, | ||||
| 	user *User, accountid int64) { | ||||
|  | ||||
| func AccountTransactionsHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { | ||||
| 	var page uint64 = 0 | ||||
| 	var limit uint64 = 50 | ||||
| 	var sort string = "date-desc" | ||||
| @@ -904,8 +718,7 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, | ||||
| 	if pagestring != "" { | ||||
| 		p, err := strconv.ParseUint(pagestring, 10, 0) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		page = p | ||||
| 	} | ||||
| @@ -914,8 +727,7 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, | ||||
| 	if limitstring != "" { | ||||
| 		l, err := strconv.ParseUint(limitstring, 10, 0) | ||||
| 		if err != nil || l > 100 { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		limit = l | ||||
| 	} | ||||
| @@ -923,23 +735,16 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, | ||||
| 	sortstring := query.Get("sort") | ||||
| 	if sortstring != "" { | ||||
| 		if sortstring != "date-asc" && sortstring != "date-desc" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		sort = sortstring | ||||
| 	} | ||||
|  | ||||
| 	accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit) | ||||
| 	accountTransactions, err := GetAccountTransactions(tx, user, accountid, sort, page, limit) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 		return NewError(999 /*Internal Error*/) | ||||
| 	} | ||||
|  | ||||
| 	err = accountTransactions.Write(w) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
| 	return accountTransactions | ||||
| } | ||||
|   | ||||
| @@ -5,7 +5,6 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"gopkg.in/gorp.v1" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| @@ -47,61 +46,52 @@ func (u *User) HashPassword() { | ||||
| 	u.Password = "" | ||||
| } | ||||
|  | ||||
| func GetUser(db *DB, userid int64) (*User, error) { | ||||
| func GetUser(tx *Tx, userid int64) (*User, error) { | ||||
| 	var u User | ||||
|  | ||||
| 	err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid) | ||||
| 	err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) { | ||||
| func GetUserTx(tx *Tx, userid int64) (*User, error) { | ||||
| 	var u User | ||||
|  | ||||
| 	err := transaction.SelectOne(&u, "SELECT * from users where UserId=?", userid) | ||||
| 	err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func GetUserByUsername(db *DB, username string) (*User, error) { | ||||
| func GetUserByUsername(tx *Tx, username string) (*User, error) { | ||||
| 	var u User | ||||
|  | ||||
| 	err := db.SelectOne(&u, "SELECT * from users where Username=?", username) | ||||
| 	err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func InsertUser(db *DB, u *User) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| func InsertUser(tx *Tx, u *User) error { | ||||
| 	security_template := FindCurrencyTemplate(u.DefaultCurrency) | ||||
| 	if security_template == nil { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Invalid ISO4217 Default Currency") | ||||
| 	} | ||||
|  | ||||
| 	existing, err := transaction.SelectInt("SELECT count(*) from users where Username=?", u.Username) | ||||
| 	existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	if existing > 0 { | ||||
| 		transaction.Rollback() | ||||
| 		return UserExistsError{} | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Insert(u) | ||||
| 	err = tx.Insert(u) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| @@ -110,201 +100,146 @@ func InsertUser(db *DB, u *User) error { | ||||
| 	security = *security_template | ||||
| 	security.UserId = u.UserId | ||||
|  | ||||
| 	err = InsertSecurityTx(transaction, &security) | ||||
| 	err = InsertSecurityTx(tx, &security) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// Update the user's DefaultCurrency to our new SecurityId | ||||
| 	u.DefaultCurrency = security.SecurityId | ||||
| 	count, err := transaction.Update(u) | ||||
| 	count, err := tx.Update(u) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} else if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Would have updated more than one user") | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func GetUserFromSession(db *DB, r *http.Request) (*User, error) { | ||||
| 	s, err := GetSession(db, r) | ||||
| func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) { | ||||
| 	s, err := GetSession(tx, r) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return GetUser(db, s.UserId) | ||||
| 	return GetUser(tx, s.UserId) | ||||
| } | ||||
|  | ||||
| func UpdateUser(db *DB, u *User) error { | ||||
| 	transaction, err := db.Begin() | ||||
| func GetUserFromSessionTx(tx *Tx, r *http.Request) (*User, error) { | ||||
| 	s, err := GetSessionTx(tx, r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return GetUserTx(tx, s.UserId) | ||||
| } | ||||
|  | ||||
| 	security, err := GetSecurityTx(transaction, u.DefaultCurrency, u.UserId) | ||||
| func UpdateUser(tx *Tx, u *User) error { | ||||
| 	security, err := GetSecurityTx(tx, u.DefaultCurrency, u.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("UserId and DefaultCurrency don't match the fetched security") | ||||
| 	} else if security.Type != Currency { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("New DefaultCurrency security is not a currency") | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Update(u) | ||||
| 	count, err := tx.Update(u) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} else if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Would have updated more than one user") | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func DeleteUser(db *DB, u *User) error { | ||||
| 	transaction, err := db.Begin() | ||||
| func DeleteUser(tx *Tx, u *User) error { | ||||
| 	count, err := tx.Delete(u) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Delete(u) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return fmt.Errorf("No user to delete") | ||||
| 	} | ||||
| 	_, err = transaction.Exec("DELETE FROM prices WHERE prices.PriceId IN (SELECT prices.PriceId FROM prices INNER JOIN securities ON prices.SecurityId=securities.SecurityId WHERE securities.UserId=?)", u.UserId) | ||||
| 	_, err = tx.Exec("DELETE FROM prices WHERE prices.PriceId IN (SELECT prices.PriceId FROM prices INNER JOIN securities ON prices.SecurityId=securities.SecurityId WHERE securities.UserId=?)", u.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = transaction.Exec("DELETE FROM splits WHERE splits.SplitId IN (SELECT splits.SplitId FROM splits INNER JOIN transactions ON splits.TransactionId=transactions.TransactionId WHERE transactions.UserId=?)", u.UserId) | ||||
| 	_, err = tx.Exec("DELETE FROM splits WHERE splits.SplitId IN (SELECT splits.SplitId FROM splits INNER JOIN transactions ON splits.TransactionId=transactions.TransactionId WHERE transactions.UserId=?)", u.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = transaction.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId) | ||||
| 	_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = transaction.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId) | ||||
| 	_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = transaction.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId) | ||||
| 	_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = transaction.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId) | ||||
| 	_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = transaction.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId) | ||||
| 	_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| func UserHandler(r *http.Request, tx *Tx) ResponseWriterWriter { | ||||
| 	if r.Method == "POST" { | ||||
| 		user_json := r.PostFormValue("user") | ||||
| 		if user_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		var user User | ||||
| 		err := user.Read(user_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		user.UserId = -1 | ||||
| 		user.HashPassword() | ||||
|  | ||||
| 		err = InsertUser(db, &user) | ||||
| 		err = InsertUser(tx, &user) | ||||
| 		if err != nil { | ||||
| 			if _, ok := err.(UserExistsError); ok { | ||||
| 				WriteError(w, 4 /*User Exists*/) | ||||
| 				return NewError(4 /*User Exists*/) | ||||
| 			} else { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		w.WriteHeader(201 /*Created*/) | ||||
| 		err = user.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 		return ResponseWrapper{201, &user} | ||||
| 	} else { | ||||
| 		user, err := GetUserFromSession(db, r) | ||||
| 		user, err := GetUserFromSession(tx, r) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 1 /*Not Signed In*/) | ||||
| 			return | ||||
| 			return NewError(1 /*Not Signed In*/) | ||||
| 		} | ||||
|  | ||||
| 		userid, err := GetURLID(r.URL.Path) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		if userid != user.UserId { | ||||
| 			WriteError(w, 2 /*Unauthorized Access*/) | ||||
| 			return | ||||
| 			return NewError(2 /*Unauthorized Access*/) | ||||
| 		} | ||||
|  | ||||
| 		if r.Method == "GET" { | ||||
| 			err = user.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return user | ||||
| 		} else if r.Method == "PUT" { | ||||
| 			user_json := r.PostFormValue("user") | ||||
| 			if user_json == "" { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			// Save old PWHash in case the new password is bogus | ||||
| @@ -312,8 +247,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
|  | ||||
| 			err = user.Read(user_json) | ||||
| 			if err != nil || user.UserId != userid { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			// If the user didn't create a new password, keep their old one | ||||
| @@ -324,27 +258,21 @@ func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 				user.PasswordHash = old_pwhash | ||||
| 			} | ||||
|  | ||||
| 			err = UpdateUser(db, user) | ||||
| 			err = UpdateUser(tx, user) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
|  | ||||
| 			err = user.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			return user | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			err := DeleteUser(db, user) | ||||
| 			err := DeleteUser(tx, user) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
| 			WriteSuccess(w) | ||||
| 			return SuccessWriter{} | ||||
| 		} | ||||
| 	} | ||||
| 	return NewError(3 /*Invalid Request*/) | ||||
| } | ||||
|   | ||||
| @@ -18,6 +18,23 @@ func GetURLPieces(url string, format string, a ...interface{}) (int, error) { | ||||
| 	return fmt.Sscanf(url, format, a...) | ||||
| } | ||||
|  | ||||
| type ResponseWrapper struct { | ||||
| 	Code   int | ||||
| 	Writer ResponseWriterWriter | ||||
| } | ||||
|  | ||||
| func (r ResponseWrapper) Write(w http.ResponseWriter) error { | ||||
| 	w.WriteHeader(r.Code) | ||||
| 	return r.Writer.Write(w) | ||||
| } | ||||
|  | ||||
| type SuccessWriter struct{} | ||||
|  | ||||
| func (s SuccessWriter) Write(w http.ResponseWriter) error { | ||||
| 	fmt.Fprint(w, "{}") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func WriteSuccess(w http.ResponseWriter) { | ||||
| 	fmt.Fprint(w, "{}") | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user