mirror of
				https://github.com/aclindsa/moneygo.git
				synced 2025-10-30 09:33:25 -04:00 
			
		
		
		
	Merge pull request #35 from aclindsa/store_split
Split DB activity into 'store'
This commit is contained in:
		| @@ -20,12 +20,11 @@ env: | ||||
|     - MONEYGO_TEST_DB=mysql | ||||
|     - MONEYGO_TEST_DB=postgres | ||||
|  | ||||
| # OSX builds take too long, so don't wait for all of them | ||||
| # OSX builds take too long, so don't wait for them | ||||
| matrix: | ||||
|   fast_finish: true | ||||
|   allow_failures: | ||||
|     - os: osx | ||||
|       go: master | ||||
|  | ||||
| before_install: | ||||
|   # Fetch/build coverage reporting tools | ||||
|   | ||||
| @@ -3,43 +3,22 @@ package handlers | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) { | ||||
| 	var a models.Account | ||||
|  | ||||
| 	err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &a, nil | ||||
| } | ||||
|  | ||||
| func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) { | ||||
| 	var accounts []models.Account | ||||
|  | ||||
| 	_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &accounts, nil | ||||
| } | ||||
|  | ||||
| // Get (and attempt to create if it doesn't exist). Matches on UserId, | ||||
| // SecurityId, Type, Name, and ParentAccountId | ||||
| func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) { | ||||
| 	var accounts []models.Account | ||||
| func GetCreateAccount(tx store.Tx, a models.Account) (*models.Account, error) { | ||||
| 	var account models.Account | ||||
|  | ||||
| 	// Try to find the top-level trading account | ||||
| 	_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId) | ||||
| 	accounts, err := tx.FindMatchingAccounts(&a) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if len(accounts) == 1 { | ||||
| 		account = accounts[0] | ||||
| 	if len(*accounts) > 0 { | ||||
| 		account = *(*accounts)[0] | ||||
| 	} else { | ||||
| 		account.UserId = a.UserId | ||||
| 		account.SecurityId = a.SecurityId | ||||
| @@ -47,7 +26,7 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) { | ||||
| 		account.Name = a.Name | ||||
| 		account.ParentAccountId = a.ParentAccountId | ||||
|  | ||||
| 		err = tx.Insert(&account) | ||||
| 		err = tx.InsertAccount(&account) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @@ -57,11 +36,11 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) { | ||||
|  | ||||
| // Get (and attempt to create if it doesn't exist) the security/currency | ||||
| // trading account for the supplied security/currency | ||||
| func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { | ||||
| func GetTradingAccount(tx store.Tx, userid int64, securityid int64) (*models.Account, error) { | ||||
| 	var tradingAccount models.Account | ||||
| 	var account models.Account | ||||
|  | ||||
| 	user, err := GetUser(tx, userid) | ||||
| 	user, err := tx.GetUser(userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -78,7 +57,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	security, err := GetSecurity(tx, securityid, userid) | ||||
| 	security, err := tx.GetSecurity(securityid, userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -99,7 +78,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, | ||||
|  | ||||
| // Get (and attempt to create if it doesn't exist) the security/currency | ||||
| // imbalance account for the supplied security/currency | ||||
| func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { | ||||
| func GetImbalanceAccount(tx store.Tx, userid int64, securityid int64) (*models.Account, error) { | ||||
| 	var imbalanceAccount models.Account | ||||
| 	var account models.Account | ||||
| 	xxxtemplate := FindSecurityTemplate("XXX", models.Currency) | ||||
| @@ -123,7 +102,7 @@ func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Accoun | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	security, err := GetSecurity(tx, securityid, userid) | ||||
| 	security, err := tx.GetSecurity(securityid, userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -142,120 +121,6 @@ func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Accoun | ||||
| 	return a, nil | ||||
| } | ||||
|  | ||||
| type ParentAccountMissingError struct{} | ||||
|  | ||||
| func (pame ParentAccountMissingError) Error() string { | ||||
| 	return "Parent account missing" | ||||
| } | ||||
|  | ||||
| type TooMuchNestingError struct{} | ||||
|  | ||||
| func (tmne TooMuchNestingError) Error() string { | ||||
| 	return "Too much nesting" | ||||
| } | ||||
|  | ||||
| type CircularAccountsError struct{} | ||||
|  | ||||
| func (cae CircularAccountsError) Error() string { | ||||
| 	return "Would result in circular account relationship" | ||||
| } | ||||
|  | ||||
| func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error { | ||||
| 	found := make(map[int64]bool) | ||||
| 	if !insert { | ||||
| 		found[a.AccountId] = true | ||||
| 	} | ||||
| 	parentid := a.ParentAccountId | ||||
| 	depth := 0 | ||||
| 	for parentid != -1 { | ||||
| 		depth += 1 | ||||
| 		if depth > 100 { | ||||
| 			return TooMuchNestingError{} | ||||
| 		} | ||||
|  | ||||
| 		var a models.Account | ||||
| 		err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) | ||||
| 		if err != nil { | ||||
| 			return ParentAccountMissingError{} | ||||
| 		} | ||||
|  | ||||
| 		// Insertion by itself can never result in circular dependencies | ||||
| 		if insert { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		found[parentid] = true | ||||
| 		parentid = a.ParentAccountId | ||||
| 		if _, ok := found[parentid]; ok { | ||||
| 			return CircularAccountsError{} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if insert { | ||||
| 		err := tx.Insert(a) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		oldacct, err := GetAccount(tx, a.AccountId, a.UserId) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		a.AccountVersion = oldacct.AccountVersion + 1 | ||||
|  | ||||
| 		count, err := tx.Update(a) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if count != 1 { | ||||
| 			return errors.New("Updated more than one account") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func InsertAccount(tx *Tx, a *models.Account) error { | ||||
| 	return insertUpdateAccount(tx, a, true) | ||||
| } | ||||
|  | ||||
| func UpdateAccount(tx *Tx, a *models.Account) error { | ||||
| 	return insertUpdateAccount(tx, a, false) | ||||
| } | ||||
|  | ||||
| func DeleteAccount(tx *Tx, a *models.Account) error { | ||||
| 	if a.ParentAccountId != -1 { | ||||
| 		// Re-parent splits to this account's parent account if this account isn't a root account | ||||
| 		_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		// Delete splits if this account is a root account | ||||
| 		_, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Re-parent child accounts to this account's parent account | ||||
| 	_, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.Delete(a) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return errors.New("Was going to delete more than one account") | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 	user, err := GetUserFromSession(context.Tx, r) | ||||
| 	if err != nil { | ||||
| @@ -279,7 +144,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 		account.UserId = user.UserId | ||||
| 		account.AccountVersion = 0 | ||||
|  | ||||
| 		security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId) | ||||
| 		security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId) | ||||
| 		if err != nil { | ||||
| 			log.Print(err) | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| @@ -288,9 +153,9 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		err = InsertAccount(context.Tx, &account) | ||||
| 		err = context.Tx.InsertAccount(&account) | ||||
| 		if err != nil { | ||||
| 			if _, ok := err.(ParentAccountMissingError); ok { | ||||
| 			if _, ok := err.(store.ParentAccountMissingError); ok { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} else { | ||||
| 				log.Print(err) | ||||
| @@ -303,7 +168,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 		if context.LastLevel() { | ||||
| 			//Return all Accounts | ||||
| 			var al models.AccountList | ||||
| 			accounts, err := GetAccounts(context.Tx, user.UserId) | ||||
| 			accounts, err := context.Tx.GetAccounts(user.UserId) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| @@ -319,7 +184,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
|  | ||||
| 		if context.LastLevel() { | ||||
| 			// Return Account with this Id | ||||
| 			account, err := GetAccount(context.Tx, accountid, user.UserId) | ||||
| 			account, err := context.Tx.GetAccount(accountid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| @@ -340,7 +205,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 			} | ||||
| 			account.UserId = user.UserId | ||||
|  | ||||
| 			security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId) | ||||
| 			security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| @@ -353,11 +218,11 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = UpdateAccount(context.Tx, &account) | ||||
| 			err = context.Tx.UpdateAccount(&account) | ||||
| 			if err != nil { | ||||
| 				if _, ok := err.(ParentAccountMissingError); ok { | ||||
| 				if _, ok := err.(store.ParentAccountMissingError); ok { | ||||
| 					return NewError(3 /*Invalid Request*/) | ||||
| 				} else if _, ok := err.(CircularAccountsError); ok { | ||||
| 				} else if _, ok := err.(store.CircularAccountsError); ok { | ||||
| 					return NewError(3 /*Invalid Request*/) | ||||
| 				} else { | ||||
| 					log.Print(err) | ||||
| @@ -367,12 +232,12 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
|  | ||||
| 			return &account | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			account, err := GetAccount(context.Tx, accountid, user.UserId) | ||||
| 			account, err := context.Tx.GetAccount(accountid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteAccount(context.Tx, account) | ||||
| 			err = context.Tx.DeleteAccount(account) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
|   | ||||
| @@ -4,8 +4,8 @@ import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| 	"math/big" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| @@ -16,7 +16,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) { | ||||
|  | ||||
| 	ctx := L.Context() | ||||
|  | ||||
| 	tx, ok := ctx.Value(dbContextKey).(*Tx) | ||||
| 	tx, ok := ctx.Value(dbContextKey).(store.Tx) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
| @@ -28,14 +28,14 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) { | ||||
| 			return nil, errors.New("Couldn't find User in lua's Context") | ||||
| 		} | ||||
|  | ||||
| 		accounts, err := GetAccounts(tx, user.UserId) | ||||
| 		accounts, err := tx.GetAccounts(user.UserId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		account_map = make(map[int64]*models.Account) | ||||
| 		for i := range *accounts { | ||||
| 			account_map[(*accounts)[i].AccountId] = &(*accounts)[i] | ||||
| 			account_map[(*accounts)[i].AccountId] = (*accounts)[i] | ||||
| 		} | ||||
|  | ||||
| 		ctx = context.WithValue(ctx, accountsContextKey, account_map) | ||||
| @@ -150,7 +150,7 @@ func luaAccountBalance(L *lua.LState) int { | ||||
| 	a := luaCheckAccount(L, 1) | ||||
|  | ||||
| 	ctx := L.Context() | ||||
| 	tx, ok := ctx.Value(dbContextKey).(*Tx) | ||||
| 	tx, ok := ctx.Value(dbContextKey).(store.Tx) | ||||
| 	if !ok { | ||||
| 		panic("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
| @@ -167,24 +167,29 @@ func luaAccountBalance(L *lua.LState) int { | ||||
| 		panic("SecurityId not in lua security_map") | ||||
| 	} | ||||
| 	date := luaWeakCheckTime(L, 2) | ||||
| 	var b Balance | ||||
| 	var rat *big.Rat | ||||
| 	var splits *[]*models.Split | ||||
| 	if date != nil { | ||||
| 		end := luaWeakCheckTime(L, 3) | ||||
| 		if end != nil { | ||||
| 			rat, err = GetAccountBalanceDateRange(tx, user, a.AccountId, date, end) | ||||
| 			splits, err = tx.GetAccountSplitsDateRange(user, a.AccountId, date, end) | ||||
| 		} else { | ||||
| 			rat, err = GetAccountBalanceDate(tx, user, a.AccountId, date) | ||||
| 			splits, err = tx.GetAccountSplitsDate(user, a.AccountId, date) | ||||
| 		} | ||||
| 	} else { | ||||
| 		rat, err = GetAccountBalance(tx, user, a.AccountId) | ||||
| 		splits, err = tx.GetAccountSplits(user, a.AccountId) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		panic("Failed to GetAccountBalance:" + err.Error()) | ||||
| 		panic("Failed to fetch splits for account:" + err.Error()) | ||||
| 	} | ||||
| 	b.Amount = rat | ||||
| 	b.Security = security | ||||
| 	L.Push(BalanceToLua(L, &b)) | ||||
| 	rat, err := BalanceFromSplits(splits) | ||||
| 	if err != nil { | ||||
| 		panic("Failed to calculate balance for account:" + err.Error()) | ||||
| 	} | ||||
| 	b := &Balance{ | ||||
| 		Amount:   rat, | ||||
| 		Security: security, | ||||
| 	} | ||||
| 	L.Push(BalanceToLua(L, b)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|   | ||||
| @@ -2,12 +2,11 @@ package handlers_test | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"github.com/aclindsa/moneygo/internal/config" | ||||
| 	"github.com/aclindsa/moneygo/internal/db" | ||||
| 	"github.com/aclindsa/moneygo/internal/handlers" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store/db" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"log" | ||||
| @@ -253,24 +252,15 @@ func RunTests(m *testing.M) int { | ||||
| 		dsn = envDSN | ||||
| 	} | ||||
|  | ||||
| 	dsn = db.GetDSN(dbType, dsn) | ||||
| 	database, err := sql.Open(dbType.String(), dsn) | ||||
| 	db, err := db.GetStore(dbType, dsn) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| 	defer database.Close() | ||||
| 	defer db.Close() | ||||
|  | ||||
| 	dbmap, err := db.GetDbMap(database, dbType) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| 	db.Empty() // clear the DB tables | ||||
|  | ||||
| 	err = dbmap.TruncateTables() | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	server = httptest.NewTLSServer(&handlers.APIHandler{DB: dbmap}) | ||||
| 	server = httptest.NewTLSServer(&handlers.APIHandler{Store: db}) | ||||
| 	defer server.Close() | ||||
|  | ||||
| 	return m.Run() | ||||
|   | ||||
| @@ -437,7 +437,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite | ||||
| 			} | ||||
| 			split.AccountId = acctId | ||||
|  | ||||
| 			exists, err := SplitAlreadyImported(context.Tx, split) | ||||
| 			exists, err := context.Tx.SplitExists(split) | ||||
| 			if err != nil { | ||||
| 				log.Print("Error checking if split was already imported:", err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| @@ -446,7 +446,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite | ||||
| 			} | ||||
| 		} | ||||
| 		if !already_imported { | ||||
| 			err := InsertTransaction(context.Tx, &transaction, user) | ||||
| 			err := context.Tx.InsertTransaction(&transaction, user) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
|   | ||||
| @@ -38,13 +38,13 @@ func TestImportGnucash(t *testing.T) { | ||||
| 		} | ||||
| 		for i, account := range *accounts.Accounts { | ||||
| 			if account.Name == "Income" && account.Type == models.Income && account.ParentAccountId == -1 { | ||||
| 				income = &(*accounts.Accounts)[i] | ||||
| 				income = (*accounts.Accounts)[i] | ||||
| 			} else if account.Name == "Equity" && account.Type == models.Equity && account.ParentAccountId == -1 { | ||||
| 				equity = &(*accounts.Accounts)[i] | ||||
| 				equity = (*accounts.Accounts)[i] | ||||
| 			} else if account.Name == "Liabilities" && account.Type == models.Liability && account.ParentAccountId == -1 { | ||||
| 				liabilities = &(*accounts.Accounts)[i] | ||||
| 				liabilities = (*accounts.Accounts)[i] | ||||
| 			} else if account.Name == "Expenses" && account.Type == models.Expense && account.ParentAccountId == -1 { | ||||
| 				expenses = &(*accounts.Accounts)[i] | ||||
| 				expenses = (*accounts.Accounts)[i] | ||||
| 			} | ||||
| 		} | ||||
| 		if income == nil { | ||||
| @@ -61,15 +61,15 @@ func TestImportGnucash(t *testing.T) { | ||||
| 		} | ||||
| 		for i, account := range *accounts.Accounts { | ||||
| 			if account.Name == "Salary" && account.Type == models.Income && account.ParentAccountId == income.AccountId { | ||||
| 				salary = &(*accounts.Accounts)[i] | ||||
| 				salary = (*accounts.Accounts)[i] | ||||
| 			} else if account.Name == "Opening Balances" && account.Type == models.Equity && account.ParentAccountId == equity.AccountId { | ||||
| 				openingbalances = &(*accounts.Accounts)[i] | ||||
| 				openingbalances = (*accounts.Accounts)[i] | ||||
| 			} else if account.Name == "Credit Card" && account.Type == models.Liability && account.ParentAccountId == liabilities.AccountId { | ||||
| 				creditcard = &(*accounts.Accounts)[i] | ||||
| 				creditcard = (*accounts.Accounts)[i] | ||||
| 			} else if account.Name == "Groceries" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId { | ||||
| 				groceries = &(*accounts.Accounts)[i] | ||||
| 				groceries = (*accounts.Accounts)[i] | ||||
| 			} else if account.Name == "Cable" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId { | ||||
| 				cable = &(*accounts.Accounts)[i] | ||||
| 				cable = (*accounts.Accounts)[i] | ||||
| 			} | ||||
| 		} | ||||
| 		if salary == nil { | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"github.com/aclindsa/gorp" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"github.com/aclindsa/moneygo/internal/store/db" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"path" | ||||
| @@ -16,7 +17,7 @@ type ResponseWriterWriter interface { | ||||
| } | ||||
|  | ||||
| type Context struct { | ||||
| 	Tx           *Tx | ||||
| 	Tx           store.Tx | ||||
| 	User         *models.User | ||||
| 	remainingURL string // portion of URL path not yet reached in the hierarchy | ||||
| } | ||||
| @@ -46,11 +47,11 @@ func (c *Context) LastLevel() bool { | ||||
| type Handler func(*http.Request, *Context) ResponseWriterWriter | ||||
|  | ||||
| type APIHandler struct { | ||||
| 	DB *gorp.DbMap | ||||
| 	Store *db.DbStore | ||||
| } | ||||
|  | ||||
| func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { | ||||
| 	tx, err := GetTx(ah.DB) | ||||
| 	tx, err := ah.Store.Begin() | ||||
| 	if err != nil { | ||||
| 		log.Print(err) | ||||
| 		return NewError(999 /*Internal Error*/) | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package handlers | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"github.com/aclindsa/ofxgo" | ||||
| 	"io" | ||||
| 	"log" | ||||
| @@ -23,7 +24,7 @@ func (od *OFXDownload) Read(json_str string) error { | ||||
| 	return dec.Decode(od) | ||||
| } | ||||
|  | ||||
| func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { | ||||
| func ofxImportHelper(tx store.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { | ||||
| 	itl, err := ImportOFX(r) | ||||
|  | ||||
| 	if err != nil { | ||||
| @@ -38,7 +39,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re | ||||
| 	} | ||||
|  | ||||
| 	// Return Account with this Id | ||||
| 	account, err := GetAccount(tx, accountid, user.UserId) | ||||
| 	account, err := tx.GetAccount(accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		log.Print(err) | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| @@ -158,7 +159,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re | ||||
| 				split := new(models.Split) | ||||
| 				r := new(big.Rat) | ||||
| 				r.Neg(&imbalance) | ||||
| 				security, err := GetSecurity(tx, imbalanced_security, user.UserId) | ||||
| 				security, err := tx.GetSecurity(imbalanced_security, user.UserId) | ||||
| 				if err != nil { | ||||
| 					log.Print(err) | ||||
| 					return NewError(999 /*Internal Error*/) | ||||
| @@ -186,7 +187,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re | ||||
| 				split.SecurityId = -1 | ||||
| 			} | ||||
|  | ||||
| 			exists, err := SplitAlreadyImported(tx, split) | ||||
| 			exists, err := tx.SplitExists(split) | ||||
| 			if err != nil { | ||||
| 				log.Print("Error checking if split was already imported:", err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| @@ -201,7 +202,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re | ||||
| 	} | ||||
|  | ||||
| 	for _, transaction := range transactions { | ||||
| 		err := InsertTransaction(tx, &transaction, user) | ||||
| 		err := tx.InsertTransaction(&transaction, user) | ||||
| 		if err != nil { | ||||
| 			log.Print(err) | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| @@ -217,7 +218,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *models.User, acco | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	account, err := GetAccount(context.Tx, accountid, user.UserId) | ||||
| 	account, err := context.Tx.GetAccount(accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|   | ||||
| @@ -83,7 +83,7 @@ func findAccount(client *http.Client, name string, tipe models.AccountType, secu | ||||
| 	} | ||||
| 	for _, account := range *accounts.Accounts { | ||||
| 		if account.Name == name && account.Type == tipe && account.SecurityId == securityid { | ||||
| 			return &account, nil | ||||
| 			return account, nil | ||||
| 		} | ||||
| 	} | ||||
| 	return nil, fmt.Errorf("Unable to find account: \"%s\"", name) | ||||
|   | ||||
| @@ -2,82 +2,41 @@ package handlers | ||||
|  | ||||
| import ( | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func CreatePriceIfNotExist(tx *Tx, price *models.Price) error { | ||||
| func CreatePriceIfNotExist(tx store.Tx, price *models.Price) error { | ||||
| 	if len(price.RemoteId) == 0 { | ||||
| 		// Always create a new price if we can't match on the RemoteId | ||||
| 		err := tx.Insert(price) | ||||
| 		err := tx.InsertPrice(price) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	var prices []*models.Price | ||||
|  | ||||
| 	_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) | ||||
| 	exists, err := tx.PriceExists(price) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if len(prices) > 0 { | ||||
| 	if exists { | ||||
| 		return nil // price already exists | ||||
| 	} | ||||
|  | ||||
| 	err = tx.Insert(price) | ||||
| 	err = tx.InsertPrice(price) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) { | ||||
| 	var p models.Price | ||||
| 	err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &p, nil | ||||
| } | ||||
|  | ||||
| func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) { | ||||
| 	var prices []*models.Price | ||||
|  | ||||
| 	_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &prices, nil | ||||
| } | ||||
|  | ||||
| // Return the latest price for security in currency units before date | ||||
| func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { | ||||
| 	var p models.Price | ||||
| 	err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &p, nil | ||||
| } | ||||
|  | ||||
| // Return the earliest price for security in currency units after date | ||||
| func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { | ||||
| 	var p models.Price | ||||
| 	err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &p, nil | ||||
| } | ||||
|  | ||||
| // Return the price for security in currency closest to date | ||||
| func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { | ||||
| 	earliest, _ := GetEarliestPrice(tx, security, currency, date) | ||||
| 	latest, err := GetLatestPrice(tx, security, currency, date) | ||||
| func GetClosestPrice(tx store.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { | ||||
| 	earliest, _ := tx.GetEarliestPrice(security, currency, date) | ||||
| 	latest, err := tx.GetLatestPrice(security, currency, date) | ||||
|  | ||||
| 	// Return early if either earliest or latest are invalid | ||||
| 	if earliest == nil { | ||||
| @@ -96,7 +55,7 @@ func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Tim | ||||
| } | ||||
|  | ||||
| func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter { | ||||
| 	security, err := GetSecurity(context.Tx, securityid, user.UserId) | ||||
| 	security, err := context.Tx.GetSecurity(securityid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
| @@ -111,12 +70,12 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security | ||||
| 		if price.SecurityId != security.SecurityId { | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| 		_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId) | ||||
| 		_, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId) | ||||
| 		if err != nil { | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		err = context.Tx.Insert(&price) | ||||
| 		err = context.Tx.InsertPrice(&price) | ||||
| 		if err != nil { | ||||
| 			log.Print(err) | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| @@ -128,7 +87,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security | ||||
| 			//Return all this security's prices | ||||
| 			var pl models.PriceList | ||||
|  | ||||
| 			prices, err := GetPrices(context.Tx, security.SecurityId) | ||||
| 			prices, err := context.Tx.GetPrices(security.SecurityId) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| @@ -143,7 +102,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		price, err := GetPrice(context.Tx, priceid, security.SecurityId) | ||||
| 		price, err := context.Tx.GetPrice(priceid, security.SecurityId) | ||||
| 		if err != nil { | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
| @@ -160,30 +119,30 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			_, err = GetSecurity(context.Tx, price.SecurityId, user.UserId) | ||||
| 			_, err = context.Tx.GetSecurity(price.SecurityId, user.UserId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| 			_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId) | ||||
| 			_, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			count, err := context.Tx.Update(&price) | ||||
| 			if err != nil || count != 1 { | ||||
| 			err = context.Tx.UpdatePrice(&price) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
|  | ||||
| 			return &price | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			price, err := GetPrice(context.Tx, priceid, security.SecurityId) | ||||
| 			price, err := context.Tx.GetPrice(priceid, security.SecurityId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			count, err := context.Tx.Delete(price) | ||||
| 			if err != nil || count != 1 { | ||||
| 			err = context.Tx.DeletePrice(price) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| 			} | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| @@ -24,57 +25,7 @@ const ( | ||||
|  | ||||
| const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for | ||||
|  | ||||
| func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) { | ||||
| 	var r models.Report | ||||
|  | ||||
| 	err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &r, nil | ||||
| } | ||||
|  | ||||
| func GetReports(tx *Tx, userid int64) (*[]models.Report, error) { | ||||
| 	var reports []models.Report | ||||
|  | ||||
| 	_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &reports, nil | ||||
| } | ||||
|  | ||||
| func InsertReport(tx *Tx, r *models.Report) error { | ||||
| 	err := tx.Insert(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func UpdateReport(tx *Tx, r *models.Report) error { | ||||
| 	count, err := tx.Update(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return errors.New("Updated more than one report") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func DeleteReport(tx *Tx, r *models.Report) error { | ||||
| 	count, err := tx.Delete(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return errors.New("Deleted more than one report") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { | ||||
| func runReport(tx store.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { | ||||
| 	// Create a new LState without opening the default libs for security | ||||
| 	L := lua.NewState(lua.Options{SkipOpenLibs: true}) | ||||
| 	defer L.Close() | ||||
| @@ -138,8 +89,8 @@ func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabula | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { | ||||
| 	report, err := GetReport(tx, reportid, user.UserId) | ||||
| func ReportTabulationHandler(tx store.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { | ||||
| 	report, err := tx.GetReport(reportid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
| @@ -174,7 +125,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		err = InsertReport(context.Tx, &report) | ||||
| 		err = context.Tx.InsertReport(&report) | ||||
| 		if err != nil { | ||||
| 			log.Print(err) | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| @@ -185,7 +136,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 		if context.LastLevel() { | ||||
| 			//Return all Reports | ||||
| 			var rl models.ReportList | ||||
| 			reports, err := GetReports(context.Tx, user.UserId) | ||||
| 			reports, err := context.Tx.GetReports(user.UserId) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| @@ -203,7 +154,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 			return ReportTabulationHandler(context.Tx, r, user, reportid) | ||||
| 		} else { | ||||
| 			// Return Report with this Id | ||||
| 			report, err := GetReport(context.Tx, reportid, user.UserId) | ||||
| 			report, err := context.Tx.GetReport(reportid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| @@ -227,7 +178,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = UpdateReport(context.Tx, &report) | ||||
| 			err = context.Tx.UpdateReport(&report) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| @@ -235,12 +186,12 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
|  | ||||
| 			return &report | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			report, err := GetReport(context.Tx, reportid, user.UserId) | ||||
| 			report, err := context.Tx.GetReport(reportid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteReport(context.Tx, report) | ||||
| 			err = context.Tx.DeleteReport(report) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
|   | ||||
| @@ -4,8 +4,8 @@ package handlers | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| @@ -50,108 +50,34 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) { | ||||
| 	var s models.Security | ||||
|  | ||||
| 	err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) { | ||||
| 	var securities []*models.Security | ||||
|  | ||||
| 	_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &securities, nil | ||||
| } | ||||
|  | ||||
| func InsertSecurity(tx *Tx, s *models.Security) error { | ||||
| 	err := tx.Insert(s) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func UpdateSecurity(tx *Tx, s *models.Security) (err error) { | ||||
| 	user, err := GetUser(tx, s.UserId) | ||||
| func UpdateSecurity(tx store.Tx, s *models.Security) (err error) { | ||||
| 	user, err := tx.GetUser(s.UserId) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency { | ||||
| 		return errors.New("Cannot change security which is user's default currency to be non-currency") | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.Update(s) | ||||
| 	err = tx.UpdateSecurity(s) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	if count > 1 { | ||||
| 		return fmt.Errorf("Updated %d securities (expected 1)", count) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type SecurityInUseError struct { | ||||
| 	message string | ||||
| } | ||||
|  | ||||
| func (e SecurityInUseError) Error() string { | ||||
| 	return e.message | ||||
| } | ||||
|  | ||||
| func DeleteSecurity(tx *Tx, s *models.Security) error { | ||||
| 	// First, ensure no accounts are using this security | ||||
| 	accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) | ||||
|  | ||||
| 	if accounts != 0 { | ||||
| 		return SecurityInUseError{"One or more accounts still use this security"} | ||||
| 	} | ||||
|  | ||||
| 	user, err := GetUser(tx, s.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} else if user.DefaultCurrency == s.SecurityId { | ||||
| 		return SecurityInUseError{"Cannot delete security which is user's default currency"} | ||||
| 	} | ||||
|  | ||||
| 	// Remove all prices involving this security (either of this security, or | ||||
| 	// using it as a currency) | ||||
| 	_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.Delete(s) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return errors.New("Deleted more than one security") | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) { | ||||
| func ImportGetCreateSecurity(tx store.Tx, userid int64, security *models.Security) (*models.Security, error) { | ||||
| 	security.UserId = userid | ||||
| 	if len(security.AlternateId) == 0 { | ||||
| 		// Always create a new local security if we can't match on the AlternateId | ||||
| 		err := InsertSecurity(tx, security) | ||||
| 		err := tx.InsertSecurity(security) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		return security, nil | ||||
| 	} | ||||
|  | ||||
| 	var securities []*models.Security | ||||
|  | ||||
| 	_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision) | ||||
| 	securities, err := tx.FindMatchingSecurities(security) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -159,7 +85,7 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (* | ||||
| 	// First try to find a case insensitive match on the name or symbol | ||||
| 	upperName := strings.ToUpper(security.Name) | ||||
| 	upperSymbol := strings.ToUpper(security.Symbol) | ||||
| 	for _, s := range securities { | ||||
| 	for _, s := range *securities { | ||||
| 		if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) || | ||||
| 			(len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) { | ||||
| 			return s, nil | ||||
| @@ -168,7 +94,7 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (* | ||||
| 	//		if strings.Contains(strings.ToUpper(security.Name), upperSearch) || | ||||
|  | ||||
| 	// Try to find a partial string match on the name or symbol | ||||
| 	for _, s := range securities { | ||||
| 	for _, s := range *securities { | ||||
| 		sUpperName := strings.ToUpper(s.Name) | ||||
| 		sUpperSymbol := strings.ToUpper(s.Symbol) | ||||
| 		if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) || | ||||
| @@ -178,12 +104,12 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (* | ||||
| 	} | ||||
|  | ||||
| 	// Give up and return the first security in the list | ||||
| 	if len(securities) > 0 { | ||||
| 		return securities[0], nil | ||||
| 	if len(*securities) > 0 { | ||||
| 		return (*securities)[0], nil | ||||
| 	} | ||||
|  | ||||
| 	// If there wasn't even one security in the list, make a new one | ||||
| 	err = InsertSecurity(tx, security) | ||||
| 	err = tx.InsertSecurity(security) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -216,7 +142,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 		security.SecurityId = -1 | ||||
| 		security.UserId = user.UserId | ||||
|  | ||||
| 		err = InsertSecurity(context.Tx, &security) | ||||
| 		err = context.Tx.InsertSecurity(&security) | ||||
| 		if err != nil { | ||||
| 			log.Print(err) | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| @@ -228,7 +154,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 			//Return all securities | ||||
| 			var sl models.SecurityList | ||||
|  | ||||
| 			securities, err := GetSecurities(context.Tx, user.UserId) | ||||
| 			securities, err := context.Tx.GetSecurities(user.UserId) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| @@ -249,7 +175,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 				return PriceHandler(r, context, user, securityid) | ||||
| 			} | ||||
|  | ||||
| 			security, err := GetSecurity(context.Tx, securityid, user.UserId) | ||||
| 			security, err := context.Tx.GetSecurity(securityid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| @@ -283,13 +209,13 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
|  | ||||
| 			return &security | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			security, err := GetSecurity(context.Tx, securityid, user.UserId) | ||||
| 			security, err := context.Tx.GetSecurity(securityid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteSecurity(context.Tx, security) | ||||
| 			if _, ok := err.(SecurityInUseError); ok { | ||||
| 			err = context.Tx.DeleteSecurity(security) | ||||
| 			if _, ok := err.(store.SecurityInUseError); ok { | ||||
| 				return NewError(7 /*In Use Error*/) | ||||
| 			} else if err != nil { | ||||
| 				log.Print(err) | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| ) | ||||
|  | ||||
| @@ -14,7 +15,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error) | ||||
|  | ||||
| 	ctx := L.Context() | ||||
|  | ||||
| 	tx, ok := ctx.Value(dbContextKey).(*Tx) | ||||
| 	tx, ok := ctx.Value(dbContextKey).(store.Tx) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
| @@ -26,7 +27,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error) | ||||
| 			return nil, errors.New("Couldn't find User in lua's Context") | ||||
| 		} | ||||
|  | ||||
| 		securities, err := GetSecurities(tx, user.UserId) | ||||
| 		securities, err := tx.GetSecurities(user.UserId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @@ -158,7 +159,7 @@ func luaClosestPrice(L *lua.LState) int { | ||||
| 	date := luaCheckTime(L, 3) | ||||
|  | ||||
| 	ctx := L.Context() | ||||
| 	tx, ok := ctx.Value(dbContextKey).(*Tx) | ||||
| 	tx, ok := ctx.Value(dbContextKey).(store.Tx) | ||||
| 	if !ok { | ||||
| 		panic("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
|   | ||||
| @@ -3,36 +3,37 @@ package handlers | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func GetSession(tx *Tx, r *http.Request) (*models.Session, error) { | ||||
| 	var s models.Session | ||||
|  | ||||
| func GetSession(tx store.Tx, r *http.Request) (*models.Session, error) { | ||||
| 	cookie, err := r.Cookie("moneygo-session") | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("moneygo-session cookie not set") | ||||
| 	} | ||||
| 	s.SessionSecret = cookie.Value | ||||
|  | ||||
| 	err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) | ||||
| 	s, err := tx.GetSession(cookie.Value) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if s.Expires.Before(time.Now()) { | ||||
| 		tx.Delete(&s) | ||||
| 		err := tx.DeleteSession(s) | ||||
| 		if err != nil { | ||||
| 			log.Printf("Unexpected error when attempting to delete expired session: %s", err) | ||||
| 		} | ||||
| 		return nil, fmt.Errorf("Session has expired") | ||||
| 	} | ||||
| 	return &s, nil | ||||
| 	return s, nil | ||||
| } | ||||
|  | ||||
| func DeleteSessionIfExists(tx *Tx, r *http.Request) error { | ||||
| func DeleteSessionIfExists(tx store.Tx, r *http.Request) error { | ||||
| 	session, err := GetSession(tx, r) | ||||
| 	if err == nil { | ||||
| 		_, err := tx.Delete(session) | ||||
| 		err := tx.DeleteSession(session) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -50,21 +51,26 @@ func (n *NewSessionWriter) Write(w http.ResponseWriter) error { | ||||
| 	return n.session.Write(w) | ||||
| } | ||||
|  | ||||
| func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { | ||||
| func NewSession(tx store.Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { | ||||
| 	err := DeleteSessionIfExists(tx, r) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	s, err := models.NewSession(userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", s.SessionSecret) | ||||
| 	exists, err := tx.SessionExists(s.SessionSecret) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if existing > 0 { | ||||
| 		return nil, fmt.Errorf("%d session(s) exist with the generated session_secret", existing) | ||||
| 	if exists { | ||||
| 		return nil, fmt.Errorf("Session already exists with the generated session_secret") | ||||
| 	} | ||||
|  | ||||
| 	err = tx.Insert(s) | ||||
| 	err = tx.InsertSession(s) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -79,22 +85,19 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		dbuser, err := GetUserByUsername(context.Tx, user.Username) | ||||
| 		// Hash password before checking username to help mitigate timing | ||||
| 		// attacks | ||||
| 		user.HashPassword() | ||||
|  | ||||
| 		dbuser, err := context.Tx.GetUserByUsername(user.Username) | ||||
| 		if err != nil { | ||||
| 			return NewError(2 /*Unauthorized Access*/) | ||||
| 		} | ||||
|  | ||||
| 		user.HashPassword() | ||||
| 		if user.PasswordHash != dbuser.PasswordHash { | ||||
| 			return NewError(2 /*Unauthorized Access*/) | ||||
| 		} | ||||
|  | ||||
| 		err = DeleteSessionIfExists(context.Tx, r) | ||||
| 		if err != nil { | ||||
| 			log.Print(err) | ||||
| 			return NewError(999 /*Internal Error*/) | ||||
| 		} | ||||
|  | ||||
| 		sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId) | ||||
| 		if err != nil { | ||||
| 			log.Print(err) | ||||
|   | ||||
| @@ -2,24 +2,18 @@ package handlers | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"log" | ||||
| 	"math/big" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func SplitAlreadyImported(tx *Tx, s *models.Split) (bool, error) { | ||||
| 	count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) | ||||
| 	return count == 1, err | ||||
| } | ||||
|  | ||||
| // Return a map of security ID's to big.Rat's containing the amount that | ||||
| // security is imbalanced by | ||||
| func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, error) { | ||||
| func GetTransactionImbalances(tx store.Tx, t *models.Transaction) (map[int64]big.Rat, error) { | ||||
| 	sums := make(map[int64]big.Rat) | ||||
|  | ||||
| 	if !t.Valid() { | ||||
| @@ -31,7 +25,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, | ||||
| 		if t.Splits[i].AccountId != -1 { | ||||
| 			var err error | ||||
| 			var account *models.Account | ||||
| 			account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId) | ||||
| 			account, err = tx.GetAccount(t.Splits[i].AccountId, t.UserId) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| @@ -47,7 +41,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, | ||||
|  | ||||
| // Returns true if all securities contained in this transaction are balanced, | ||||
| // false otherwise | ||||
| func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) { | ||||
| func TransactionBalanced(tx store.Tx, t *models.Transaction) (bool, error) { | ||||
| 	var zero big.Rat | ||||
|  | ||||
| 	sums, err := GetTransactionImbalances(tx, t) | ||||
| @@ -63,219 +57,6 @@ func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) { | ||||
| 	return true, nil | ||||
| } | ||||
|  | ||||
| func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transaction, error) { | ||||
| 	var t models.Transaction | ||||
|  | ||||
| 	err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	_, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &t, nil | ||||
| } | ||||
|  | ||||
| func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) { | ||||
| 	var transactions []models.Transaction | ||||
|  | ||||
| 	_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	for i := range transactions { | ||||
| 		_, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return &transactions, nil | ||||
| } | ||||
|  | ||||
| func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error { | ||||
| 	for i := range accountids { | ||||
| 		account, err := GetAccount(tx, accountids[i], user.UserId) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		account.AccountVersion++ | ||||
| 		count, err := tx.Update(account) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if count != 1 { | ||||
| 			return errors.New("Updated more than one account") | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type AccountMissingError struct{} | ||||
|  | ||||
| func (ame AccountMissingError) Error() string { | ||||
| 	return "Account missing" | ||||
| } | ||||
|  | ||||
| func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error { | ||||
| 	// Map of any accounts with transaction splits being added | ||||
| 	a_map := make(map[int64]bool) | ||||
| 	for i := range t.Splits { | ||||
| 		if t.Splits[i].AccountId != -1 { | ||||
| 			existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if existing != 1 { | ||||
| 				return AccountMissingError{} | ||||
| 			} | ||||
| 			a_map[t.Splits[i].AccountId] = true | ||||
| 		} else if t.Splits[i].SecurityId == -1 { | ||||
| 			return AccountMissingError{} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	//increment versions for all accounts | ||||
| 	var a_ids []int64 | ||||
| 	for id := range a_map { | ||||
| 		a_ids = append(a_ids, id) | ||||
| 	} | ||||
| 	// ensure at least one of the splits is associated with an actual account | ||||
| 	if len(a_ids) < 1 { | ||||
| 		return AccountMissingError{} | ||||
| 	} | ||||
| 	err := incrementAccountVersions(tx, user, a_ids) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	t.UserId = user.UserId | ||||
| 	err = tx.Insert(t) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for i := range t.Splits { | ||||
| 		t.Splits[i].TransactionId = t.TransactionId | ||||
| 		t.Splits[i].SplitId = -1 | ||||
| 		err = tx.Insert(t.Splits[i]) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error { | ||||
| 	var existing_splits []*models.Split | ||||
|  | ||||
| 	_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// Map of any accounts with transaction splits being added | ||||
| 	a_map := make(map[int64]bool) | ||||
|  | ||||
| 	// Make a map with any existing splits for this transaction | ||||
| 	s_map := make(map[int64]bool) | ||||
| 	for i := range existing_splits { | ||||
| 		s_map[existing_splits[i].SplitId] = true | ||||
| 	} | ||||
|  | ||||
| 	// Insert splits, updating any pre-existing ones | ||||
| 	for i := range t.Splits { | ||||
| 		t.Splits[i].TransactionId = t.TransactionId | ||||
| 		_, ok := s_map[t.Splits[i].SplitId] | ||||
| 		if ok { | ||||
| 			count, err := tx.Update(t.Splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if count > 1 { | ||||
| 				return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count) | ||||
| 			} | ||||
| 			delete(s_map, t.Splits[i].SplitId) | ||||
| 		} else { | ||||
| 			t.Splits[i].SplitId = -1 | ||||
| 			err := tx.Insert(t.Splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 		if t.Splits[i].AccountId != -1 { | ||||
| 			a_map[t.Splits[i].AccountId] = true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Delete any remaining pre-existing splits | ||||
| 	for i := range existing_splits { | ||||
| 		_, ok := s_map[existing_splits[i].SplitId] | ||||
| 		if existing_splits[i].AccountId != -1 { | ||||
| 			a_map[existing_splits[i].AccountId] = true | ||||
| 		} | ||||
| 		if ok { | ||||
| 			_, err := tx.Delete(existing_splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Increment versions for all accounts with modified splits | ||||
| 	var a_ids []int64 | ||||
| 	for id := range a_map { | ||||
| 		a_ids = append(a_ids, id) | ||||
| 	} | ||||
| 	err = incrementAccountVersions(tx, user, a_ids) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.Update(t) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count > 1 { | ||||
| 		return fmt.Errorf("Updated %d transactions (expected 1)", count) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func DeleteTransaction(tx *Tx, t *models.Transaction, user *models.User) error { | ||||
| 	var accountids []int64 | ||||
| 	_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	_, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.Delete(t) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return errors.New("Deleted more than one transaction") | ||||
| 	} | ||||
|  | ||||
| 	err = incrementAccountVersions(tx, user, accountids) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
| 	user, err := GetUserFromSession(context.Tx, r) | ||||
| 	if err != nil { | ||||
| @@ -296,7 +77,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter | ||||
|  | ||||
| 		for i := range transaction.Splits { | ||||
| 			transaction.Splits[i].SplitId = -1 | ||||
| 			_, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId) | ||||
| 			_, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| @@ -310,9 +91,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter | ||||
| 			return NewError(3 /*Invalid Request*/) | ||||
| 		} | ||||
|  | ||||
| 		err = InsertTransaction(context.Tx, &transaction, user) | ||||
| 		err = context.Tx.InsertTransaction(&transaction, user) | ||||
| 		if err != nil { | ||||
| 			if _, ok := err.(AccountMissingError); ok { | ||||
| 			if _, ok := err.(store.AccountMissingError); ok { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} else { | ||||
| 				log.Print(err) | ||||
| @@ -325,7 +106,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter | ||||
| 		if context.LastLevel() { | ||||
| 			//Return all Transactions | ||||
| 			var al models.TransactionList | ||||
| 			transactions, err := GetTransactions(context.Tx, user.UserId) | ||||
| 			transactions, err := context.Tx.GetTransactions(user.UserId) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| @@ -338,7 +119,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| 			transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) | ||||
| 			transaction, err := context.Tx.GetTransaction(transactionid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
| @@ -370,13 +151,13 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter | ||||
| 			} | ||||
|  | ||||
| 			for i := range transaction.Splits { | ||||
| 				_, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId) | ||||
| 				_, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId) | ||||
| 				if err != nil { | ||||
| 					return NewError(3 /*Invalid Request*/) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			err = UpdateTransaction(context.Tx, &transaction, user) | ||||
| 			err = context.Tx.UpdateTransaction(&transaction, user) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| @@ -384,12 +165,12 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter | ||||
|  | ||||
| 			return &transaction | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) | ||||
| 			transaction, err := context.Tx.GetTransaction(transactionid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				return NewError(3 /*Invalid Request*/) | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteTransaction(context.Tx, transaction, user) | ||||
| 			err = context.Tx.DeleteTransaction(transaction, user) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
| @@ -401,41 +182,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter | ||||
| 	return NewError(3 /*Invalid Request*/) | ||||
| } | ||||
|  | ||||
| func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) { | ||||
| 	var pageDifference, tmp big.Rat | ||||
| 	for i := range transactions { | ||||
| 		_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		// Sum up the amounts from the splits we're returning so we can return | ||||
| 		// an ending balance | ||||
| 		for j := range transactions[i].Splits { | ||||
| 			if transactions[i].Splits[j].AccountId == accountid { | ||||
| 				rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount) | ||||
| 				if err != nil { | ||||
| 					return nil, err | ||||
| 				} | ||||
| 				tmp.Add(&pageDifference, rat_amount) | ||||
| 				pageDifference.Set(&tmp) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return &pageDifference, nil | ||||
| } | ||||
|  | ||||
| func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) { | ||||
| 	var splits []models.Split | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" | ||||
| 	_, err := tx.Select(&splits, sql, accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| func BalanceFromSplits(splits *[]*models.Split) (*big.Rat, error) { | ||||
| 	var balance, tmp big.Rat | ||||
| 	for _, s := range splits { | ||||
| 	for _, s := range *splits { | ||||
| 		rat_amount, err := models.GetBigAmount(s.Amount) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| @@ -447,132 +196,6 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er | ||||
| 	return &balance, nil | ||||
| } | ||||
|  | ||||
| // Assumes accountid is valid and is owned by the current user | ||||
| func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) { | ||||
| 	var splits []models.Split | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" | ||||
| 	_, err := tx.Select(&splits, sql, accountid, user.UserId, date) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var balance, tmp big.Rat | ||||
| 	for _, s := range splits { | ||||
| 		rat_amount, err := models.GetBigAmount(s.Amount) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
|  | ||||
| 	return &balance, nil | ||||
| } | ||||
|  | ||||
| func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) { | ||||
| 	var splits []models.Split | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" | ||||
| 	_, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var balance, tmp big.Rat | ||||
| 	for _, s := range splits { | ||||
| 		rat_amount, err := models.GetBigAmount(s.Amount) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
|  | ||||
| 	return &balance, nil | ||||
| } | ||||
|  | ||||
| func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { | ||||
| 	var transactions []models.Transaction | ||||
| 	var atl models.AccountTransactionsList | ||||
|  | ||||
| 	var sqlsort, balanceLimitOffset string | ||||
| 	var balanceLimitOffsetArg uint64 | ||||
| 	if sort == "date-asc" { | ||||
| 		sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC" | ||||
| 		balanceLimitOffset = " LIMIT ?" | ||||
| 		balanceLimitOffsetArg = page * limit | ||||
| 	} else if sort == "date-desc" { | ||||
| 		numSplits, err := tx.SelectInt("SELECT count(*) FROM splits") | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC" | ||||
| 		balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits) | ||||
| 		balanceLimitOffsetArg = (page + 1) * limit | ||||
| 	} | ||||
|  | ||||
| 	var sqloffset string | ||||
| 	if page > 0 { | ||||
| 		sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) | ||||
| 	} | ||||
|  | ||||
| 	account, err := GetAccount(tx, accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.Account = account | ||||
|  | ||||
| 	sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset | ||||
| 	_, err = tx.Select(&transactions, sql, user.UserId, accountid, limit) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.Transactions = &transactions | ||||
|  | ||||
| 	pageDifference, err := TransactionsBalanceDifference(tx, accountid, transactions) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.TotalTransactions = count | ||||
|  | ||||
| 	security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if security == nil { | ||||
| 		return nil, errors.New("Security not found") | ||||
| 	} | ||||
|  | ||||
| 	// Sum all the splits for all transaction splits for this account that | ||||
| 	// occurred before the page we're returning | ||||
| 	var amounts []string | ||||
| 	sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" | ||||
| 	_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var tmp, balance big.Rat | ||||
| 	for _, amount := range amounts { | ||||
| 		rat_amount, err := models.GetBigAmount(amount) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
| 	atl.BeginningBalance = balance.FloatString(security.Precision) | ||||
| 	atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) | ||||
|  | ||||
| 	return &atl, nil | ||||
| } | ||||
|  | ||||
| // Return only those transactions which have at least one split pertaining to | ||||
| // an account | ||||
| func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { | ||||
| @@ -608,7 +231,7 @@ func AccountTransactionsHandler(context *Context, r *http.Request, user *models. | ||||
| 		sort = sortstring | ||||
| 	} | ||||
|  | ||||
| 	accountTransactions, err := GetAccountTransactions(context.Tx, user, accountid, sort, page, limit) | ||||
| 	accountTransactions, err := context.Tx.GetAccountTransactions(user, accountid, sort, page, limit) | ||||
| 	if err != nil { | ||||
| 		log.Print(err) | ||||
| 		return NewError(999 /*Internal Error*/) | ||||
|   | ||||
| @@ -276,7 +276,7 @@ func TestGetTransactions(t *testing.T) { | ||||
| 			found := false | ||||
| 			for _, tran := range *tl.Transactions { | ||||
| 				if tran.TransactionId == curr.TransactionId { | ||||
| 					ensureTransactionsMatch(t, &curr, &tran, nil, true, true) | ||||
| 					ensureTransactionsMatch(t, &curr, tran, nil, true, true) | ||||
| 					if _, ok := foundIds[tran.TransactionId]; ok { | ||||
| 						continue | ||||
| 					} | ||||
| @@ -410,7 +410,7 @@ func helperTestAccountTransactions(t *testing.T, d *TestData, account *models.Ac | ||||
| 		} | ||||
| 		if atl.Transactions != nil { | ||||
| 			for _, tran := range *atl.Transactions { | ||||
| 				transactions = append(transactions, tran) | ||||
| 				transactions = append(transactions, *tran) | ||||
| 			} | ||||
| 			lastFetchCount = int64(len(*atl.Transactions)) | ||||
| 		} else { | ||||
|   | ||||
| @@ -2,8 +2,8 @@ package handlers | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| ) | ||||
| @@ -14,41 +14,21 @@ func (ueu UserExistsError) Error() string { | ||||
| 	return "User exists" | ||||
| } | ||||
|  | ||||
| func GetUser(tx *Tx, userid int64) (*models.User, error) { | ||||
| 	var u models.User | ||||
|  | ||||
| 	err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func GetUserByUsername(tx *Tx, username string) (*models.User, error) { | ||||
| 	var u models.User | ||||
|  | ||||
| 	err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func InsertUser(tx *Tx, u *models.User) error { | ||||
| func InsertUser(tx store.Tx, u *models.User) error { | ||||
| 	security_template := FindCurrencyTemplate(u.DefaultCurrency) | ||||
| 	if security_template == nil { | ||||
| 		return errors.New("Invalid ISO4217 Default Currency") | ||||
| 	} | ||||
|  | ||||
| 	existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username) | ||||
| 	exists, err := tx.UsernameExists(u.Username) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if existing > 0 { | ||||
| 	if exists { | ||||
| 		return UserExistsError{} | ||||
| 	} | ||||
|  | ||||
| 	err = tx.Insert(u) | ||||
| 	err = tx.InsertUser(u) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -58,33 +38,31 @@ func InsertUser(tx *Tx, u *models.User) error { | ||||
| 	security = *security_template | ||||
| 	security.UserId = u.UserId | ||||
|  | ||||
| 	err = InsertSecurity(tx, &security) | ||||
| 	err = tx.InsertSecurity(&security) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// Update the user's DefaultCurrency to our new SecurityId | ||||
| 	u.DefaultCurrency = security.SecurityId | ||||
| 	count, err := tx.Update(u) | ||||
| 	err = tx.UpdateUser(u) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} else if count != 1 { | ||||
| 		return errors.New("Would have updated more than one user") | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) { | ||||
| func GetUserFromSession(tx store.Tx, r *http.Request) (*models.User, error) { | ||||
| 	s, err := GetSession(tx, r) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return GetUser(tx, s.UserId) | ||||
| 	return tx.GetUser(s.UserId) | ||||
| } | ||||
|  | ||||
| func UpdateUser(tx *Tx, u *models.User) error { | ||||
| 	security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) | ||||
| func UpdateUser(tx store.Tx, u *models.User) error { | ||||
| 	security, err := tx.GetSecurity(u.DefaultCurrency, u.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency { | ||||
| @@ -93,49 +71,7 @@ func UpdateUser(tx *Tx, u *models.User) error { | ||||
| 		return errors.New("New DefaultCurrency security is not a currency") | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.Update(u) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} else if count != 1 { | ||||
| 		return errors.New("Would have updated more than one user") | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func DeleteUser(tx *Tx, u *models.User) error { | ||||
| 	count, err := tx.Delete(u) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return fmt.Errorf("No user to delete") | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", u.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", u.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId) | ||||
| 	err = tx.UpdateUser(u) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -204,7 +140,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter { | ||||
|  | ||||
| 			return user | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			err := DeleteUser(context.Tx, user) | ||||
| 			err := context.Tx.DeleteUser(user) | ||||
| 			if err != nil { | ||||
| 				log.Print(err) | ||||
| 				return NewError(999 /*Internal Error*/) | ||||
|   | ||||
| @@ -94,7 +94,7 @@ type Account struct { | ||||
| } | ||||
|  | ||||
| type AccountList struct { | ||||
| 	Accounts *[]Account `json:"accounts"` | ||||
| 	Accounts *[]*Account `json:"accounts"` | ||||
| } | ||||
|  | ||||
| func (a *Account) Write(w http.ResponseWriter) error { | ||||
|   | ||||
| @@ -28,7 +28,7 @@ func (r *Report) Read(json_str string) error { | ||||
| } | ||||
|  | ||||
| type ReportList struct { | ||||
| 	Reports *[]Report `json:"reports"` | ||||
| 	Reports *[]*Report `json:"reports"` | ||||
| } | ||||
|  | ||||
| func (rl *ReportList) Write(w http.ResponseWriter) error { | ||||
|   | ||||
| @@ -82,12 +82,12 @@ type Transaction struct { | ||||
| } | ||||
|  | ||||
| type TransactionList struct { | ||||
| 	Transactions *[]Transaction `json:"transactions"` | ||||
| 	Transactions *[]*Transaction `json:"transactions"` | ||||
| } | ||||
|  | ||||
| type AccountTransactionsList struct { | ||||
| 	Account           *Account | ||||
| 	Transactions      *[]Transaction | ||||
| 	Transactions      *[]*Transaction | ||||
| 	TotalTransactions int64 | ||||
| 	BeginningBalance  string | ||||
| 	EndingBalance     string | ||||
|   | ||||
							
								
								
									
										133
									
								
								internal/store/db/accounts.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								internal/store/db/accounts.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,133 @@ | ||||
| package db | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| ) | ||||
|  | ||||
| func (tx *Tx) GetAccount(accountid int64, userid int64) (*models.Account, error) { | ||||
| 	var account models.Account | ||||
|  | ||||
| 	err := tx.SelectOne(&account, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &account, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetAccounts(userid int64) (*[]*models.Account, error) { | ||||
| 	var accounts []*models.Account | ||||
|  | ||||
| 	_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &accounts, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) FindMatchingAccounts(account *models.Account) (*[]*models.Account, error) { | ||||
| 	var accounts []*models.Account | ||||
|  | ||||
| 	_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC", account.UserId, account.SecurityId, account.Type, account.Name, account.ParentAccountId) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &accounts, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) insertUpdateAccount(account *models.Account, insert bool) error { | ||||
| 	found := make(map[int64]bool) | ||||
| 	if !insert { | ||||
| 		found[account.AccountId] = true | ||||
| 	} | ||||
| 	parentid := account.ParentAccountId | ||||
| 	depth := 0 | ||||
| 	for parentid != -1 { | ||||
| 		depth += 1 | ||||
| 		if depth > 100 { | ||||
| 			return store.TooMuchNestingError{} | ||||
| 		} | ||||
|  | ||||
| 		var a models.Account | ||||
| 		err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) | ||||
| 		if err != nil { | ||||
| 			return store.ParentAccountMissingError{} | ||||
| 		} | ||||
|  | ||||
| 		// Insertion by itself can never result in circular dependencies | ||||
| 		if insert { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		found[parentid] = true | ||||
| 		parentid = a.ParentAccountId | ||||
| 		if _, ok := found[parentid]; ok { | ||||
| 			return store.CircularAccountsError{} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if insert { | ||||
| 		err := tx.Insert(account) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		oldacct, err := tx.GetAccount(account.AccountId, account.UserId) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		account.AccountVersion = oldacct.AccountVersion + 1 | ||||
|  | ||||
| 		count, err := tx.Update(account) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if count != 1 { | ||||
| 			return errors.New("Updated more than one account") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) InsertAccount(account *models.Account) error { | ||||
| 	return tx.insertUpdateAccount(account, true) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) UpdateAccount(account *models.Account) error { | ||||
| 	return tx.insertUpdateAccount(account, false) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) DeleteAccount(account *models.Account) error { | ||||
| 	if account.ParentAccountId != -1 { | ||||
| 		// Re-parent splits to this account's parent account if this account isn't a root account | ||||
| 		_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", account.ParentAccountId, account.AccountId) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		// Delete splits if this account is a root account | ||||
| 		_, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", account.AccountId) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Re-parent child accounts to this account's parent account | ||||
| 	_, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", account.ParentAccountId, account.AccountId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.Delete(account) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return errors.New("Was going to delete more than one account") | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"github.com/aclindsa/gorp" | ||||
| 	"github.com/aclindsa/moneygo/internal/config" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	_ "github.com/go-sql-driver/mysql" | ||||
| 	_ "github.com/lib/pq" | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| @@ -18,7 +19,7 @@ import ( | ||||
| // implementation's string type specified by the same. | ||||
| const luaMaxLengthBuffer int = 4096 | ||||
| 
 | ||||
| func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { | ||||
| func getDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { | ||||
| 	var dialect gorp.Dialect | ||||
| 	if dbtype == config.SQLite { | ||||
| 		dialect = gorp.SqliteDialect{} | ||||
| @@ -38,11 +39,11 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { | ||||
| 	dbmap := &gorp.DbMap{Db: db, Dialect: dialect} | ||||
| 	dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") | ||||
| 	dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId") | ||||
| 	dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId") | ||||
| 	dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId") | ||||
| 	dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId") | ||||
| 	dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId") | ||||
| 	dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId") | ||||
| 	dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId") | ||||
| 	dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId") | ||||
| 	rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId") | ||||
| 	rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer) | ||||
| 
 | ||||
| @@ -54,9 +55,50 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { | ||||
| 	return dbmap, nil | ||||
| } | ||||
| 
 | ||||
| func GetDSN(dbtype config.DbType, dsn string) string { | ||||
| func getDSN(dbtype config.DbType, dsn string) string { | ||||
| 	if dbtype == config.MySQL && !strings.Contains(dsn, "parseTime=true") { | ||||
| 		log.Fatalf("The DSN for MySQL MUST contain 'parseTime=True' but does not!") | ||||
| 	} | ||||
| 	return dsn | ||||
| } | ||||
| 
 | ||||
| type DbStore struct { | ||||
| 	dbMap *gorp.DbMap | ||||
| } | ||||
| 
 | ||||
| func (db *DbStore) Empty() error { | ||||
| 	return db.dbMap.TruncateTables() | ||||
| } | ||||
| 
 | ||||
| func (db *DbStore) Begin() (store.Tx, error) { | ||||
| 	tx, err := db.dbMap.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &Tx{db.dbMap.Dialect, tx}, nil | ||||
| } | ||||
| 
 | ||||
| func (db *DbStore) Close() error { | ||||
| 	err := db.dbMap.Db.Close() | ||||
| 	db.dbMap = nil | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) { | ||||
| 	dsn = getDSN(dbtype, dsn) | ||||
| 	database, err := sql.Open(dbtype.String(), dsn) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		if err != nil { | ||||
| 			database.Close() | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	dbmap, err := getDbMap(database, dbtype) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &DbStore{dbmap}, nil | ||||
| } | ||||
							
								
								
									
										78
									
								
								internal/store/db/prices.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								internal/store/db/prices.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,78 @@ | ||||
| package db | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func (tx *Tx) PriceExists(price *models.Price) (bool, error) { | ||||
| 	var prices []*models.Price | ||||
| 	_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) | ||||
| 	return len(prices) > 0, err | ||||
| } | ||||
|  | ||||
| func (tx *Tx) InsertPrice(price *models.Price) error { | ||||
| 	return tx.Insert(price) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetPrice(priceid, securityid int64) (*models.Price, error) { | ||||
| 	var price models.Price | ||||
| 	err := tx.SelectOne(&price, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &price, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetPrices(securityid int64) (*[]*models.Price, error) { | ||||
| 	var prices []*models.Price | ||||
|  | ||||
| 	_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &prices, nil | ||||
| } | ||||
|  | ||||
| // Return the latest price for security in currency units before date | ||||
| func (tx *Tx) GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) { | ||||
| 	var price models.Price | ||||
| 	err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &price, nil | ||||
| } | ||||
|  | ||||
| // Return the earliest price for security in currency units after date | ||||
| func (tx *Tx) GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) { | ||||
| 	var price models.Price | ||||
| 	err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &price, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) UpdatePrice(price *models.Price) error { | ||||
| 	count, err := tx.Update(price) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return fmt.Errorf("Expected to update 1 price, was going to update %d", count) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) DeletePrice(price *models.Price) error { | ||||
| 	count, err := tx.Delete(price) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return fmt.Errorf("Expected to delete 1 price, was going to delete %d", count) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										56
									
								
								internal/store/db/reports.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								internal/store/db/reports.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | ||||
| package db | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| ) | ||||
|  | ||||
| func (tx *Tx) GetReport(reportid int64, userid int64) (*models.Report, error) { | ||||
| 	var r models.Report | ||||
|  | ||||
| 	err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &r, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetReports(userid int64) (*[]*models.Report, error) { | ||||
| 	var reports []*models.Report | ||||
|  | ||||
| 	_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &reports, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) InsertReport(report *models.Report) error { | ||||
| 	err := tx.Insert(report) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) UpdateReport(report *models.Report) error { | ||||
| 	count, err := tx.Update(report) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return fmt.Errorf("Expected to update 1 report, was going to update %d", count) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) DeleteReport(report *models.Report) error { | ||||
| 	count, err := tx.Delete(report) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return fmt.Errorf("Expected to delete 1 report, was going to delete %d", count) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										88
									
								
								internal/store/db/securities.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								internal/store/db/securities.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,88 @@ | ||||
| package db | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| ) | ||||
|  | ||||
| func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) { | ||||
| 	var s models.Security | ||||
|  | ||||
| 	err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetSecurities(userid int64) (*[]*models.Security, error) { | ||||
| 	var securities []*models.Security | ||||
|  | ||||
| 	_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &securities, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) FindMatchingSecurities(security *models.Security) (*[]*models.Security, error) { | ||||
| 	var securities []*models.Security | ||||
|  | ||||
| 	_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", security.UserId, security.Type, security.AlternateId, security.Precision) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &securities, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) InsertSecurity(s *models.Security) error { | ||||
| 	err := tx.Insert(s) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) UpdateSecurity(security *models.Security) error { | ||||
| 	count, err := tx.Update(security) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return fmt.Errorf("Expected to update 1 security, was going to update %d", count) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) DeleteSecurity(s *models.Security) error { | ||||
| 	// First, ensure no accounts are using this security | ||||
| 	accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) | ||||
|  | ||||
| 	if accounts != 0 { | ||||
| 		return store.SecurityInUseError{"One or more accounts still use this security"} | ||||
| 	} | ||||
|  | ||||
| 	user, err := tx.GetUser(s.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} else if user.DefaultCurrency == s.SecurityId { | ||||
| 		return store.SecurityInUseError{"Cannot delete security which is user's default currency"} | ||||
| 	} | ||||
|  | ||||
| 	// Remove all prices involving this security (either of this security, or | ||||
| 	// using it as a currency) | ||||
| 	_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.Delete(s) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return fmt.Errorf("Expected to delete 1 security, was going to delete %d", count) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										42
									
								
								internal/store/db/sessions.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								internal/store/db/sessions.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| package db | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func (tx *Tx) InsertSession(session *models.Session) error { | ||||
| 	return tx.Insert(session) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetSession(secret string) (*models.Session, error) { | ||||
| 	var s models.Session | ||||
|  | ||||
| 	err := tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", secret) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if s.Expires.Before(time.Now()) { | ||||
| 		tx.Delete(&s) | ||||
| 		return nil, fmt.Errorf("Session has expired") | ||||
| 	} | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) SessionExists(secret string) (bool, error) { | ||||
| 	existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", secret) | ||||
| 	return existing != 0, err | ||||
| } | ||||
|  | ||||
| func (tx *Tx) DeleteSession(session *models.Session) error { | ||||
| 	count, err := tx.Delete(session) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										361
									
								
								internal/store/db/transactions.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										361
									
								
								internal/store/db/transactions.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,361 @@ | ||||
| package db | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"math/big" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func (tx *Tx) incrementAccountVersions(user *models.User, accountids []int64) error { | ||||
| 	for i := range accountids { | ||||
| 		account, err := tx.GetAccount(accountids[i], user.UserId) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		account.AccountVersion++ | ||||
| 		count, err := tx.Update(account) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if count != 1 { | ||||
| 			return errors.New("Updated more than one account") | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) InsertTransaction(t *models.Transaction, user *models.User) error { | ||||
| 	// Map of any accounts with transaction splits being added | ||||
| 	a_map := make(map[int64]bool) | ||||
| 	for i := range t.Splits { | ||||
| 		if t.Splits[i].AccountId != -1 { | ||||
| 			existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if existing != 1 { | ||||
| 				return store.AccountMissingError{} | ||||
| 			} | ||||
| 			a_map[t.Splits[i].AccountId] = true | ||||
| 		} else if t.Splits[i].SecurityId == -1 { | ||||
| 			return store.AccountMissingError{} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	//increment versions for all accounts | ||||
| 	var a_ids []int64 | ||||
| 	for id := range a_map { | ||||
| 		a_ids = append(a_ids, id) | ||||
| 	} | ||||
| 	// ensure at least one of the splits is associated with an actual account | ||||
| 	if len(a_ids) < 1 { | ||||
| 		return store.AccountMissingError{} | ||||
| 	} | ||||
| 	err := tx.incrementAccountVersions(user, a_ids) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	t.UserId = user.UserId | ||||
| 	err = tx.Insert(t) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for i := range t.Splits { | ||||
| 		t.Splits[i].TransactionId = t.TransactionId | ||||
| 		t.Splits[i].SplitId = -1 | ||||
| 		err = tx.Insert(t.Splits[i]) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) SplitExists(s *models.Split) (bool, error) { | ||||
| 	count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) | ||||
| 	return count == 1, err | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) { | ||||
| 	var t models.Transaction | ||||
|  | ||||
| 	err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	_, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &t, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetTransactions(userid int64) (*[]*models.Transaction, error) { | ||||
| 	var transactions []*models.Transaction | ||||
|  | ||||
| 	_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	for i := range transactions { | ||||
| 		_, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return &transactions, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) UpdateTransaction(t *models.Transaction, user *models.User) error { | ||||
| 	var existing_splits []*models.Split | ||||
|  | ||||
| 	_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// Map of any accounts with transaction splits being added | ||||
| 	a_map := make(map[int64]bool) | ||||
|  | ||||
| 	// Make a map with any existing splits for this transaction | ||||
| 	s_map := make(map[int64]bool) | ||||
| 	for i := range existing_splits { | ||||
| 		s_map[existing_splits[i].SplitId] = true | ||||
| 	} | ||||
|  | ||||
| 	// Insert splits, updating any pre-existing ones | ||||
| 	for i := range t.Splits { | ||||
| 		t.Splits[i].TransactionId = t.TransactionId | ||||
| 		_, ok := s_map[t.Splits[i].SplitId] | ||||
| 		if ok { | ||||
| 			count, err := tx.Update(t.Splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if count > 1 { | ||||
| 				return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count) | ||||
| 			} | ||||
| 			delete(s_map, t.Splits[i].SplitId) | ||||
| 		} else { | ||||
| 			t.Splits[i].SplitId = -1 | ||||
| 			err := tx.Insert(t.Splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 		if t.Splits[i].AccountId != -1 { | ||||
| 			a_map[t.Splits[i].AccountId] = true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Delete any remaining pre-existing splits | ||||
| 	for i := range existing_splits { | ||||
| 		_, ok := s_map[existing_splits[i].SplitId] | ||||
| 		if existing_splits[i].AccountId != -1 { | ||||
| 			a_map[existing_splits[i].AccountId] = true | ||||
| 		} | ||||
| 		if ok { | ||||
| 			_, err := tx.Delete(existing_splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Increment versions for all accounts with modified splits | ||||
| 	var a_ids []int64 | ||||
| 	for id := range a_map { | ||||
| 		a_ids = append(a_ids, id) | ||||
| 	} | ||||
| 	err = tx.incrementAccountVersions(user, a_ids) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.Update(t) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count > 1 { | ||||
| 		return fmt.Errorf("Updated %d transactions (expected 1)", count) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) DeleteTransaction(t *models.Transaction, user *models.User) error { | ||||
| 	var accountids []int64 | ||||
| 	_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	_, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.Delete(t) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return errors.New("Deleted more than one transaction") | ||||
| 	} | ||||
|  | ||||
| 	err = tx.incrementAccountVersions(user, accountids) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) { | ||||
| 	var splits []*models.Split | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" | ||||
| 	_, err := tx.Select(&splits, sql, accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &splits, nil | ||||
| } | ||||
|  | ||||
| // Assumes accountid is valid and is owned by the current user | ||||
| func (tx *Tx) GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) { | ||||
| 	var splits []*models.Split | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" | ||||
| 	_, err := tx.Select(&splits, sql, accountid, user.UserId, date) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &splits, err | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) { | ||||
| 	var splits []*models.Split | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" | ||||
| 	_, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &splits, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) transactionsBalanceDifference(accountid int64, transactions []*models.Transaction) (*big.Rat, error) { | ||||
| 	var pageDifference, tmp big.Rat | ||||
| 	for i := range transactions { | ||||
| 		_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		// Sum up the amounts from the splits we're returning so we can return | ||||
| 		// an ending balance | ||||
| 		for j := range transactions[i].Splits { | ||||
| 			if transactions[i].Splits[j].AccountId == accountid { | ||||
| 				rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount) | ||||
| 				if err != nil { | ||||
| 					return nil, err | ||||
| 				} | ||||
| 				tmp.Add(&pageDifference, rat_amount) | ||||
| 				pageDifference.Set(&tmp) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return &pageDifference, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { | ||||
| 	var transactions []*models.Transaction | ||||
| 	var atl models.AccountTransactionsList | ||||
|  | ||||
| 	var sqlsort, balanceLimitOffset string | ||||
| 	var balanceLimitOffsetArg uint64 | ||||
| 	if sort == "date-asc" { | ||||
| 		sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC" | ||||
| 		balanceLimitOffset = " LIMIT ?" | ||||
| 		balanceLimitOffsetArg = page * limit | ||||
| 	} else if sort == "date-desc" { | ||||
| 		numSplits, err := tx.SelectInt("SELECT count(*) FROM splits") | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC" | ||||
| 		balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits) | ||||
| 		balanceLimitOffsetArg = (page + 1) * limit | ||||
| 	} | ||||
|  | ||||
| 	var sqloffset string | ||||
| 	if page > 0 { | ||||
| 		sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) | ||||
| 	} | ||||
|  | ||||
| 	account, err := tx.GetAccount(accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.Account = account | ||||
|  | ||||
| 	sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset | ||||
| 	_, err = tx.Select(&transactions, sql, user.UserId, accountid, limit) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.Transactions = &transactions | ||||
|  | ||||
| 	pageDifference, err := tx.transactionsBalanceDifference(accountid, transactions) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.TotalTransactions = count | ||||
|  | ||||
| 	security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if security == nil { | ||||
| 		return nil, errors.New("Security not found") | ||||
| 	} | ||||
|  | ||||
| 	// Sum all the splits for all transaction splits for this account that | ||||
| 	// occurred before the page we're returning | ||||
| 	var amounts []string | ||||
| 	sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" | ||||
| 	_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var tmp, balance big.Rat | ||||
| 	for _, amount := range amounts { | ||||
| 		rat_amount, err := models.GetBigAmount(amount) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
| 	atl.BeginningBalance = balance.FloatString(security.Precision) | ||||
| 	atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) | ||||
|  | ||||
| 	return &atl, nil | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package handlers | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| @@ -41,7 +41,20 @@ func (tx *Tx) Insert(list ...interface{}) error { | ||||
| } | ||||
| 
 | ||||
| func (tx *Tx) Update(list ...interface{}) (int64, error) { | ||||
| 	return tx.Tx.Update(list...) | ||||
| 	count, err := tx.Tx.Update(list...) | ||||
| 	if count == 0 { | ||||
| 		switch tx.Dialect.(type) { | ||||
| 		case gorp.MySQLDialect: | ||||
| 			// Always return 1 for 0 if we're using MySQL because it returns | ||||
| 			// count=0 if the row data was unchanged, even if the row existed | ||||
| 
 | ||||
| 			// TODO Find another way to fix this without risking ignoring | ||||
| 			// errors | ||||
| 
 | ||||
| 			count = 1 | ||||
| 		} | ||||
| 	} | ||||
| 	return count, err | ||||
| } | ||||
| 
 | ||||
| func (tx *Tx) Delete(list ...interface{}) (int64, error) { | ||||
| @@ -55,11 +68,3 @@ func (tx *Tx) Commit() error { | ||||
| func (tx *Tx) Rollback() error { | ||||
| 	return tx.Tx.Rollback() | ||||
| } | ||||
| 
 | ||||
| func GetTx(db *gorp.DbMap) (*Tx, error) { | ||||
| 	tx, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &Tx{db.Dialect, tx}, nil | ||||
| } | ||||
							
								
								
									
										86
									
								
								internal/store/db/users.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								internal/store/db/users.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| package db | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| ) | ||||
|  | ||||
| func (tx *Tx) UsernameExists(username string) (bool, error) { | ||||
| 	existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", username) | ||||
| 	return existing != 0, err | ||||
| } | ||||
|  | ||||
| func (tx *Tx) InsertUser(user *models.User) error { | ||||
| 	return tx.Insert(user) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetUser(userid int64) (*models.User, error) { | ||||
| 	var u models.User | ||||
|  | ||||
| 	err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetUserByUsername(username string) (*models.User, error) { | ||||
| 	var u models.User | ||||
|  | ||||
| 	err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) UpdateUser(user *models.User) error { | ||||
| 	count, err := tx.Update(user) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return fmt.Errorf("Expected to update 1 user, was going to update %d", count) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) DeleteUser(user *models.User) error { | ||||
| 	count, err := tx.Delete(user) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count) | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", user.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", user.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", user.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", user.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", user.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", user.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", user.UserId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										123
									
								
								internal/store/store.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								internal/store/store.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,123 @@ | ||||
| package store | ||||
|  | ||||
| import ( | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type UserStore interface { | ||||
| 	UsernameExists(username string) (bool, error) | ||||
| 	InsertUser(user *models.User) error | ||||
| 	GetUser(userid int64) (*models.User, error) | ||||
| 	GetUserByUsername(username string) (*models.User, error) | ||||
| 	UpdateUser(user *models.User) error | ||||
| 	DeleteUser(user *models.User) error | ||||
| } | ||||
|  | ||||
| type SessionStore interface { | ||||
| 	SessionExists(secret string) (bool, error) | ||||
| 	InsertSession(session *models.Session) error | ||||
| 	GetSession(secret string) (*models.Session, error) | ||||
| 	DeleteSession(session *models.Session) error | ||||
| } | ||||
|  | ||||
| type SecurityInUseError struct { | ||||
| 	Message string | ||||
| } | ||||
|  | ||||
| func (e SecurityInUseError) Error() string { | ||||
| 	return e.Message | ||||
| } | ||||
|  | ||||
| type SecurityStore interface { | ||||
| 	InsertSecurity(security *models.Security) error | ||||
| 	GetSecurity(securityid int64, userid int64) (*models.Security, error) | ||||
| 	GetSecurities(userid int64) (*[]*models.Security, error) | ||||
| 	FindMatchingSecurities(security *models.Security) (*[]*models.Security, error) | ||||
| 	UpdateSecurity(security *models.Security) error | ||||
| 	DeleteSecurity(security *models.Security) error | ||||
| } | ||||
|  | ||||
| type PriceStore interface { | ||||
| 	PriceExists(price *models.Price) (bool, error) | ||||
| 	InsertPrice(price *models.Price) error | ||||
| 	GetPrice(priceid, securityid int64) (*models.Price, error) | ||||
| 	GetPrices(securityid int64) (*[]*models.Price, error) | ||||
| 	GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) | ||||
| 	GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) | ||||
| 	UpdatePrice(price *models.Price) error | ||||
| 	DeletePrice(price *models.Price) error | ||||
| } | ||||
|  | ||||
| type ParentAccountMissingError struct{} | ||||
|  | ||||
| func (pame ParentAccountMissingError) Error() string { | ||||
| 	return "Parent account missing" | ||||
| } | ||||
|  | ||||
| type TooMuchNestingError struct{} | ||||
|  | ||||
| func (tmne TooMuchNestingError) Error() string { | ||||
| 	return "Too much account nesting" | ||||
| } | ||||
|  | ||||
| type CircularAccountsError struct{} | ||||
|  | ||||
| func (cae CircularAccountsError) Error() string { | ||||
| 	return "Would result in circular account relationship" | ||||
| } | ||||
|  | ||||
| type AccountStore interface { | ||||
| 	InsertAccount(account *models.Account) error | ||||
| 	GetAccount(accountid int64, userid int64) (*models.Account, error) | ||||
| 	GetAccounts(userid int64) (*[]*models.Account, error) | ||||
| 	FindMatchingAccounts(account *models.Account) (*[]*models.Account, error) | ||||
| 	UpdateAccount(account *models.Account) error | ||||
| 	DeleteAccount(account *models.Account) error | ||||
| } | ||||
|  | ||||
| type AccountMissingError struct{} | ||||
|  | ||||
| func (ame AccountMissingError) Error() string { | ||||
| 	return "Account missing" | ||||
| } | ||||
|  | ||||
| type TransactionStore interface { | ||||
| 	SplitExists(s *models.Split) (bool, error) | ||||
| 	InsertTransaction(t *models.Transaction, user *models.User) error | ||||
| 	GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) | ||||
| 	GetTransactions(userid int64) (*[]*models.Transaction, error) | ||||
| 	UpdateTransaction(t *models.Transaction, user *models.User) error | ||||
| 	DeleteTransaction(t *models.Transaction, user *models.User) error | ||||
| 	GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) | ||||
| 	GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) | ||||
| 	GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) | ||||
| 	GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) | ||||
| } | ||||
|  | ||||
| type ReportStore interface { | ||||
| 	InsertReport(report *models.Report) error | ||||
| 	GetReport(reportid int64, userid int64) (*models.Report, error) | ||||
| 	GetReports(userid int64) (*[]*models.Report, error) | ||||
| 	UpdateReport(report *models.Report) error | ||||
| 	DeleteReport(report *models.Report) error | ||||
| } | ||||
|  | ||||
| type Tx interface { | ||||
| 	Commit() error | ||||
| 	Rollback() error | ||||
|  | ||||
| 	UserStore | ||||
| 	SessionStore | ||||
| 	SecurityStore | ||||
| 	PriceStore | ||||
| 	AccountStore | ||||
| 	TransactionStore | ||||
| 	ReportStore | ||||
| } | ||||
|  | ||||
| type Store interface { | ||||
| 	Empty() error | ||||
| 	Begin() (Tx, error) | ||||
| 	Close() error | ||||
| } | ||||
							
								
								
									
										15
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								main.go
									
									
									
									
									
								
							| @@ -3,11 +3,10 @@ package main | ||||
| //go:generate make | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"flag" | ||||
| 	"github.com/aclindsa/moneygo/internal/config" | ||||
| 	"github.com/aclindsa/moneygo/internal/db" | ||||
| 	"github.com/aclindsa/moneygo/internal/handlers" | ||||
| 	"github.com/aclindsa/moneygo/internal/store/db" | ||||
| 	"github.com/kabukky/httpscerts" | ||||
| 	"log" | ||||
| 	"net" | ||||
| @@ -67,21 +66,15 @@ func staticHandler(w http.ResponseWriter, r *http.Request, basedir string) { | ||||
| } | ||||
|  | ||||
| func main() { | ||||
| 	dsn := db.GetDSN(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN) | ||||
| 	database, err := sql.Open(cfg.MoneyGo.DBType.String(), dsn) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| 	defer database.Close() | ||||
|  | ||||
| 	dbmap, err := db.GetDbMap(database, cfg.MoneyGo.DBType) | ||||
| 	db, err := db.GetStore(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
|  | ||||
| 	// Get ServeMux for API and add our own handlers for files | ||||
| 	servemux := http.NewServeMux() | ||||
| 	servemux.Handle("/v1/", &handlers.APIHandler{DB: dbmap}) | ||||
| 	servemux.Handle("/v1/", &handlers.APIHandler{Store: db}) | ||||
| 	servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir)) | ||||
| 	servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir)) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user