1
0
mirror of https://github.com/aclindsa/moneygo.git synced 2024-10-30 07:40:05 -04:00

Merge pull request #35 from aclindsa/store_split

Split DB activity into 'store'
This commit is contained in:
Aaron Lindsay 2017-12-09 19:36:33 -05:00 committed by GitHub
commit 9cdf4f3c29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1216 additions and 949 deletions

View File

@ -20,12 +20,11 @@ env:
- MONEYGO_TEST_DB=mysql - MONEYGO_TEST_DB=mysql
- MONEYGO_TEST_DB=postgres - MONEYGO_TEST_DB=postgres
# OSX builds take too long, so don't wait for all of them # OSX builds take too long, so don't wait for them
matrix: matrix:
fast_finish: true fast_finish: true
allow_failures: allow_failures:
- os: osx - os: osx
go: master
before_install: before_install:
# Fetch/build coverage reporting tools # Fetch/build coverage reporting tools

View File

@ -3,43 +3,22 @@ package handlers
import ( import (
"errors" "errors"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"log" "log"
"net/http" "net/http"
) )
func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) {
var a models.Account
err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
if err != nil {
return nil, err
}
return &a, nil
}
func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) {
var accounts []models.Account
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
if err != nil {
return nil, err
}
return &accounts, nil
}
// Get (and attempt to create if it doesn't exist). Matches on UserId, // Get (and attempt to create if it doesn't exist). Matches on UserId,
// SecurityId, Type, Name, and ParentAccountId // SecurityId, Type, Name, and ParentAccountId
func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) { func GetCreateAccount(tx store.Tx, a models.Account) (*models.Account, error) {
var accounts []models.Account
var account models.Account var account models.Account
// Try to find the top-level trading account accounts, err := tx.FindMatchingAccounts(&a)
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(accounts) == 1 { if len(*accounts) > 0 {
account = accounts[0] account = *(*accounts)[0]
} else { } else {
account.UserId = a.UserId account.UserId = a.UserId
account.SecurityId = a.SecurityId account.SecurityId = a.SecurityId
@ -47,7 +26,7 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
account.Name = a.Name account.Name = a.Name
account.ParentAccountId = a.ParentAccountId account.ParentAccountId = a.ParentAccountId
err = tx.Insert(&account) err = tx.InsertAccount(&account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -57,11 +36,11 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
// Get (and attempt to create if it doesn't exist) the security/currency // Get (and attempt to create if it doesn't exist) the security/currency
// trading account for the supplied security/currency // trading account for the supplied security/currency
func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { func GetTradingAccount(tx store.Tx, userid int64, securityid int64) (*models.Account, error) {
var tradingAccount models.Account var tradingAccount models.Account
var account models.Account var account models.Account
user, err := GetUser(tx, userid) user, err := tx.GetUser(userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -78,7 +57,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account,
return nil, err return nil, err
} }
security, err := GetSecurity(tx, securityid, userid) security, err := tx.GetSecurity(securityid, userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -99,7 +78,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account,
// Get (and attempt to create if it doesn't exist) the security/currency // Get (and attempt to create if it doesn't exist) the security/currency
// imbalance account for the supplied security/currency // imbalance account for the supplied security/currency
func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { func GetImbalanceAccount(tx store.Tx, userid int64, securityid int64) (*models.Account, error) {
var imbalanceAccount models.Account var imbalanceAccount models.Account
var account models.Account var account models.Account
xxxtemplate := FindSecurityTemplate("XXX", models.Currency) xxxtemplate := FindSecurityTemplate("XXX", models.Currency)
@ -123,7 +102,7 @@ func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Accoun
return nil, err return nil, err
} }
security, err := GetSecurity(tx, securityid, userid) security, err := tx.GetSecurity(securityid, userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -142,120 +121,6 @@ func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Accoun
return a, nil return a, nil
} }
type ParentAccountMissingError struct{}
func (pame ParentAccountMissingError) Error() string {
return "Parent account missing"
}
type TooMuchNestingError struct{}
func (tmne TooMuchNestingError) Error() string {
return "Too much nesting"
}
type CircularAccountsError struct{}
func (cae CircularAccountsError) Error() string {
return "Would result in circular account relationship"
}
func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error {
found := make(map[int64]bool)
if !insert {
found[a.AccountId] = true
}
parentid := a.ParentAccountId
depth := 0
for parentid != -1 {
depth += 1
if depth > 100 {
return TooMuchNestingError{}
}
var a models.Account
err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid)
if err != nil {
return ParentAccountMissingError{}
}
// Insertion by itself can never result in circular dependencies
if insert {
break
}
found[parentid] = true
parentid = a.ParentAccountId
if _, ok := found[parentid]; ok {
return CircularAccountsError{}
}
}
if insert {
err := tx.Insert(a)
if err != nil {
return err
}
} else {
oldacct, err := GetAccount(tx, a.AccountId, a.UserId)
if err != nil {
return err
}
a.AccountVersion = oldacct.AccountVersion + 1
count, err := tx.Update(a)
if err != nil {
return err
}
if count != 1 {
return errors.New("Updated more than one account")
}
}
return nil
}
func InsertAccount(tx *Tx, a *models.Account) error {
return insertUpdateAccount(tx, a, true)
}
func UpdateAccount(tx *Tx, a *models.Account) error {
return insertUpdateAccount(tx, a, false)
}
func DeleteAccount(tx *Tx, a *models.Account) error {
if a.ParentAccountId != -1 {
// Re-parent splits to this account's parent account if this account isn't a root account
_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)
if err != nil {
return err
}
} else {
// Delete splits if this account is a root account
_, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId)
if err != nil {
return err
}
}
// Re-parent child accounts to this account's parent account
_, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId)
if err != nil {
return err
}
count, err := tx.Delete(a)
if err != nil {
return err
}
if count != 1 {
return errors.New("Was going to delete more than one account")
}
return nil
}
func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
user, err := GetUserFromSession(context.Tx, r) user, err := GetUserFromSession(context.Tx, r)
if err != nil { if err != nil {
@ -279,7 +144,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
account.UserId = user.UserId account.UserId = user.UserId
account.AccountVersion = 0 account.AccountVersion = 0
security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId) security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -288,9 +153,9 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
err = InsertAccount(context.Tx, &account) err = context.Tx.InsertAccount(&account)
if err != nil { if err != nil {
if _, ok := err.(ParentAccountMissingError); ok { if _, ok := err.(store.ParentAccountMissingError); ok {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} else { } else {
log.Print(err) log.Print(err)
@ -303,7 +168,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
if context.LastLevel() { if context.LastLevel() {
//Return all Accounts //Return all Accounts
var al models.AccountList var al models.AccountList
accounts, err := GetAccounts(context.Tx, user.UserId) accounts, err := context.Tx.GetAccounts(user.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -319,7 +184,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
if context.LastLevel() { if context.LastLevel() {
// Return Account with this Id // Return Account with this Id
account, err := GetAccount(context.Tx, accountid, user.UserId) account, err := context.Tx.GetAccount(accountid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
@ -340,7 +205,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
} }
account.UserId = user.UserId account.UserId = user.UserId
security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId) security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -353,11 +218,11 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
err = UpdateAccount(context.Tx, &account) err = context.Tx.UpdateAccount(&account)
if err != nil { if err != nil {
if _, ok := err.(ParentAccountMissingError); ok { if _, ok := err.(store.ParentAccountMissingError); ok {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} else if _, ok := err.(CircularAccountsError); ok { } else if _, ok := err.(store.CircularAccountsError); ok {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} else { } else {
log.Print(err) log.Print(err)
@ -367,12 +232,12 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
return &account return &account
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
account, err := GetAccount(context.Tx, accountid, user.UserId) account, err := context.Tx.GetAccount(accountid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
err = DeleteAccount(context.Tx, account) err = context.Tx.DeleteAccount(account)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)

View File

@ -4,8 +4,8 @@ import (
"context" "context"
"errors" "errors"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
"math/big"
"strings" "strings"
) )
@ -16,7 +16,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
ctx := L.Context() ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx) tx, ok := ctx.Value(dbContextKey).(store.Tx)
if !ok { if !ok {
return nil, errors.New("Couldn't find tx in lua's Context") return nil, errors.New("Couldn't find tx in lua's Context")
} }
@ -28,14 +28,14 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
return nil, errors.New("Couldn't find User in lua's Context") return nil, errors.New("Couldn't find User in lua's Context")
} }
accounts, err := GetAccounts(tx, user.UserId) accounts, err := tx.GetAccounts(user.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
account_map = make(map[int64]*models.Account) account_map = make(map[int64]*models.Account)
for i := range *accounts { for i := range *accounts {
account_map[(*accounts)[i].AccountId] = &(*accounts)[i] account_map[(*accounts)[i].AccountId] = (*accounts)[i]
} }
ctx = context.WithValue(ctx, accountsContextKey, account_map) ctx = context.WithValue(ctx, accountsContextKey, account_map)
@ -150,7 +150,7 @@ func luaAccountBalance(L *lua.LState) int {
a := luaCheckAccount(L, 1) a := luaCheckAccount(L, 1)
ctx := L.Context() ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx) tx, ok := ctx.Value(dbContextKey).(store.Tx)
if !ok { if !ok {
panic("Couldn't find tx in lua's Context") panic("Couldn't find tx in lua's Context")
} }
@ -167,24 +167,29 @@ func luaAccountBalance(L *lua.LState) int {
panic("SecurityId not in lua security_map") panic("SecurityId not in lua security_map")
} }
date := luaWeakCheckTime(L, 2) date := luaWeakCheckTime(L, 2)
var b Balance var splits *[]*models.Split
var rat *big.Rat
if date != nil { if date != nil {
end := luaWeakCheckTime(L, 3) end := luaWeakCheckTime(L, 3)
if end != nil { if end != nil {
rat, err = GetAccountBalanceDateRange(tx, user, a.AccountId, date, end) splits, err = tx.GetAccountSplitsDateRange(user, a.AccountId, date, end)
} else { } else {
rat, err = GetAccountBalanceDate(tx, user, a.AccountId, date) splits, err = tx.GetAccountSplitsDate(user, a.AccountId, date)
} }
} else { } else {
rat, err = GetAccountBalance(tx, user, a.AccountId) splits, err = tx.GetAccountSplits(user, a.AccountId)
} }
if err != nil { if err != nil {
panic("Failed to GetAccountBalance:" + err.Error()) panic("Failed to fetch splits for account:" + err.Error())
} }
b.Amount = rat rat, err := BalanceFromSplits(splits)
b.Security = security if err != nil {
L.Push(BalanceToLua(L, &b)) panic("Failed to calculate balance for account:" + err.Error())
}
b := &Balance{
Amount: rat,
Security: security,
}
L.Push(BalanceToLua(L, b))
return 1 return 1
} }

View File

@ -2,12 +2,11 @@ package handlers_test
import ( import (
"bytes" "bytes"
"database/sql"
"encoding/json" "encoding/json"
"github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/config"
"github.com/aclindsa/moneygo/internal/db"
"github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/handlers"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
@ -253,24 +252,15 @@ func RunTests(m *testing.M) int {
dsn = envDSN dsn = envDSN
} }
dsn = db.GetDSN(dbType, dsn) db, err := db.GetStore(dbType, dsn)
database, err := sql.Open(dbType.String(), dsn)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer database.Close() defer db.Close()
dbmap, err := db.GetDbMap(database, dbType) db.Empty() // clear the DB tables
if err != nil {
log.Fatal(err)
}
err = dbmap.TruncateTables() server = httptest.NewTLSServer(&handlers.APIHandler{Store: db})
if err != nil {
log.Fatal(err)
}
server = httptest.NewTLSServer(&handlers.APIHandler{DB: dbmap})
defer server.Close() defer server.Close()
return m.Run() return m.Run()

View File

@ -437,7 +437,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite
} }
split.AccountId = acctId split.AccountId = acctId
exists, err := SplitAlreadyImported(context.Tx, split) exists, err := context.Tx.SplitExists(split)
if err != nil { if err != nil {
log.Print("Error checking if split was already imported:", err) log.Print("Error checking if split was already imported:", err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -446,7 +446,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite
} }
} }
if !already_imported { if !already_imported {
err := InsertTransaction(context.Tx, &transaction, user) err := context.Tx.InsertTransaction(&transaction, user)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)

View File

@ -38,13 +38,13 @@ func TestImportGnucash(t *testing.T) {
} }
for i, account := range *accounts.Accounts { for i, account := range *accounts.Accounts {
if account.Name == "Income" && account.Type == models.Income && account.ParentAccountId == -1 { if account.Name == "Income" && account.Type == models.Income && account.ParentAccountId == -1 {
income = &(*accounts.Accounts)[i] income = (*accounts.Accounts)[i]
} else if account.Name == "Equity" && account.Type == models.Equity && account.ParentAccountId == -1 { } else if account.Name == "Equity" && account.Type == models.Equity && account.ParentAccountId == -1 {
equity = &(*accounts.Accounts)[i] equity = (*accounts.Accounts)[i]
} else if account.Name == "Liabilities" && account.Type == models.Liability && account.ParentAccountId == -1 { } else if account.Name == "Liabilities" && account.Type == models.Liability && account.ParentAccountId == -1 {
liabilities = &(*accounts.Accounts)[i] liabilities = (*accounts.Accounts)[i]
} else if account.Name == "Expenses" && account.Type == models.Expense && account.ParentAccountId == -1 { } else if account.Name == "Expenses" && account.Type == models.Expense && account.ParentAccountId == -1 {
expenses = &(*accounts.Accounts)[i] expenses = (*accounts.Accounts)[i]
} }
} }
if income == nil { if income == nil {
@ -61,15 +61,15 @@ func TestImportGnucash(t *testing.T) {
} }
for i, account := range *accounts.Accounts { for i, account := range *accounts.Accounts {
if account.Name == "Salary" && account.Type == models.Income && account.ParentAccountId == income.AccountId { if account.Name == "Salary" && account.Type == models.Income && account.ParentAccountId == income.AccountId {
salary = &(*accounts.Accounts)[i] salary = (*accounts.Accounts)[i]
} else if account.Name == "Opening Balances" && account.Type == models.Equity && account.ParentAccountId == equity.AccountId { } else if account.Name == "Opening Balances" && account.Type == models.Equity && account.ParentAccountId == equity.AccountId {
openingbalances = &(*accounts.Accounts)[i] openingbalances = (*accounts.Accounts)[i]
} else if account.Name == "Credit Card" && account.Type == models.Liability && account.ParentAccountId == liabilities.AccountId { } else if account.Name == "Credit Card" && account.Type == models.Liability && account.ParentAccountId == liabilities.AccountId {
creditcard = &(*accounts.Accounts)[i] creditcard = (*accounts.Accounts)[i]
} else if account.Name == "Groceries" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId { } else if account.Name == "Groceries" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId {
groceries = &(*accounts.Accounts)[i] groceries = (*accounts.Accounts)[i]
} else if account.Name == "Cable" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId { } else if account.Name == "Cable" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId {
cable = &(*accounts.Accounts)[i] cable = (*accounts.Accounts)[i]
} }
} }
if salary == nil { if salary == nil {

View File

@ -1,8 +1,9 @@
package handlers package handlers
import ( import (
"github.com/aclindsa/gorp"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"github.com/aclindsa/moneygo/internal/store/db"
"log" "log"
"net/http" "net/http"
"path" "path"
@ -16,7 +17,7 @@ type ResponseWriterWriter interface {
} }
type Context struct { type Context struct {
Tx *Tx Tx store.Tx
User *models.User User *models.User
remainingURL string // portion of URL path not yet reached in the hierarchy remainingURL string // portion of URL path not yet reached in the hierarchy
} }
@ -46,11 +47,11 @@ func (c *Context) LastLevel() bool {
type Handler func(*http.Request, *Context) ResponseWriterWriter type Handler func(*http.Request, *Context) ResponseWriterWriter
type APIHandler struct { type APIHandler struct {
DB *gorp.DbMap Store *db.DbStore
} }
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
tx, err := GetTx(ah.DB) tx, err := ah.Store.Begin()
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)

View File

@ -3,6 +3,7 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"github.com/aclindsa/ofxgo" "github.com/aclindsa/ofxgo"
"io" "io"
"log" "log"
@ -23,7 +24,7 @@ func (od *OFXDownload) Read(json_str string) error {
return dec.Decode(od) return dec.Decode(od)
} }
func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { func ofxImportHelper(tx store.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
itl, err := ImportOFX(r) itl, err := ImportOFX(r)
if err != nil { if err != nil {
@ -38,7 +39,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
} }
// Return Account with this Id // Return Account with this Id
account, err := GetAccount(tx, accountid, user.UserId) account, err := tx.GetAccount(accountid, user.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
@ -158,7 +159,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
split := new(models.Split) split := new(models.Split)
r := new(big.Rat) r := new(big.Rat)
r.Neg(&imbalance) r.Neg(&imbalance)
security, err := GetSecurity(tx, imbalanced_security, user.UserId) security, err := tx.GetSecurity(imbalanced_security, user.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -186,7 +187,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
split.SecurityId = -1 split.SecurityId = -1
} }
exists, err := SplitAlreadyImported(tx, split) exists, err := tx.SplitExists(split)
if err != nil { if err != nil {
log.Print("Error checking if split was already imported:", err) log.Print("Error checking if split was already imported:", err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -201,7 +202,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
} }
for _, transaction := range transactions { for _, transaction := range transactions {
err := InsertTransaction(tx, &transaction, user) err := tx.InsertTransaction(&transaction, user)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -217,7 +218,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *models.User, acco
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
account, err := GetAccount(context.Tx, accountid, user.UserId) account, err := context.Tx.GetAccount(accountid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }

View File

@ -83,7 +83,7 @@ func findAccount(client *http.Client, name string, tipe models.AccountType, secu
} }
for _, account := range *accounts.Accounts { for _, account := range *accounts.Accounts {
if account.Name == name && account.Type == tipe && account.SecurityId == securityid { if account.Name == name && account.Type == tipe && account.SecurityId == securityid {
return &account, nil return account, nil
} }
} }
return nil, fmt.Errorf("Unable to find account: \"%s\"", name) return nil, fmt.Errorf("Unable to find account: \"%s\"", name)

View File

@ -2,82 +2,41 @@ package handlers
import ( import (
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"log" "log"
"net/http" "net/http"
"time" "time"
) )
func CreatePriceIfNotExist(tx *Tx, price *models.Price) error { func CreatePriceIfNotExist(tx store.Tx, price *models.Price) error {
if len(price.RemoteId) == 0 { if len(price.RemoteId) == 0 {
// Always create a new price if we can't match on the RemoteId // Always create a new price if we can't match on the RemoteId
err := tx.Insert(price) err := tx.InsertPrice(price)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
var prices []*models.Price exists, err := tx.PriceExists(price)
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value)
if err != nil { if err != nil {
return err return err
} }
if exists {
if len(prices) > 0 {
return nil // price already exists return nil // price already exists
} }
err = tx.Insert(price) err = tx.InsertPrice(price)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) {
var p models.Price
err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
if err != nil {
return nil, err
}
return &p, nil
}
func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) {
var prices []*models.Price
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
if err != nil {
return nil, err
}
return &prices, nil
}
// Return the latest price for security in currency units before date
func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
var p models.Price
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil {
return nil, err
}
return &p, nil
}
// Return the earliest price for security in currency units after date
func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
var p models.Price
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil {
return nil, err
}
return &p, nil
}
// Return the price for security in currency closest to date // Return the price for security in currency closest to date
func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { func GetClosestPrice(tx store.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
earliest, _ := GetEarliestPrice(tx, security, currency, date) earliest, _ := tx.GetEarliestPrice(security, currency, date)
latest, err := GetLatestPrice(tx, security, currency, date) latest, err := tx.GetLatestPrice(security, currency, date)
// Return early if either earliest or latest are invalid // Return early if either earliest or latest are invalid
if earliest == nil { if earliest == nil {
@ -96,7 +55,7 @@ func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Tim
} }
func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter { func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter {
security, err := GetSecurity(context.Tx, securityid, user.UserId) security, err := context.Tx.GetSecurity(securityid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
@ -111,12 +70,12 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
if price.SecurityId != security.SecurityId { if price.SecurityId != security.SecurityId {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId) _, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
err = context.Tx.Insert(&price) err = context.Tx.InsertPrice(&price)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -128,7 +87,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
//Return all this security's prices //Return all this security's prices
var pl models.PriceList var pl models.PriceList
prices, err := GetPrices(context.Tx, security.SecurityId) prices, err := context.Tx.GetPrices(security.SecurityId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -143,7 +102,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
price, err := GetPrice(context.Tx, priceid, security.SecurityId) price, err := context.Tx.GetPrice(priceid, security.SecurityId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
@ -160,30 +119,30 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
_, err = GetSecurity(context.Tx, price.SecurityId, user.UserId) _, err = context.Tx.GetSecurity(price.SecurityId, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId) _, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
count, err := context.Tx.Update(&price) err = context.Tx.UpdatePrice(&price)
if err != nil || count != 1 { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
} }
return &price return &price
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
price, err := GetPrice(context.Tx, priceid, security.SecurityId) price, err := context.Tx.GetPrice(priceid, security.SecurityId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
count, err := context.Tx.Delete(price) err = context.Tx.DeletePrice(price)
if err != nil || count != 1 { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
} }

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
"log" "log"
"net/http" "net/http"
@ -24,57 +25,7 @@ const (
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) { func runReport(tx store.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
var r models.Report
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
if err != nil {
return nil, err
}
return &r, nil
}
func GetReports(tx *Tx, userid int64) (*[]models.Report, error) {
var reports []models.Report
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
if err != nil {
return nil, err
}
return &reports, nil
}
func InsertReport(tx *Tx, r *models.Report) error {
err := tx.Insert(r)
if err != nil {
return err
}
return nil
}
func UpdateReport(tx *Tx, r *models.Report) error {
count, err := tx.Update(r)
if err != nil {
return err
}
if count != 1 {
return errors.New("Updated more than one report")
}
return nil
}
func DeleteReport(tx *Tx, r *models.Report) error {
count, err := tx.Delete(r)
if err != nil {
return err
}
if count != 1 {
return errors.New("Deleted more than one report")
}
return nil
}
func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
// Create a new LState without opening the default libs for security // Create a new LState without opening the default libs for security
L := lua.NewState(lua.Options{SkipOpenLibs: true}) L := lua.NewState(lua.Options{SkipOpenLibs: true})
defer L.Close() defer L.Close()
@ -138,8 +89,8 @@ func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabula
} }
} }
func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { func ReportTabulationHandler(tx store.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
report, err := GetReport(tx, reportid, user.UserId) report, err := tx.GetReport(reportid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
@ -174,7 +125,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
err = InsertReport(context.Tx, &report) err = context.Tx.InsertReport(&report)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -185,7 +136,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
if context.LastLevel() { if context.LastLevel() {
//Return all Reports //Return all Reports
var rl models.ReportList var rl models.ReportList
reports, err := GetReports(context.Tx, user.UserId) reports, err := context.Tx.GetReports(user.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -203,7 +154,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
return ReportTabulationHandler(context.Tx, r, user, reportid) return ReportTabulationHandler(context.Tx, r, user, reportid)
} else { } else {
// Return Report with this Id // Return Report with this Id
report, err := GetReport(context.Tx, reportid, user.UserId) report, err := context.Tx.GetReport(reportid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
@ -227,7 +178,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
err = UpdateReport(context.Tx, &report) err = context.Tx.UpdateReport(&report)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -235,12 +186,12 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
return &report return &report
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
report, err := GetReport(context.Tx, reportid, user.UserId) report, err := context.Tx.GetReport(reportid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
err = DeleteReport(context.Tx, report) err = context.Tx.DeleteReport(report)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)

View File

@ -4,8 +4,8 @@ package handlers
import ( import (
"errors" "errors"
"fmt"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
@ -50,108 +50,34 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security {
return nil return nil
} }
func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) { func UpdateSecurity(tx store.Tx, s *models.Security) (err error) {
var s models.Security user, err := tx.GetUser(s.UserId)
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
if err != nil {
return nil, err
}
return &s, nil
}
func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) {
var securities []*models.Security
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
if err != nil {
return nil, err
}
return &securities, nil
}
func InsertSecurity(tx *Tx, s *models.Security) error {
err := tx.Insert(s)
if err != nil {
return err
}
return nil
}
func UpdateSecurity(tx *Tx, s *models.Security) (err error) {
user, err := GetUser(tx, s.UserId)
if err != nil { if err != nil {
return return
} else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency { } else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency {
return errors.New("Cannot change security which is user's default currency to be non-currency") return errors.New("Cannot change security which is user's default currency to be non-currency")
} }
count, err := tx.Update(s) err = tx.UpdateSecurity(s)
if err != nil { if err != nil {
return return
} }
if count > 1 {
return fmt.Errorf("Updated %d securities (expected 1)", count)
}
return nil return nil
} }
type SecurityInUseError struct { func ImportGetCreateSecurity(tx store.Tx, userid int64, security *models.Security) (*models.Security, error) {
message string
}
func (e SecurityInUseError) Error() string {
return e.message
}
func DeleteSecurity(tx *Tx, s *models.Security) error {
// First, ensure no accounts are using this security
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
if accounts != 0 {
return SecurityInUseError{"One or more accounts still use this security"}
}
user, err := GetUser(tx, s.UserId)
if err != nil {
return err
} else if user.DefaultCurrency == s.SecurityId {
return SecurityInUseError{"Cannot delete security which is user's default currency"}
}
// Remove all prices involving this security (either of this security, or
// using it as a currency)
_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
if err != nil {
return err
}
count, err := tx.Delete(s)
if err != nil {
return err
}
if count != 1 {
return errors.New("Deleted more than one security")
}
return nil
}
func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) {
security.UserId = userid security.UserId = userid
if len(security.AlternateId) == 0 { if len(security.AlternateId) == 0 {
// Always create a new local security if we can't match on the AlternateId // Always create a new local security if we can't match on the AlternateId
err := InsertSecurity(tx, security) err := tx.InsertSecurity(security)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return security, nil return security, nil
} }
var securities []*models.Security securities, err := tx.FindMatchingSecurities(security)
_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -159,7 +85,7 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*
// First try to find a case insensitive match on the name or symbol // First try to find a case insensitive match on the name or symbol
upperName := strings.ToUpper(security.Name) upperName := strings.ToUpper(security.Name)
upperSymbol := strings.ToUpper(security.Symbol) upperSymbol := strings.ToUpper(security.Symbol)
for _, s := range securities { for _, s := range *securities {
if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) || if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) ||
(len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) { (len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) {
return s, nil return s, nil
@ -168,7 +94,7 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*
// if strings.Contains(strings.ToUpper(security.Name), upperSearch) || // if strings.Contains(strings.ToUpper(security.Name), upperSearch) ||
// Try to find a partial string match on the name or symbol // Try to find a partial string match on the name or symbol
for _, s := range securities { for _, s := range *securities {
sUpperName := strings.ToUpper(s.Name) sUpperName := strings.ToUpper(s.Name)
sUpperSymbol := strings.ToUpper(s.Symbol) sUpperSymbol := strings.ToUpper(s.Symbol)
if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) || if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) ||
@ -178,12 +104,12 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*
} }
// Give up and return the first security in the list // Give up and return the first security in the list
if len(securities) > 0 { if len(*securities) > 0 {
return securities[0], nil return (*securities)[0], nil
} }
// If there wasn't even one security in the list, make a new one // If there wasn't even one security in the list, make a new one
err = InsertSecurity(tx, security) err = tx.InsertSecurity(security)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -216,7 +142,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
security.SecurityId = -1 security.SecurityId = -1
security.UserId = user.UserId security.UserId = user.UserId
err = InsertSecurity(context.Tx, &security) err = context.Tx.InsertSecurity(&security)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -228,7 +154,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
//Return all securities //Return all securities
var sl models.SecurityList var sl models.SecurityList
securities, err := GetSecurities(context.Tx, user.UserId) securities, err := context.Tx.GetSecurities(user.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -249,7 +175,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
return PriceHandler(r, context, user, securityid) return PriceHandler(r, context, user, securityid)
} }
security, err := GetSecurity(context.Tx, securityid, user.UserId) security, err := context.Tx.GetSecurity(securityid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
@ -283,13 +209,13 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
return &security return &security
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
security, err := GetSecurity(context.Tx, securityid, user.UserId) security, err := context.Tx.GetSecurity(securityid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
err = DeleteSecurity(context.Tx, security) err = context.Tx.DeleteSecurity(security)
if _, ok := err.(SecurityInUseError); ok { if _, ok := err.(store.SecurityInUseError); ok {
return NewError(7 /*In Use Error*/) return NewError(7 /*In Use Error*/)
} else if err != nil { } else if err != nil {
log.Print(err) log.Print(err)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
) )
@ -14,7 +15,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
ctx := L.Context() ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx) tx, ok := ctx.Value(dbContextKey).(store.Tx)
if !ok { if !ok {
return nil, errors.New("Couldn't find tx in lua's Context") return nil, errors.New("Couldn't find tx in lua's Context")
} }
@ -26,7 +27,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
return nil, errors.New("Couldn't find User in lua's Context") return nil, errors.New("Couldn't find User in lua's Context")
} }
securities, err := GetSecurities(tx, user.UserId) securities, err := tx.GetSecurities(user.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -158,7 +159,7 @@ func luaClosestPrice(L *lua.LState) int {
date := luaCheckTime(L, 3) date := luaCheckTime(L, 3)
ctx := L.Context() ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx) tx, ok := ctx.Value(dbContextKey).(store.Tx)
if !ok { if !ok {
panic("Couldn't find tx in lua's Context") panic("Couldn't find tx in lua's Context")
} }

View File

@ -3,36 +3,37 @@ package handlers
import ( import (
"fmt" "fmt"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"log" "log"
"net/http" "net/http"
"time" "time"
) )
func GetSession(tx *Tx, r *http.Request) (*models.Session, error) { func GetSession(tx store.Tx, r *http.Request) (*models.Session, error) {
var s models.Session
cookie, err := r.Cookie("moneygo-session") cookie, err := r.Cookie("moneygo-session")
if err != nil { if err != nil {
return nil, fmt.Errorf("moneygo-session cookie not set") return nil, fmt.Errorf("moneygo-session cookie not set")
} }
s.SessionSecret = cookie.Value
err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) s, err := tx.GetSession(cookie.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.Expires.Before(time.Now()) { if s.Expires.Before(time.Now()) {
tx.Delete(&s) err := tx.DeleteSession(s)
if err != nil {
log.Printf("Unexpected error when attempting to delete expired session: %s", err)
}
return nil, fmt.Errorf("Session has expired") return nil, fmt.Errorf("Session has expired")
} }
return &s, nil return s, nil
} }
func DeleteSessionIfExists(tx *Tx, r *http.Request) error { func DeleteSessionIfExists(tx store.Tx, r *http.Request) error {
session, err := GetSession(tx, r) session, err := GetSession(tx, r)
if err == nil { if err == nil {
_, err := tx.Delete(session) err := tx.DeleteSession(session)
if err != nil { if err != nil {
return err return err
} }
@ -50,21 +51,26 @@ func (n *NewSessionWriter) Write(w http.ResponseWriter) error {
return n.session.Write(w) return n.session.Write(w)
} }
func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { func NewSession(tx store.Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
err := DeleteSessionIfExists(tx, r)
if err != nil {
return nil, err
}
s, err := models.NewSession(userid) s, err := models.NewSession(userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", s.SessionSecret) exists, err := tx.SessionExists(s.SessionSecret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if existing > 0 { if exists {
return nil, fmt.Errorf("%d session(s) exist with the generated session_secret", existing) return nil, fmt.Errorf("Session already exists with the generated session_secret")
} }
err = tx.Insert(s) err = tx.InsertSession(s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -79,22 +85,19 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
dbuser, err := GetUserByUsername(context.Tx, user.Username) // Hash password before checking username to help mitigate timing
// attacks
user.HashPassword()
dbuser, err := context.Tx.GetUserByUsername(user.Username)
if err != nil { if err != nil {
return NewError(2 /*Unauthorized Access*/) return NewError(2 /*Unauthorized Access*/)
} }
user.HashPassword()
if user.PasswordHash != dbuser.PasswordHash { if user.PasswordHash != dbuser.PasswordHash {
return NewError(2 /*Unauthorized Access*/) return NewError(2 /*Unauthorized Access*/)
} }
err = DeleteSessionIfExists(context.Tx, r)
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)
}
sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId) sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)

View File

@ -2,24 +2,18 @@ package handlers
import ( import (
"errors" "errors"
"fmt"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"log" "log"
"math/big" "math/big"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"time"
) )
func SplitAlreadyImported(tx *Tx, s *models.Split) (bool, error) {
count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
return count == 1, err
}
// Return a map of security ID's to big.Rat's containing the amount that // Return a map of security ID's to big.Rat's containing the amount that
// security is imbalanced by // security is imbalanced by
func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, error) { func GetTransactionImbalances(tx store.Tx, t *models.Transaction) (map[int64]big.Rat, error) {
sums := make(map[int64]big.Rat) sums := make(map[int64]big.Rat)
if !t.Valid() { if !t.Valid() {
@ -31,7 +25,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat,
if t.Splits[i].AccountId != -1 { if t.Splits[i].AccountId != -1 {
var err error var err error
var account *models.Account var account *models.Account
account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId) account, err = tx.GetAccount(t.Splits[i].AccountId, t.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -47,7 +41,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat,
// Returns true if all securities contained in this transaction are balanced, // Returns true if all securities contained in this transaction are balanced,
// false otherwise // false otherwise
func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) { func TransactionBalanced(tx store.Tx, t *models.Transaction) (bool, error) {
var zero big.Rat var zero big.Rat
sums, err := GetTransactionImbalances(tx, t) sums, err := GetTransactionImbalances(tx, t)
@ -63,219 +57,6 @@ func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
return true, nil return true, nil
} }
func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transaction, error) {
var t models.Transaction
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
if err != nil {
return nil, err
}
_, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
if err != nil {
return nil, err
}
return &t, nil
}
func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) {
var transactions []models.Transaction
_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
if err != nil {
return nil, err
}
for i := range transactions {
_, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId)
if err != nil {
return nil, err
}
}
return &transactions, nil
}
func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error {
for i := range accountids {
account, err := GetAccount(tx, accountids[i], user.UserId)
if err != nil {
return err
}
account.AccountVersion++
count, err := tx.Update(account)
if err != nil {
return err
}
if count != 1 {
return errors.New("Updated more than one account")
}
}
return nil
}
type AccountMissingError struct{}
func (ame AccountMissingError) Error() string {
return "Account missing"
}
func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
// Map of any accounts with transaction splits being added
a_map := make(map[int64]bool)
for i := range t.Splits {
if t.Splits[i].AccountId != -1 {
existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId)
if err != nil {
return err
}
if existing != 1 {
return AccountMissingError{}
}
a_map[t.Splits[i].AccountId] = true
} else if t.Splits[i].SecurityId == -1 {
return AccountMissingError{}
}
}
//increment versions for all accounts
var a_ids []int64
for id := range a_map {
a_ids = append(a_ids, id)
}
// ensure at least one of the splits is associated with an actual account
if len(a_ids) < 1 {
return AccountMissingError{}
}
err := incrementAccountVersions(tx, user, a_ids)
if err != nil {
return err
}
t.UserId = user.UserId
err = tx.Insert(t)
if err != nil {
return err
}
for i := range t.Splits {
t.Splits[i].TransactionId = t.TransactionId
t.Splits[i].SplitId = -1
err = tx.Insert(t.Splits[i])
if err != nil {
return err
}
}
return nil
}
func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
var existing_splits []*models.Split
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
if err != nil {
return err
}
// Map of any accounts with transaction splits being added
a_map := make(map[int64]bool)
// Make a map with any existing splits for this transaction
s_map := make(map[int64]bool)
for i := range existing_splits {
s_map[existing_splits[i].SplitId] = true
}
// Insert splits, updating any pre-existing ones
for i := range t.Splits {
t.Splits[i].TransactionId = t.TransactionId
_, ok := s_map[t.Splits[i].SplitId]
if ok {
count, err := tx.Update(t.Splits[i])
if err != nil {
return err
}
if count > 1 {
return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count)
}
delete(s_map, t.Splits[i].SplitId)
} else {
t.Splits[i].SplitId = -1
err := tx.Insert(t.Splits[i])
if err != nil {
return err
}
}
if t.Splits[i].AccountId != -1 {
a_map[t.Splits[i].AccountId] = true
}
}
// Delete any remaining pre-existing splits
for i := range existing_splits {
_, ok := s_map[existing_splits[i].SplitId]
if existing_splits[i].AccountId != -1 {
a_map[existing_splits[i].AccountId] = true
}
if ok {
_, err := tx.Delete(existing_splits[i])
if err != nil {
return err
}
}
}
// Increment versions for all accounts with modified splits
var a_ids []int64
for id := range a_map {
a_ids = append(a_ids, id)
}
err = incrementAccountVersions(tx, user, a_ids)
if err != nil {
return err
}
count, err := tx.Update(t)
if err != nil {
return err
}
if count > 1 {
return fmt.Errorf("Updated %d transactions (expected 1)", count)
}
return nil
}
func DeleteTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
var accountids []int64
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId)
if err != nil {
return err
}
count, err := tx.Delete(t)
if err != nil {
return err
}
if count != 1 {
return errors.New("Deleted more than one transaction")
}
err = incrementAccountVersions(tx, user, accountids)
if err != nil {
return err
}
return nil
}
func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter { func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter {
user, err := GetUserFromSession(context.Tx, r) user, err := GetUserFromSession(context.Tx, r)
if err != nil { if err != nil {
@ -296,7 +77,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
for i := range transaction.Splits { for i := range transaction.Splits {
transaction.Splits[i].SplitId = -1 transaction.Splits[i].SplitId = -1
_, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId) _, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
@ -310,9 +91,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
err = InsertTransaction(context.Tx, &transaction, user) err = context.Tx.InsertTransaction(&transaction, user)
if err != nil { if err != nil {
if _, ok := err.(AccountMissingError); ok { if _, ok := err.(store.AccountMissingError); ok {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} else { } else {
log.Print(err) log.Print(err)
@ -325,7 +106,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
if context.LastLevel() { if context.LastLevel() {
//Return all Transactions //Return all Transactions
var al models.TransactionList var al models.TransactionList
transactions, err := GetTransactions(context.Tx, user.UserId) transactions, err := context.Tx.GetTransactions(user.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -338,7 +119,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) transaction, err := context.Tx.GetTransaction(transactionid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
@ -370,13 +151,13 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
} }
for i := range transaction.Splits { for i := range transaction.Splits {
_, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId) _, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
} }
err = UpdateTransaction(context.Tx, &transaction, user) err = context.Tx.UpdateTransaction(&transaction, user)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -384,12 +165,12 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
return &transaction return &transaction
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) transaction, err := context.Tx.GetTransaction(transactionid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
err = DeleteTransaction(context.Tx, transaction, user) err = context.Tx.DeleteTransaction(transaction, user)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -401,41 +182,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) { func BalanceFromSplits(splits *[]*models.Split) (*big.Rat, error) {
var pageDifference, tmp big.Rat
for i := range transactions {
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
if err != nil {
return nil, err
}
// Sum up the amounts from the splits we're returning so we can return
// an ending balance
for j := range transactions[i].Splits {
if transactions[i].Splits[j].AccountId == accountid {
rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount)
if err != nil {
return nil, err
}
tmp.Add(&pageDifference, rat_amount)
pageDifference.Set(&tmp)
}
}
}
return &pageDifference, nil
}
func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) {
var splits []models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
_, err := tx.Select(&splits, sql, accountid, user.UserId)
if err != nil {
return nil, err
}
var balance, tmp big.Rat var balance, tmp big.Rat
for _, s := range splits { for _, s := range *splits {
rat_amount, err := models.GetBigAmount(s.Amount) rat_amount, err := models.GetBigAmount(s.Amount)
if err != nil { if err != nil {
return nil, err return nil, err
@ -447,132 +196,6 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er
return &balance, nil return &balance, nil
} }
// Assumes accountid is valid and is owned by the current user
func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
var splits []models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
_, err := tx.Select(&splits, sql, accountid, user.UserId, date)
if err != nil {
return nil, err
}
var balance, tmp big.Rat
for _, s := range splits {
rat_amount, err := models.GetBigAmount(s.Amount)
if err != nil {
return nil, err
}
tmp.Add(&balance, rat_amount)
balance.Set(&tmp)
}
return &balance, nil
}
func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
var splits []models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
_, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end)
if err != nil {
return nil, err
}
var balance, tmp big.Rat
for _, s := range splits {
rat_amount, err := models.GetBigAmount(s.Amount)
if err != nil {
return nil, err
}
tmp.Add(&balance, rat_amount)
balance.Set(&tmp)
}
return &balance, nil
}
func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
var transactions []models.Transaction
var atl models.AccountTransactionsList
var sqlsort, balanceLimitOffset string
var balanceLimitOffsetArg uint64
if sort == "date-asc" {
sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC"
balanceLimitOffset = " LIMIT ?"
balanceLimitOffsetArg = page * limit
} else if sort == "date-desc" {
numSplits, err := tx.SelectInt("SELECT count(*) FROM splits")
if err != nil {
return nil, err
}
sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC"
balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits)
balanceLimitOffsetArg = (page + 1) * limit
}
var sqloffset string
if page > 0 {
sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
}
account, err := GetAccount(tx, accountid, user.UserId)
if err != nil {
return nil, err
}
atl.Account = account
sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset
_, err = tx.Select(&transactions, sql, user.UserId, accountid, limit)
if err != nil {
return nil, err
}
atl.Transactions = &transactions
pageDifference, err := TransactionsBalanceDifference(tx, accountid, transactions)
if err != nil {
return nil, err
}
count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid)
if err != nil {
return nil, err
}
atl.TotalTransactions = count
security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId)
if err != nil {
return nil, err
}
if security == nil {
return nil, errors.New("Security not found")
}
// Sum all the splits for all transaction splits for this account that
// occurred before the page we're returning
var amounts []string
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
if err != nil {
return nil, err
}
var tmp, balance big.Rat
for _, amount := range amounts {
rat_amount, err := models.GetBigAmount(amount)
if err != nil {
return nil, err
}
tmp.Add(&balance, rat_amount)
balance.Set(&tmp)
}
atl.BeginningBalance = balance.FloatString(security.Precision)
atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision)
return &atl, nil
}
// Return only those transactions which have at least one split pertaining to // Return only those transactions which have at least one split pertaining to
// an account // an account
func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
@ -608,7 +231,7 @@ func AccountTransactionsHandler(context *Context, r *http.Request, user *models.
sort = sortstring sort = sortstring
} }
accountTransactions, err := GetAccountTransactions(context.Tx, user, accountid, sort, page, limit) accountTransactions, err := context.Tx.GetAccountTransactions(user, accountid, sort, page, limit)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)

View File

@ -276,7 +276,7 @@ func TestGetTransactions(t *testing.T) {
found := false found := false
for _, tran := range *tl.Transactions { for _, tran := range *tl.Transactions {
if tran.TransactionId == curr.TransactionId { if tran.TransactionId == curr.TransactionId {
ensureTransactionsMatch(t, &curr, &tran, nil, true, true) ensureTransactionsMatch(t, &curr, tran, nil, true, true)
if _, ok := foundIds[tran.TransactionId]; ok { if _, ok := foundIds[tran.TransactionId]; ok {
continue continue
} }
@ -410,7 +410,7 @@ func helperTestAccountTransactions(t *testing.T, d *TestData, account *models.Ac
} }
if atl.Transactions != nil { if atl.Transactions != nil {
for _, tran := range *atl.Transactions { for _, tran := range *atl.Transactions {
transactions = append(transactions, tran) transactions = append(transactions, *tran)
} }
lastFetchCount = int64(len(*atl.Transactions)) lastFetchCount = int64(len(*atl.Transactions))
} else { } else {

View File

@ -2,8 +2,8 @@ package handlers
import ( import (
"errors" "errors"
"fmt"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"log" "log"
"net/http" "net/http"
) )
@ -14,41 +14,21 @@ func (ueu UserExistsError) Error() string {
return "User exists" return "User exists"
} }
func GetUser(tx *Tx, userid int64) (*models.User, error) { func InsertUser(tx store.Tx, u *models.User) error {
var u models.User
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
if err != nil {
return nil, err
}
return &u, nil
}
func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
var u models.User
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
if err != nil {
return nil, err
}
return &u, nil
}
func InsertUser(tx *Tx, u *models.User) error {
security_template := FindCurrencyTemplate(u.DefaultCurrency) security_template := FindCurrencyTemplate(u.DefaultCurrency)
if security_template == nil { if security_template == nil {
return errors.New("Invalid ISO4217 Default Currency") return errors.New("Invalid ISO4217 Default Currency")
} }
existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username) exists, err := tx.UsernameExists(u.Username)
if err != nil { if err != nil {
return err return err
} }
if existing > 0 { if exists {
return UserExistsError{} return UserExistsError{}
} }
err = tx.Insert(u) err = tx.InsertUser(u)
if err != nil { if err != nil {
return err return err
} }
@ -58,33 +38,31 @@ func InsertUser(tx *Tx, u *models.User) error {
security = *security_template security = *security_template
security.UserId = u.UserId security.UserId = u.UserId
err = InsertSecurity(tx, &security) err = tx.InsertSecurity(&security)
if err != nil { if err != nil {
return err return err
} }
// Update the user's DefaultCurrency to our new SecurityId // Update the user's DefaultCurrency to our new SecurityId
u.DefaultCurrency = security.SecurityId u.DefaultCurrency = security.SecurityId
count, err := tx.Update(u) err = tx.UpdateUser(u)
if err != nil { if err != nil {
return err return err
} else if count != 1 {
return errors.New("Would have updated more than one user")
} }
return nil return nil
} }
func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) { func GetUserFromSession(tx store.Tx, r *http.Request) (*models.User, error) {
s, err := GetSession(tx, r) s, err := GetSession(tx, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return GetUser(tx, s.UserId) return tx.GetUser(s.UserId)
} }
func UpdateUser(tx *Tx, u *models.User) error { func UpdateUser(tx store.Tx, u *models.User) error {
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) security, err := tx.GetSecurity(u.DefaultCurrency, u.UserId)
if err != nil { if err != nil {
return err return err
} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency { } else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
@ -93,49 +71,7 @@ func UpdateUser(tx *Tx, u *models.User) error {
return errors.New("New DefaultCurrency security is not a currency") return errors.New("New DefaultCurrency security is not a currency")
} }
count, err := tx.Update(u) err = tx.UpdateUser(u)
if err != nil {
return err
} else if count != 1 {
return errors.New("Would have updated more than one user")
}
return nil
}
func DeleteUser(tx *Tx, u *models.User) error {
count, err := tx.Delete(u)
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("No user to delete")
}
_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", u.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", u.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId)
if err != nil { if err != nil {
return err return err
} }
@ -204,7 +140,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter {
return user return user
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
err := DeleteUser(context.Tx, user) err := context.Tx.DeleteUser(user)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)

View File

@ -94,7 +94,7 @@ type Account struct {
} }
type AccountList struct { type AccountList struct {
Accounts *[]Account `json:"accounts"` Accounts *[]*Account `json:"accounts"`
} }
func (a *Account) Write(w http.ResponseWriter) error { func (a *Account) Write(w http.ResponseWriter) error {

View File

@ -28,7 +28,7 @@ func (r *Report) Read(json_str string) error {
} }
type ReportList struct { type ReportList struct {
Reports *[]Report `json:"reports"` Reports *[]*Report `json:"reports"`
} }
func (rl *ReportList) Write(w http.ResponseWriter) error { func (rl *ReportList) Write(w http.ResponseWriter) error {

View File

@ -82,12 +82,12 @@ type Transaction struct {
} }
type TransactionList struct { type TransactionList struct {
Transactions *[]Transaction `json:"transactions"` Transactions *[]*Transaction `json:"transactions"`
} }
type AccountTransactionsList struct { type AccountTransactionsList struct {
Account *Account Account *Account
Transactions *[]Transaction Transactions *[]*Transaction
TotalTransactions int64 TotalTransactions int64
BeginningBalance string BeginningBalance string
EndingBalance string EndingBalance string

View File

@ -0,0 +1,133 @@
package db
import (
"errors"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
)
func (tx *Tx) GetAccount(accountid int64, userid int64) (*models.Account, error) {
var account models.Account
err := tx.SelectOne(&account, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
if err != nil {
return nil, err
}
return &account, nil
}
func (tx *Tx) GetAccounts(userid int64) (*[]*models.Account, error) {
var accounts []*models.Account
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
if err != nil {
return nil, err
}
return &accounts, nil
}
func (tx *Tx) FindMatchingAccounts(account *models.Account) (*[]*models.Account, error) {
var accounts []*models.Account
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC", account.UserId, account.SecurityId, account.Type, account.Name, account.ParentAccountId)
if err != nil {
return nil, err
}
return &accounts, nil
}
func (tx *Tx) insertUpdateAccount(account *models.Account, insert bool) error {
found := make(map[int64]bool)
if !insert {
found[account.AccountId] = true
}
parentid := account.ParentAccountId
depth := 0
for parentid != -1 {
depth += 1
if depth > 100 {
return store.TooMuchNestingError{}
}
var a models.Account
err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid)
if err != nil {
return store.ParentAccountMissingError{}
}
// Insertion by itself can never result in circular dependencies
if insert {
break
}
found[parentid] = true
parentid = a.ParentAccountId
if _, ok := found[parentid]; ok {
return store.CircularAccountsError{}
}
}
if insert {
err := tx.Insert(account)
if err != nil {
return err
}
} else {
oldacct, err := tx.GetAccount(account.AccountId, account.UserId)
if err != nil {
return err
}
account.AccountVersion = oldacct.AccountVersion + 1
count, err := tx.Update(account)
if err != nil {
return err
}
if count != 1 {
return errors.New("Updated more than one account")
}
}
return nil
}
func (tx *Tx) InsertAccount(account *models.Account) error {
return tx.insertUpdateAccount(account, true)
}
func (tx *Tx) UpdateAccount(account *models.Account) error {
return tx.insertUpdateAccount(account, false)
}
func (tx *Tx) DeleteAccount(account *models.Account) error {
if account.ParentAccountId != -1 {
// Re-parent splits to this account's parent account if this account isn't a root account
_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", account.ParentAccountId, account.AccountId)
if err != nil {
return err
}
} else {
// Delete splits if this account is a root account
_, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", account.AccountId)
if err != nil {
return err
}
}
// Re-parent child accounts to this account's parent account
_, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", account.ParentAccountId, account.AccountId)
if err != nil {
return err
}
count, err := tx.Delete(account)
if err != nil {
return err
}
if count != 1 {
return errors.New("Was going to delete more than one account")
}
return nil
}

View File

@ -6,6 +6,7 @@ import (
"github.com/aclindsa/gorp" "github.com/aclindsa/gorp"
"github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/config"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -18,7 +19,7 @@ import (
// implementation's string type specified by the same. // implementation's string type specified by the same.
const luaMaxLengthBuffer int = 4096 const luaMaxLengthBuffer int = 4096
func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { func getDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
var dialect gorp.Dialect var dialect gorp.Dialect
if dbtype == config.SQLite { if dbtype == config.SQLite {
dialect = gorp.SqliteDialect{} dialect = gorp.SqliteDialect{}
@ -38,11 +39,11 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
dbmap := &gorp.DbMap{Db: db, Dialect: dialect} dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId")
dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId") dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId")
dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId")
dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId") dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId")
dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId")
dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId")
dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId") dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId")
dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId") dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId")
dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId")
rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId") rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId")
rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer) rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer)
@ -54,9 +55,50 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
return dbmap, nil return dbmap, nil
} }
func GetDSN(dbtype config.DbType, dsn string) string { func getDSN(dbtype config.DbType, dsn string) string {
if dbtype == config.MySQL && !strings.Contains(dsn, "parseTime=true") { if dbtype == config.MySQL && !strings.Contains(dsn, "parseTime=true") {
log.Fatalf("The DSN for MySQL MUST contain 'parseTime=True' but does not!") log.Fatalf("The DSN for MySQL MUST contain 'parseTime=True' but does not!")
} }
return dsn return dsn
} }
type DbStore struct {
dbMap *gorp.DbMap
}
func (db *DbStore) Empty() error {
return db.dbMap.TruncateTables()
}
func (db *DbStore) Begin() (store.Tx, error) {
tx, err := db.dbMap.Begin()
if err != nil {
return nil, err
}
return &Tx{db.dbMap.Dialect, tx}, nil
}
func (db *DbStore) Close() error {
err := db.dbMap.Db.Close()
db.dbMap = nil
return err
}
func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) {
dsn = getDSN(dbtype, dsn)
database, err := sql.Open(dbtype.String(), dsn)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
database.Close()
}
}()
dbmap, err := getDbMap(database, dbtype)
if err != nil {
return nil, err
}
return &DbStore{dbmap}, nil
}

View File

@ -0,0 +1,78 @@
package db
import (
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"time"
)
func (tx *Tx) PriceExists(price *models.Price) (bool, error) {
var prices []*models.Price
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value)
return len(prices) > 0, err
}
func (tx *Tx) InsertPrice(price *models.Price) error {
return tx.Insert(price)
}
func (tx *Tx) GetPrice(priceid, securityid int64) (*models.Price, error) {
var price models.Price
err := tx.SelectOne(&price, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
if err != nil {
return nil, err
}
return &price, nil
}
func (tx *Tx) GetPrices(securityid int64) (*[]*models.Price, error) {
var prices []*models.Price
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
if err != nil {
return nil, err
}
return &prices, nil
}
// Return the latest price for security in currency units before date
func (tx *Tx) GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) {
var price models.Price
err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil {
return nil, err
}
return &price, nil
}
// Return the earliest price for security in currency units after date
func (tx *Tx) GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) {
var price models.Price
err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil {
return nil, err
}
return &price, nil
}
func (tx *Tx) UpdatePrice(price *models.Price) error {
count, err := tx.Update(price)
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected to update 1 price, was going to update %d", count)
}
return nil
}
func (tx *Tx) DeletePrice(price *models.Price) error {
count, err := tx.Delete(price)
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected to delete 1 price, was going to delete %d", count)
}
return nil
}

View File

@ -0,0 +1,56 @@
package db
import (
"fmt"
"github.com/aclindsa/moneygo/internal/models"
)
func (tx *Tx) GetReport(reportid int64, userid int64) (*models.Report, error) {
var r models.Report
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
if err != nil {
return nil, err
}
return &r, nil
}
func (tx *Tx) GetReports(userid int64) (*[]*models.Report, error) {
var reports []*models.Report
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
if err != nil {
return nil, err
}
return &reports, nil
}
func (tx *Tx) InsertReport(report *models.Report) error {
err := tx.Insert(report)
if err != nil {
return err
}
return nil
}
func (tx *Tx) UpdateReport(report *models.Report) error {
count, err := tx.Update(report)
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected to update 1 report, was going to update %d", count)
}
return nil
}
func (tx *Tx) DeleteReport(report *models.Report) error {
count, err := tx.Delete(report)
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected to delete 1 report, was going to delete %d", count)
}
return nil
}

View File

@ -0,0 +1,88 @@
package db
import (
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
)
func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) {
var s models.Security
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
if err != nil {
return nil, err
}
return &s, nil
}
func (tx *Tx) GetSecurities(userid int64) (*[]*models.Security, error) {
var securities []*models.Security
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
if err != nil {
return nil, err
}
return &securities, nil
}
func (tx *Tx) FindMatchingSecurities(security *models.Security) (*[]*models.Security, error) {
var securities []*models.Security
_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", security.UserId, security.Type, security.AlternateId, security.Precision)
if err != nil {
return nil, err
}
return &securities, nil
}
func (tx *Tx) InsertSecurity(s *models.Security) error {
err := tx.Insert(s)
if err != nil {
return err
}
return nil
}
func (tx *Tx) UpdateSecurity(security *models.Security) error {
count, err := tx.Update(security)
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected to update 1 security, was going to update %d", count)
}
return nil
}
func (tx *Tx) DeleteSecurity(s *models.Security) error {
// First, ensure no accounts are using this security
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
if accounts != 0 {
return store.SecurityInUseError{"One or more accounts still use this security"}
}
user, err := tx.GetUser(s.UserId)
if err != nil {
return err
} else if user.DefaultCurrency == s.SecurityId {
return store.SecurityInUseError{"Cannot delete security which is user's default currency"}
}
// Remove all prices involving this security (either of this security, or
// using it as a currency)
_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
if err != nil {
return err
}
count, err := tx.Delete(s)
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected to delete 1 security, was going to delete %d", count)
}
return nil
}

View File

@ -0,0 +1,42 @@
package db
import (
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"time"
)
func (tx *Tx) InsertSession(session *models.Session) error {
return tx.Insert(session)
}
func (tx *Tx) GetSession(secret string) (*models.Session, error) {
var s models.Session
err := tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", secret)
if err != nil {
return nil, err
}
if s.Expires.Before(time.Now()) {
tx.Delete(&s)
return nil, fmt.Errorf("Session has expired")
}
return &s, nil
}
func (tx *Tx) SessionExists(secret string) (bool, error) {
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", secret)
return existing != 0, err
}
func (tx *Tx) DeleteSession(session *models.Session) error {
count, err := tx.Delete(session)
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count)
}
return nil
}

View File

@ -0,0 +1,361 @@
package db
import (
"errors"
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"math/big"
"time"
)
func (tx *Tx) incrementAccountVersions(user *models.User, accountids []int64) error {
for i := range accountids {
account, err := tx.GetAccount(accountids[i], user.UserId)
if err != nil {
return err
}
account.AccountVersion++
count, err := tx.Update(account)
if err != nil {
return err
}
if count != 1 {
return errors.New("Updated more than one account")
}
}
return nil
}
func (tx *Tx) InsertTransaction(t *models.Transaction, user *models.User) error {
// Map of any accounts with transaction splits being added
a_map := make(map[int64]bool)
for i := range t.Splits {
if t.Splits[i].AccountId != -1 {
existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId)
if err != nil {
return err
}
if existing != 1 {
return store.AccountMissingError{}
}
a_map[t.Splits[i].AccountId] = true
} else if t.Splits[i].SecurityId == -1 {
return store.AccountMissingError{}
}
}
//increment versions for all accounts
var a_ids []int64
for id := range a_map {
a_ids = append(a_ids, id)
}
// ensure at least one of the splits is associated with an actual account
if len(a_ids) < 1 {
return store.AccountMissingError{}
}
err := tx.incrementAccountVersions(user, a_ids)
if err != nil {
return err
}
t.UserId = user.UserId
err = tx.Insert(t)
if err != nil {
return err
}
for i := range t.Splits {
t.Splits[i].TransactionId = t.TransactionId
t.Splits[i].SplitId = -1
err = tx.Insert(t.Splits[i])
if err != nil {
return err
}
}
return nil
}
func (tx *Tx) SplitExists(s *models.Split) (bool, error) {
count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
return count == 1, err
}
func (tx *Tx) GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) {
var t models.Transaction
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
if err != nil {
return nil, err
}
_, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
if err != nil {
return nil, err
}
return &t, nil
}
func (tx *Tx) GetTransactions(userid int64) (*[]*models.Transaction, error) {
var transactions []*models.Transaction
_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
if err != nil {
return nil, err
}
for i := range transactions {
_, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId)
if err != nil {
return nil, err
}
}
return &transactions, nil
}
func (tx *Tx) UpdateTransaction(t *models.Transaction, user *models.User) error {
var existing_splits []*models.Split
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
if err != nil {
return err
}
// Map of any accounts with transaction splits being added
a_map := make(map[int64]bool)
// Make a map with any existing splits for this transaction
s_map := make(map[int64]bool)
for i := range existing_splits {
s_map[existing_splits[i].SplitId] = true
}
// Insert splits, updating any pre-existing ones
for i := range t.Splits {
t.Splits[i].TransactionId = t.TransactionId
_, ok := s_map[t.Splits[i].SplitId]
if ok {
count, err := tx.Update(t.Splits[i])
if err != nil {
return err
}
if count > 1 {
return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count)
}
delete(s_map, t.Splits[i].SplitId)
} else {
t.Splits[i].SplitId = -1
err := tx.Insert(t.Splits[i])
if err != nil {
return err
}
}
if t.Splits[i].AccountId != -1 {
a_map[t.Splits[i].AccountId] = true
}
}
// Delete any remaining pre-existing splits
for i := range existing_splits {
_, ok := s_map[existing_splits[i].SplitId]
if existing_splits[i].AccountId != -1 {
a_map[existing_splits[i].AccountId] = true
}
if ok {
_, err := tx.Delete(existing_splits[i])
if err != nil {
return err
}
}
}
// Increment versions for all accounts with modified splits
var a_ids []int64
for id := range a_map {
a_ids = append(a_ids, id)
}
err = tx.incrementAccountVersions(user, a_ids)
if err != nil {
return err
}
count, err := tx.Update(t)
if err != nil {
return err
}
if count > 1 {
return fmt.Errorf("Updated %d transactions (expected 1)", count)
}
return nil
}
func (tx *Tx) DeleteTransaction(t *models.Transaction, user *models.User) error {
var accountids []int64
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId)
if err != nil {
return err
}
count, err := tx.Delete(t)
if err != nil {
return err
}
if count != 1 {
return errors.New("Deleted more than one transaction")
}
err = tx.incrementAccountVersions(user, accountids)
if err != nil {
return err
}
return nil
}
func (tx *Tx) GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) {
var splits []*models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
_, err := tx.Select(&splits, sql, accountid, user.UserId)
if err != nil {
return nil, err
}
return &splits, nil
}
// Assumes accountid is valid and is owned by the current user
func (tx *Tx) GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) {
var splits []*models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
_, err := tx.Select(&splits, sql, accountid, user.UserId, date)
if err != nil {
return nil, err
}
return &splits, err
}
func (tx *Tx) GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) {
var splits []*models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
_, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end)
if err != nil {
return nil, err
}
return &splits, nil
}
func (tx *Tx) transactionsBalanceDifference(accountid int64, transactions []*models.Transaction) (*big.Rat, error) {
var pageDifference, tmp big.Rat
for i := range transactions {
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
if err != nil {
return nil, err
}
// Sum up the amounts from the splits we're returning so we can return
// an ending balance
for j := range transactions[i].Splits {
if transactions[i].Splits[j].AccountId == accountid {
rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount)
if err != nil {
return nil, err
}
tmp.Add(&pageDifference, rat_amount)
pageDifference.Set(&tmp)
}
}
}
return &pageDifference, nil
}
func (tx *Tx) GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
var transactions []*models.Transaction
var atl models.AccountTransactionsList
var sqlsort, balanceLimitOffset string
var balanceLimitOffsetArg uint64
if sort == "date-asc" {
sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC"
balanceLimitOffset = " LIMIT ?"
balanceLimitOffsetArg = page * limit
} else if sort == "date-desc" {
numSplits, err := tx.SelectInt("SELECT count(*) FROM splits")
if err != nil {
return nil, err
}
sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC"
balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits)
balanceLimitOffsetArg = (page + 1) * limit
}
var sqloffset string
if page > 0 {
sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
}
account, err := tx.GetAccount(accountid, user.UserId)
if err != nil {
return nil, err
}
atl.Account = account
sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset
_, err = tx.Select(&transactions, sql, user.UserId, accountid, limit)
if err != nil {
return nil, err
}
atl.Transactions = &transactions
pageDifference, err := tx.transactionsBalanceDifference(accountid, transactions)
if err != nil {
return nil, err
}
count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid)
if err != nil {
return nil, err
}
atl.TotalTransactions = count
security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId)
if err != nil {
return nil, err
}
if security == nil {
return nil, errors.New("Security not found")
}
// Sum all the splits for all transaction splits for this account that
// occurred before the page we're returning
var amounts []string
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
if err != nil {
return nil, err
}
var tmp, balance big.Rat
for _, amount := range amounts {
rat_amount, err := models.GetBigAmount(amount)
if err != nil {
return nil, err
}
tmp.Add(&balance, rat_amount)
balance.Set(&tmp)
}
atl.BeginningBalance = balance.FloatString(security.Precision)
atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision)
return &atl, nil
}

View File

@ -1,4 +1,4 @@
package handlers package db
import ( import (
"database/sql" "database/sql"
@ -41,7 +41,20 @@ func (tx *Tx) Insert(list ...interface{}) error {
} }
func (tx *Tx) Update(list ...interface{}) (int64, error) { func (tx *Tx) Update(list ...interface{}) (int64, error) {
return tx.Tx.Update(list...) count, err := tx.Tx.Update(list...)
if count == 0 {
switch tx.Dialect.(type) {
case gorp.MySQLDialect:
// Always return 1 for 0 if we're using MySQL because it returns
// count=0 if the row data was unchanged, even if the row existed
// TODO Find another way to fix this without risking ignoring
// errors
count = 1
}
}
return count, err
} }
func (tx *Tx) Delete(list ...interface{}) (int64, error) { func (tx *Tx) Delete(list ...interface{}) (int64, error) {
@ -55,11 +68,3 @@ func (tx *Tx) Commit() error {
func (tx *Tx) Rollback() error { func (tx *Tx) Rollback() error {
return tx.Tx.Rollback() return tx.Tx.Rollback()
} }
func GetTx(db *gorp.DbMap) (*Tx, error) {
tx, err := db.Begin()
if err != nil {
return nil, err
}
return &Tx{db.Dialect, tx}, nil
}

View File

@ -0,0 +1,86 @@
package db
import (
"fmt"
"github.com/aclindsa/moneygo/internal/models"
)
func (tx *Tx) UsernameExists(username string) (bool, error) {
existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", username)
return existing != 0, err
}
func (tx *Tx) InsertUser(user *models.User) error {
return tx.Insert(user)
}
func (tx *Tx) GetUser(userid int64) (*models.User, error) {
var u models.User
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
if err != nil {
return nil, err
}
return &u, nil
}
func (tx *Tx) GetUserByUsername(username string) (*models.User, error) {
var u models.User
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
if err != nil {
return nil, err
}
return &u, nil
}
func (tx *Tx) UpdateUser(user *models.User) error {
count, err := tx.Update(user)
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected to update 1 user, was going to update %d", count)
}
return nil
}
func (tx *Tx) DeleteUser(user *models.User) error {
count, err := tx.Delete(user)
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count)
}
_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", user.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", user.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", user.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", user.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", user.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", user.UserId)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", user.UserId)
if err != nil {
return err
}
return nil
}

123
internal/store/store.go Normal file
View File

@ -0,0 +1,123 @@
package store
import (
"github.com/aclindsa/moneygo/internal/models"
"time"
)
type UserStore interface {
UsernameExists(username string) (bool, error)
InsertUser(user *models.User) error
GetUser(userid int64) (*models.User, error)
GetUserByUsername(username string) (*models.User, error)
UpdateUser(user *models.User) error
DeleteUser(user *models.User) error
}
type SessionStore interface {
SessionExists(secret string) (bool, error)
InsertSession(session *models.Session) error
GetSession(secret string) (*models.Session, error)
DeleteSession(session *models.Session) error
}
type SecurityInUseError struct {
Message string
}
func (e SecurityInUseError) Error() string {
return e.Message
}
type SecurityStore interface {
InsertSecurity(security *models.Security) error
GetSecurity(securityid int64, userid int64) (*models.Security, error)
GetSecurities(userid int64) (*[]*models.Security, error)
FindMatchingSecurities(security *models.Security) (*[]*models.Security, error)
UpdateSecurity(security *models.Security) error
DeleteSecurity(security *models.Security) error
}
type PriceStore interface {
PriceExists(price *models.Price) (bool, error)
InsertPrice(price *models.Price) error
GetPrice(priceid, securityid int64) (*models.Price, error)
GetPrices(securityid int64) (*[]*models.Price, error)
GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error)
GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error)
UpdatePrice(price *models.Price) error
DeletePrice(price *models.Price) error
}
type ParentAccountMissingError struct{}
func (pame ParentAccountMissingError) Error() string {
return "Parent account missing"
}
type TooMuchNestingError struct{}
func (tmne TooMuchNestingError) Error() string {
return "Too much account nesting"
}
type CircularAccountsError struct{}
func (cae CircularAccountsError) Error() string {
return "Would result in circular account relationship"
}
type AccountStore interface {
InsertAccount(account *models.Account) error
GetAccount(accountid int64, userid int64) (*models.Account, error)
GetAccounts(userid int64) (*[]*models.Account, error)
FindMatchingAccounts(account *models.Account) (*[]*models.Account, error)
UpdateAccount(account *models.Account) error
DeleteAccount(account *models.Account) error
}
type AccountMissingError struct{}
func (ame AccountMissingError) Error() string {
return "Account missing"
}
type TransactionStore interface {
SplitExists(s *models.Split) (bool, error)
InsertTransaction(t *models.Transaction, user *models.User) error
GetTransaction(transactionid int64, userid int64) (*models.Transaction, error)
GetTransactions(userid int64) (*[]*models.Transaction, error)
UpdateTransaction(t *models.Transaction, user *models.User) error
DeleteTransaction(t *models.Transaction, user *models.User) error
GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error)
GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error)
GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error)
GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error)
}
type ReportStore interface {
InsertReport(report *models.Report) error
GetReport(reportid int64, userid int64) (*models.Report, error)
GetReports(userid int64) (*[]*models.Report, error)
UpdateReport(report *models.Report) error
DeleteReport(report *models.Report) error
}
type Tx interface {
Commit() error
Rollback() error
UserStore
SessionStore
SecurityStore
PriceStore
AccountStore
TransactionStore
ReportStore
}
type Store interface {
Empty() error
Begin() (Tx, error)
Close() error
}

15
main.go
View File

@ -3,11 +3,10 @@ package main
//go:generate make //go:generate make
import ( import (
"database/sql"
"flag" "flag"
"github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/config"
"github.com/aclindsa/moneygo/internal/db"
"github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/handlers"
"github.com/aclindsa/moneygo/internal/store/db"
"github.com/kabukky/httpscerts" "github.com/kabukky/httpscerts"
"log" "log"
"net" "net"
@ -67,21 +66,15 @@ func staticHandler(w http.ResponseWriter, r *http.Request, basedir string) {
} }
func main() { func main() {
dsn := db.GetDSN(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN) db, err := db.GetStore(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN)
database, err := sql.Open(cfg.MoneyGo.DBType.String(), dsn)
if err != nil {
log.Fatal(err)
}
defer database.Close()
dbmap, err := db.GetDbMap(database, cfg.MoneyGo.DBType)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer db.Close()
// Get ServeMux for API and add our own handlers for files // Get ServeMux for API and add our own handlers for files
servemux := http.NewServeMux() servemux := http.NewServeMux()
servemux.Handle("/v1/", &handlers.APIHandler{DB: dbmap}) servemux.Handle("/v1/", &handlers.APIHandler{Store: db})
servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir)) servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir))
servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir)) servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir))