mirror of
https://github.com/aclindsa/moneygo.git
synced 2024-10-30 07:40:05 -04:00
Merge pull request #35 from aclindsa/store_split
Split DB activity into 'store'
This commit is contained in:
commit
9cdf4f3c29
@ -20,12 +20,11 @@ env:
|
||||
- MONEYGO_TEST_DB=mysql
|
||||
- MONEYGO_TEST_DB=postgres
|
||||
|
||||
# OSX builds take too long, so don't wait for all of them
|
||||
# OSX builds take too long, so don't wait for them
|
||||
matrix:
|
||||
fast_finish: true
|
||||
allow_failures:
|
||||
- os: osx
|
||||
go: master
|
||||
|
||||
before_install:
|
||||
# Fetch/build coverage reporting tools
|
||||
|
@ -3,43 +3,22 @@ package handlers
|
||||
import (
|
||||
"errors"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) {
|
||||
var a models.Account
|
||||
|
||||
err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) {
|
||||
var accounts []models.Account
|
||||
|
||||
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &accounts, nil
|
||||
}
|
||||
|
||||
// Get (and attempt to create if it doesn't exist). Matches on UserId,
|
||||
// SecurityId, Type, Name, and ParentAccountId
|
||||
func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
|
||||
var accounts []models.Account
|
||||
func GetCreateAccount(tx store.Tx, a models.Account) (*models.Account, error) {
|
||||
var account models.Account
|
||||
|
||||
// Try to find the top-level trading account
|
||||
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId)
|
||||
accounts, err := tx.FindMatchingAccounts(&a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 1 {
|
||||
account = accounts[0]
|
||||
if len(*accounts) > 0 {
|
||||
account = *(*accounts)[0]
|
||||
} else {
|
||||
account.UserId = a.UserId
|
||||
account.SecurityId = a.SecurityId
|
||||
@ -47,7 +26,7 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
|
||||
account.Name = a.Name
|
||||
account.ParentAccountId = a.ParentAccountId
|
||||
|
||||
err = tx.Insert(&account)
|
||||
err = tx.InsertAccount(&account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -57,11 +36,11 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
|
||||
|
||||
// Get (and attempt to create if it doesn't exist) the security/currency
|
||||
// trading account for the supplied security/currency
|
||||
func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) {
|
||||
func GetTradingAccount(tx store.Tx, userid int64, securityid int64) (*models.Account, error) {
|
||||
var tradingAccount models.Account
|
||||
var account models.Account
|
||||
|
||||
user, err := GetUser(tx, userid)
|
||||
user, err := tx.GetUser(userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -78,7 +57,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
security, err := GetSecurity(tx, securityid, userid)
|
||||
security, err := tx.GetSecurity(securityid, userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -99,7 +78,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account,
|
||||
|
||||
// Get (and attempt to create if it doesn't exist) the security/currency
|
||||
// imbalance account for the supplied security/currency
|
||||
func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) {
|
||||
func GetImbalanceAccount(tx store.Tx, userid int64, securityid int64) (*models.Account, error) {
|
||||
var imbalanceAccount models.Account
|
||||
var account models.Account
|
||||
xxxtemplate := FindSecurityTemplate("XXX", models.Currency)
|
||||
@ -123,7 +102,7 @@ func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Accoun
|
||||
return nil, err
|
||||
}
|
||||
|
||||
security, err := GetSecurity(tx, securityid, userid)
|
||||
security, err := tx.GetSecurity(securityid, userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -142,120 +121,6 @@ func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Accoun
|
||||
return a, nil
|
||||
}
|
||||
|
||||
type ParentAccountMissingError struct{}
|
||||
|
||||
func (pame ParentAccountMissingError) Error() string {
|
||||
return "Parent account missing"
|
||||
}
|
||||
|
||||
type TooMuchNestingError struct{}
|
||||
|
||||
func (tmne TooMuchNestingError) Error() string {
|
||||
return "Too much nesting"
|
||||
}
|
||||
|
||||
type CircularAccountsError struct{}
|
||||
|
||||
func (cae CircularAccountsError) Error() string {
|
||||
return "Would result in circular account relationship"
|
||||
}
|
||||
|
||||
func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error {
|
||||
found := make(map[int64]bool)
|
||||
if !insert {
|
||||
found[a.AccountId] = true
|
||||
}
|
||||
parentid := a.ParentAccountId
|
||||
depth := 0
|
||||
for parentid != -1 {
|
||||
depth += 1
|
||||
if depth > 100 {
|
||||
return TooMuchNestingError{}
|
||||
}
|
||||
|
||||
var a models.Account
|
||||
err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid)
|
||||
if err != nil {
|
||||
return ParentAccountMissingError{}
|
||||
}
|
||||
|
||||
// Insertion by itself can never result in circular dependencies
|
||||
if insert {
|
||||
break
|
||||
}
|
||||
|
||||
found[parentid] = true
|
||||
parentid = a.ParentAccountId
|
||||
if _, ok := found[parentid]; ok {
|
||||
return CircularAccountsError{}
|
||||
}
|
||||
}
|
||||
|
||||
if insert {
|
||||
err := tx.Insert(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
oldacct, err := GetAccount(tx, a.AccountId, a.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.AccountVersion = oldacct.AccountVersion + 1
|
||||
|
||||
count, err := tx.Update(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Updated more than one account")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertAccount(tx *Tx, a *models.Account) error {
|
||||
return insertUpdateAccount(tx, a, true)
|
||||
}
|
||||
|
||||
func UpdateAccount(tx *Tx, a *models.Account) error {
|
||||
return insertUpdateAccount(tx, a, false)
|
||||
}
|
||||
|
||||
func DeleteAccount(tx *Tx, a *models.Account) error {
|
||||
if a.ParentAccountId != -1 {
|
||||
// Re-parent splits to this account's parent account if this account isn't a root account
|
||||
_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Delete splits if this account is a root account
|
||||
_, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Re-parent child accounts to this account's parent account
|
||||
_, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := tx.Delete(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Was going to delete more than one account")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
user, err := GetUserFromSession(context.Tx, r)
|
||||
if err != nil {
|
||||
@ -279,7 +144,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
account.UserId = user.UserId
|
||||
account.AccountVersion = 0
|
||||
|
||||
security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId)
|
||||
security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -288,9 +153,9 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
err = InsertAccount(context.Tx, &account)
|
||||
err = context.Tx.InsertAccount(&account)
|
||||
if err != nil {
|
||||
if _, ok := err.(ParentAccountMissingError); ok {
|
||||
if _, ok := err.(store.ParentAccountMissingError); ok {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
} else {
|
||||
log.Print(err)
|
||||
@ -303,7 +168,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
if context.LastLevel() {
|
||||
//Return all Accounts
|
||||
var al models.AccountList
|
||||
accounts, err := GetAccounts(context.Tx, user.UserId)
|
||||
accounts, err := context.Tx.GetAccounts(user.UserId)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -319,7 +184,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
|
||||
if context.LastLevel() {
|
||||
// Return Account with this Id
|
||||
account, err := GetAccount(context.Tx, accountid, user.UserId)
|
||||
account, err := context.Tx.GetAccount(accountid, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
@ -340,7 +205,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
}
|
||||
account.UserId = user.UserId
|
||||
|
||||
security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId)
|
||||
security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -353,11 +218,11 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
err = UpdateAccount(context.Tx, &account)
|
||||
err = context.Tx.UpdateAccount(&account)
|
||||
if err != nil {
|
||||
if _, ok := err.(ParentAccountMissingError); ok {
|
||||
if _, ok := err.(store.ParentAccountMissingError); ok {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
} else if _, ok := err.(CircularAccountsError); ok {
|
||||
} else if _, ok := err.(store.CircularAccountsError); ok {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
} else {
|
||||
log.Print(err)
|
||||
@ -367,12 +232,12 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
|
||||
return &account
|
||||
} else if r.Method == "DELETE" {
|
||||
account, err := GetAccount(context.Tx, accountid, user.UserId)
|
||||
account, err := context.Tx.GetAccount(accountid, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
err = DeleteAccount(context.Tx, account)
|
||||
err = context.Tx.DeleteAccount(account)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
|
@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"github.com/yuin/gopher-lua"
|
||||
"math/big"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -16,7 +16,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
|
||||
|
||||
ctx := L.Context()
|
||||
|
||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
||||
tx, ok := ctx.Value(dbContextKey).(store.Tx)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find tx in lua's Context")
|
||||
}
|
||||
@ -28,14 +28,14 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
|
||||
return nil, errors.New("Couldn't find User in lua's Context")
|
||||
}
|
||||
|
||||
accounts, err := GetAccounts(tx, user.UserId)
|
||||
accounts, err := tx.GetAccounts(user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account_map = make(map[int64]*models.Account)
|
||||
for i := range *accounts {
|
||||
account_map[(*accounts)[i].AccountId] = &(*accounts)[i]
|
||||
account_map[(*accounts)[i].AccountId] = (*accounts)[i]
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, accountsContextKey, account_map)
|
||||
@ -150,7 +150,7 @@ func luaAccountBalance(L *lua.LState) int {
|
||||
a := luaCheckAccount(L, 1)
|
||||
|
||||
ctx := L.Context()
|
||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
||||
tx, ok := ctx.Value(dbContextKey).(store.Tx)
|
||||
if !ok {
|
||||
panic("Couldn't find tx in lua's Context")
|
||||
}
|
||||
@ -167,24 +167,29 @@ func luaAccountBalance(L *lua.LState) int {
|
||||
panic("SecurityId not in lua security_map")
|
||||
}
|
||||
date := luaWeakCheckTime(L, 2)
|
||||
var b Balance
|
||||
var rat *big.Rat
|
||||
var splits *[]*models.Split
|
||||
if date != nil {
|
||||
end := luaWeakCheckTime(L, 3)
|
||||
if end != nil {
|
||||
rat, err = GetAccountBalanceDateRange(tx, user, a.AccountId, date, end)
|
||||
splits, err = tx.GetAccountSplitsDateRange(user, a.AccountId, date, end)
|
||||
} else {
|
||||
rat, err = GetAccountBalanceDate(tx, user, a.AccountId, date)
|
||||
splits, err = tx.GetAccountSplitsDate(user, a.AccountId, date)
|
||||
}
|
||||
} else {
|
||||
rat, err = GetAccountBalance(tx, user, a.AccountId)
|
||||
splits, err = tx.GetAccountSplits(user, a.AccountId)
|
||||
}
|
||||
if err != nil {
|
||||
panic("Failed to GetAccountBalance:" + err.Error())
|
||||
panic("Failed to fetch splits for account:" + err.Error())
|
||||
}
|
||||
b.Amount = rat
|
||||
b.Security = security
|
||||
L.Push(BalanceToLua(L, &b))
|
||||
rat, err := BalanceFromSplits(splits)
|
||||
if err != nil {
|
||||
panic("Failed to calculate balance for account:" + err.Error())
|
||||
}
|
||||
b := &Balance{
|
||||
Amount: rat,
|
||||
Security: security,
|
||||
}
|
||||
L.Push(BalanceToLua(L, b))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
@ -2,12 +2,11 @@ package handlers_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"github.com/aclindsa/moneygo/internal/config"
|
||||
"github.com/aclindsa/moneygo/internal/db"
|
||||
"github.com/aclindsa/moneygo/internal/handlers"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store/db"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
@ -253,24 +252,15 @@ func RunTests(m *testing.M) int {
|
||||
dsn = envDSN
|
||||
}
|
||||
|
||||
dsn = db.GetDSN(dbType, dsn)
|
||||
database, err := sql.Open(dbType.String(), dsn)
|
||||
db, err := db.GetStore(dbType, dsn)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer database.Close()
|
||||
defer db.Close()
|
||||
|
||||
dbmap, err := db.GetDbMap(database, dbType)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
db.Empty() // clear the DB tables
|
||||
|
||||
err = dbmap.TruncateTables()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
server = httptest.NewTLSServer(&handlers.APIHandler{DB: dbmap})
|
||||
server = httptest.NewTLSServer(&handlers.APIHandler{Store: db})
|
||||
defer server.Close()
|
||||
|
||||
return m.Run()
|
||||
|
@ -437,7 +437,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite
|
||||
}
|
||||
split.AccountId = acctId
|
||||
|
||||
exists, err := SplitAlreadyImported(context.Tx, split)
|
||||
exists, err := context.Tx.SplitExists(split)
|
||||
if err != nil {
|
||||
log.Print("Error checking if split was already imported:", err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -446,7 +446,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite
|
||||
}
|
||||
}
|
||||
if !already_imported {
|
||||
err := InsertTransaction(context.Tx, &transaction, user)
|
||||
err := context.Tx.InsertTransaction(&transaction, user)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
|
@ -38,13 +38,13 @@ func TestImportGnucash(t *testing.T) {
|
||||
}
|
||||
for i, account := range *accounts.Accounts {
|
||||
if account.Name == "Income" && account.Type == models.Income && account.ParentAccountId == -1 {
|
||||
income = &(*accounts.Accounts)[i]
|
||||
income = (*accounts.Accounts)[i]
|
||||
} else if account.Name == "Equity" && account.Type == models.Equity && account.ParentAccountId == -1 {
|
||||
equity = &(*accounts.Accounts)[i]
|
||||
equity = (*accounts.Accounts)[i]
|
||||
} else if account.Name == "Liabilities" && account.Type == models.Liability && account.ParentAccountId == -1 {
|
||||
liabilities = &(*accounts.Accounts)[i]
|
||||
liabilities = (*accounts.Accounts)[i]
|
||||
} else if account.Name == "Expenses" && account.Type == models.Expense && account.ParentAccountId == -1 {
|
||||
expenses = &(*accounts.Accounts)[i]
|
||||
expenses = (*accounts.Accounts)[i]
|
||||
}
|
||||
}
|
||||
if income == nil {
|
||||
@ -61,15 +61,15 @@ func TestImportGnucash(t *testing.T) {
|
||||
}
|
||||
for i, account := range *accounts.Accounts {
|
||||
if account.Name == "Salary" && account.Type == models.Income && account.ParentAccountId == income.AccountId {
|
||||
salary = &(*accounts.Accounts)[i]
|
||||
salary = (*accounts.Accounts)[i]
|
||||
} else if account.Name == "Opening Balances" && account.Type == models.Equity && account.ParentAccountId == equity.AccountId {
|
||||
openingbalances = &(*accounts.Accounts)[i]
|
||||
openingbalances = (*accounts.Accounts)[i]
|
||||
} else if account.Name == "Credit Card" && account.Type == models.Liability && account.ParentAccountId == liabilities.AccountId {
|
||||
creditcard = &(*accounts.Accounts)[i]
|
||||
creditcard = (*accounts.Accounts)[i]
|
||||
} else if account.Name == "Groceries" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId {
|
||||
groceries = &(*accounts.Accounts)[i]
|
||||
groceries = (*accounts.Accounts)[i]
|
||||
} else if account.Name == "Cable" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId {
|
||||
cable = &(*accounts.Accounts)[i]
|
||||
cable = (*accounts.Accounts)[i]
|
||||
}
|
||||
}
|
||||
if salary == nil {
|
||||
|
@ -1,8 +1,9 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/aclindsa/gorp"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"github.com/aclindsa/moneygo/internal/store/db"
|
||||
"log"
|
||||
"net/http"
|
||||
"path"
|
||||
@ -16,7 +17,7 @@ type ResponseWriterWriter interface {
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
Tx *Tx
|
||||
Tx store.Tx
|
||||
User *models.User
|
||||
remainingURL string // portion of URL path not yet reached in the hierarchy
|
||||
}
|
||||
@ -46,11 +47,11 @@ func (c *Context) LastLevel() bool {
|
||||
type Handler func(*http.Request, *Context) ResponseWriterWriter
|
||||
|
||||
type APIHandler struct {
|
||||
DB *gorp.DbMap
|
||||
Store *db.DbStore
|
||||
}
|
||||
|
||||
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
|
||||
tx, err := GetTx(ah.DB)
|
||||
tx, err := ah.Store.Begin()
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
|
@ -3,6 +3,7 @@ package handlers
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"github.com/aclindsa/ofxgo"
|
||||
"io"
|
||||
"log"
|
||||
@ -23,7 +24,7 @@ func (od *OFXDownload) Read(json_str string) error {
|
||||
return dec.Decode(od)
|
||||
}
|
||||
|
||||
func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
|
||||
func ofxImportHelper(tx store.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
|
||||
itl, err := ImportOFX(r)
|
||||
|
||||
if err != nil {
|
||||
@ -38,7 +39,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
|
||||
}
|
||||
|
||||
// Return Account with this Id
|
||||
account, err := GetAccount(tx, accountid, user.UserId)
|
||||
account, err := tx.GetAccount(accountid, user.UserId)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
@ -158,7 +159,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
|
||||
split := new(models.Split)
|
||||
r := new(big.Rat)
|
||||
r.Neg(&imbalance)
|
||||
security, err := GetSecurity(tx, imbalanced_security, user.UserId)
|
||||
security, err := tx.GetSecurity(imbalanced_security, user.UserId)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -186,7 +187,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
|
||||
split.SecurityId = -1
|
||||
}
|
||||
|
||||
exists, err := SplitAlreadyImported(tx, split)
|
||||
exists, err := tx.SplitExists(split)
|
||||
if err != nil {
|
||||
log.Print("Error checking if split was already imported:", err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -201,7 +202,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
|
||||
}
|
||||
|
||||
for _, transaction := range transactions {
|
||||
err := InsertTransaction(tx, &transaction, user)
|
||||
err := tx.InsertTransaction(&transaction, user)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -217,7 +218,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *models.User, acco
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
account, err := GetAccount(context.Tx, accountid, user.UserId)
|
||||
account, err := context.Tx.GetAccount(accountid, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
@ -83,7 +83,7 @@ func findAccount(client *http.Client, name string, tipe models.AccountType, secu
|
||||
}
|
||||
for _, account := range *accounts.Accounts {
|
||||
if account.Name == name && account.Type == tipe && account.SecurityId == securityid {
|
||||
return &account, nil
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Unable to find account: \"%s\"", name)
|
||||
|
@ -2,82 +2,41 @@ package handlers
|
||||
|
||||
import (
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func CreatePriceIfNotExist(tx *Tx, price *models.Price) error {
|
||||
func CreatePriceIfNotExist(tx store.Tx, price *models.Price) error {
|
||||
if len(price.RemoteId) == 0 {
|
||||
// Always create a new price if we can't match on the RemoteId
|
||||
err := tx.Insert(price)
|
||||
err := tx.InsertPrice(price)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var prices []*models.Price
|
||||
|
||||
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value)
|
||||
exists, err := tx.PriceExists(price)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(prices) > 0 {
|
||||
if exists {
|
||||
return nil // price already exists
|
||||
}
|
||||
|
||||
err = tx.Insert(price)
|
||||
err = tx.InsertPrice(price)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) {
|
||||
var p models.Price
|
||||
err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) {
|
||||
var prices []*models.Price
|
||||
|
||||
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &prices, nil
|
||||
}
|
||||
|
||||
// Return the latest price for security in currency units before date
|
||||
func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||
var p models.Price
|
||||
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Return the earliest price for security in currency units after date
|
||||
func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||
var p models.Price
|
||||
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Return the price for security in currency closest to date
|
||||
func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||
earliest, _ := GetEarliestPrice(tx, security, currency, date)
|
||||
latest, err := GetLatestPrice(tx, security, currency, date)
|
||||
func GetClosestPrice(tx store.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||
earliest, _ := tx.GetEarliestPrice(security, currency, date)
|
||||
latest, err := tx.GetLatestPrice(security, currency, date)
|
||||
|
||||
// Return early if either earliest or latest are invalid
|
||||
if earliest == nil {
|
||||
@ -96,7 +55,7 @@ func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Tim
|
||||
}
|
||||
|
||||
func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter {
|
||||
security, err := GetSecurity(context.Tx, securityid, user.UserId)
|
||||
security, err := context.Tx.GetSecurity(securityid, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
@ -111,12 +70,12 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
|
||||
if price.SecurityId != security.SecurityId {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId)
|
||||
_, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
err = context.Tx.Insert(&price)
|
||||
err = context.Tx.InsertPrice(&price)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -128,7 +87,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
|
||||
//Return all this security's prices
|
||||
var pl models.PriceList
|
||||
|
||||
prices, err := GetPrices(context.Tx, security.SecurityId)
|
||||
prices, err := context.Tx.GetPrices(security.SecurityId)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -143,7 +102,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
price, err := GetPrice(context.Tx, priceid, security.SecurityId)
|
||||
price, err := context.Tx.GetPrice(priceid, security.SecurityId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
@ -160,30 +119,30 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
_, err = GetSecurity(context.Tx, price.SecurityId, user.UserId)
|
||||
_, err = context.Tx.GetSecurity(price.SecurityId, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId)
|
||||
_, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
count, err := context.Tx.Update(&price)
|
||||
if err != nil || count != 1 {
|
||||
err = context.Tx.UpdatePrice(&price)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
}
|
||||
|
||||
return &price
|
||||
} else if r.Method == "DELETE" {
|
||||
price, err := GetPrice(context.Tx, priceid, security.SecurityId)
|
||||
price, err := context.Tx.GetPrice(priceid, security.SecurityId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
count, err := context.Tx.Delete(price)
|
||||
if err != nil || count != 1 {
|
||||
err = context.Tx.DeletePrice(price)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"github.com/yuin/gopher-lua"
|
||||
"log"
|
||||
"net/http"
|
||||
@ -24,57 +25,7 @@ const (
|
||||
|
||||
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
|
||||
|
||||
func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) {
|
||||
var r models.Report
|
||||
|
||||
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func GetReports(tx *Tx, userid int64) (*[]models.Report, error) {
|
||||
var reports []models.Report
|
||||
|
||||
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &reports, nil
|
||||
}
|
||||
|
||||
func InsertReport(tx *Tx, r *models.Report) error {
|
||||
err := tx.Insert(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateReport(tx *Tx, r *models.Report) error {
|
||||
count, err := tx.Update(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Updated more than one report")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteReport(tx *Tx, r *models.Report) error {
|
||||
count, err := tx.Delete(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Deleted more than one report")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
|
||||
func runReport(tx store.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
|
||||
// Create a new LState without opening the default libs for security
|
||||
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
||||
defer L.Close()
|
||||
@ -138,8 +89,8 @@ func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabula
|
||||
}
|
||||
}
|
||||
|
||||
func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
|
||||
report, err := GetReport(tx, reportid, user.UserId)
|
||||
func ReportTabulationHandler(tx store.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
|
||||
report, err := tx.GetReport(reportid, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
@ -174,7 +125,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
err = InsertReport(context.Tx, &report)
|
||||
err = context.Tx.InsertReport(&report)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -185,7 +136,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
if context.LastLevel() {
|
||||
//Return all Reports
|
||||
var rl models.ReportList
|
||||
reports, err := GetReports(context.Tx, user.UserId)
|
||||
reports, err := context.Tx.GetReports(user.UserId)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -203,7 +154,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
return ReportTabulationHandler(context.Tx, r, user, reportid)
|
||||
} else {
|
||||
// Return Report with this Id
|
||||
report, err := GetReport(context.Tx, reportid, user.UserId)
|
||||
report, err := context.Tx.GetReport(reportid, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
@ -227,7 +178,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
err = UpdateReport(context.Tx, &report)
|
||||
err = context.Tx.UpdateReport(&report)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -235,12 +186,12 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
|
||||
return &report
|
||||
} else if r.Method == "DELETE" {
|
||||
report, err := GetReport(context.Tx, reportid, user.UserId)
|
||||
report, err := context.Tx.GetReport(reportid, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
err = DeleteReport(context.Tx, report)
|
||||
err = context.Tx.DeleteReport(report)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
|
@ -4,8 +4,8 @@ package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@ -50,108 +50,34 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) {
|
||||
var s models.Security
|
||||
|
||||
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) {
|
||||
var securities []*models.Security
|
||||
|
||||
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &securities, nil
|
||||
}
|
||||
|
||||
func InsertSecurity(tx *Tx, s *models.Security) error {
|
||||
err := tx.Insert(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateSecurity(tx *Tx, s *models.Security) (err error) {
|
||||
user, err := GetUser(tx, s.UserId)
|
||||
func UpdateSecurity(tx store.Tx, s *models.Security) (err error) {
|
||||
user, err := tx.GetUser(s.UserId)
|
||||
if err != nil {
|
||||
return
|
||||
} else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency {
|
||||
return errors.New("Cannot change security which is user's default currency to be non-currency")
|
||||
}
|
||||
|
||||
count, err := tx.Update(s)
|
||||
err = tx.UpdateSecurity(s)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if count > 1 {
|
||||
return fmt.Errorf("Updated %d securities (expected 1)", count)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type SecurityInUseError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e SecurityInUseError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func DeleteSecurity(tx *Tx, s *models.Security) error {
|
||||
// First, ensure no accounts are using this security
|
||||
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
|
||||
|
||||
if accounts != 0 {
|
||||
return SecurityInUseError{"One or more accounts still use this security"}
|
||||
}
|
||||
|
||||
user, err := GetUser(tx, s.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if user.DefaultCurrency == s.SecurityId {
|
||||
return SecurityInUseError{"Cannot delete security which is user's default currency"}
|
||||
}
|
||||
|
||||
// Remove all prices involving this security (either of this security, or
|
||||
// using it as a currency)
|
||||
_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := tx.Delete(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Deleted more than one security")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) {
|
||||
func ImportGetCreateSecurity(tx store.Tx, userid int64, security *models.Security) (*models.Security, error) {
|
||||
security.UserId = userid
|
||||
if len(security.AlternateId) == 0 {
|
||||
// Always create a new local security if we can't match on the AlternateId
|
||||
err := InsertSecurity(tx, security)
|
||||
err := tx.InsertSecurity(security)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return security, nil
|
||||
}
|
||||
|
||||
var securities []*models.Security
|
||||
|
||||
_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision)
|
||||
securities, err := tx.FindMatchingSecurities(security)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -159,7 +85,7 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*
|
||||
// First try to find a case insensitive match on the name or symbol
|
||||
upperName := strings.ToUpper(security.Name)
|
||||
upperSymbol := strings.ToUpper(security.Symbol)
|
||||
for _, s := range securities {
|
||||
for _, s := range *securities {
|
||||
if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) ||
|
||||
(len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) {
|
||||
return s, nil
|
||||
@ -168,7 +94,7 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*
|
||||
// if strings.Contains(strings.ToUpper(security.Name), upperSearch) ||
|
||||
|
||||
// Try to find a partial string match on the name or symbol
|
||||
for _, s := range securities {
|
||||
for _, s := range *securities {
|
||||
sUpperName := strings.ToUpper(s.Name)
|
||||
sUpperSymbol := strings.ToUpper(s.Symbol)
|
||||
if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) ||
|
||||
@ -178,12 +104,12 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*
|
||||
}
|
||||
|
||||
// Give up and return the first security in the list
|
||||
if len(securities) > 0 {
|
||||
return securities[0], nil
|
||||
if len(*securities) > 0 {
|
||||
return (*securities)[0], nil
|
||||
}
|
||||
|
||||
// If there wasn't even one security in the list, make a new one
|
||||
err = InsertSecurity(tx, security)
|
||||
err = tx.InsertSecurity(security)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -216,7 +142,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
security.SecurityId = -1
|
||||
security.UserId = user.UserId
|
||||
|
||||
err = InsertSecurity(context.Tx, &security)
|
||||
err = context.Tx.InsertSecurity(&security)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -228,7 +154,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
//Return all securities
|
||||
var sl models.SecurityList
|
||||
|
||||
securities, err := GetSecurities(context.Tx, user.UserId)
|
||||
securities, err := context.Tx.GetSecurities(user.UserId)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -249,7 +175,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
return PriceHandler(r, context, user, securityid)
|
||||
}
|
||||
|
||||
security, err := GetSecurity(context.Tx, securityid, user.UserId)
|
||||
security, err := context.Tx.GetSecurity(securityid, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
@ -283,13 +209,13 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
|
||||
return &security
|
||||
} else if r.Method == "DELETE" {
|
||||
security, err := GetSecurity(context.Tx, securityid, user.UserId)
|
||||
security, err := context.Tx.GetSecurity(securityid, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
err = DeleteSecurity(context.Tx, security)
|
||||
if _, ok := err.(SecurityInUseError); ok {
|
||||
err = context.Tx.DeleteSecurity(security)
|
||||
if _, ok := err.(store.SecurityInUseError); ok {
|
||||
return NewError(7 /*In Use Error*/)
|
||||
} else if err != nil {
|
||||
log.Print(err)
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
@ -14,7 +15,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
|
||||
|
||||
ctx := L.Context()
|
||||
|
||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
||||
tx, ok := ctx.Value(dbContextKey).(store.Tx)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find tx in lua's Context")
|
||||
}
|
||||
@ -26,7 +27,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
|
||||
return nil, errors.New("Couldn't find User in lua's Context")
|
||||
}
|
||||
|
||||
securities, err := GetSecurities(tx, user.UserId)
|
||||
securities, err := tx.GetSecurities(user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -158,7 +159,7 @@ func luaClosestPrice(L *lua.LState) int {
|
||||
date := luaCheckTime(L, 3)
|
||||
|
||||
ctx := L.Context()
|
||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
||||
tx, ok := ctx.Value(dbContextKey).(store.Tx)
|
||||
if !ok {
|
||||
panic("Couldn't find tx in lua's Context")
|
||||
}
|
||||
|
@ -3,36 +3,37 @@ package handlers
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GetSession(tx *Tx, r *http.Request) (*models.Session, error) {
|
||||
var s models.Session
|
||||
|
||||
func GetSession(tx store.Tx, r *http.Request) (*models.Session, error) {
|
||||
cookie, err := r.Cookie("moneygo-session")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("moneygo-session cookie not set")
|
||||
}
|
||||
s.SessionSecret = cookie.Value
|
||||
|
||||
err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
|
||||
s, err := tx.GetSession(cookie.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.Expires.Before(time.Now()) {
|
||||
tx.Delete(&s)
|
||||
err := tx.DeleteSession(s)
|
||||
if err != nil {
|
||||
log.Printf("Unexpected error when attempting to delete expired session: %s", err)
|
||||
}
|
||||
return nil, fmt.Errorf("Session has expired")
|
||||
}
|
||||
return &s, nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func DeleteSessionIfExists(tx *Tx, r *http.Request) error {
|
||||
func DeleteSessionIfExists(tx store.Tx, r *http.Request) error {
|
||||
session, err := GetSession(tx, r)
|
||||
if err == nil {
|
||||
_, err := tx.Delete(session)
|
||||
err := tx.DeleteSession(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -50,21 +51,26 @@ func (n *NewSessionWriter) Write(w http.ResponseWriter) error {
|
||||
return n.session.Write(w)
|
||||
}
|
||||
|
||||
func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
|
||||
func NewSession(tx store.Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
|
||||
err := DeleteSessionIfExists(tx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, err := models.NewSession(userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", s.SessionSecret)
|
||||
exists, err := tx.SessionExists(s.SessionSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if existing > 0 {
|
||||
return nil, fmt.Errorf("%d session(s) exist with the generated session_secret", existing)
|
||||
if exists {
|
||||
return nil, fmt.Errorf("Session already exists with the generated session_secret")
|
||||
}
|
||||
|
||||
err = tx.Insert(s)
|
||||
err = tx.InsertSession(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -79,22 +85,19 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
dbuser, err := GetUserByUsername(context.Tx, user.Username)
|
||||
// Hash password before checking username to help mitigate timing
|
||||
// attacks
|
||||
user.HashPassword()
|
||||
|
||||
dbuser, err := context.Tx.GetUserByUsername(user.Username)
|
||||
if err != nil {
|
||||
return NewError(2 /*Unauthorized Access*/)
|
||||
}
|
||||
|
||||
user.HashPassword()
|
||||
if user.PasswordHash != dbuser.PasswordHash {
|
||||
return NewError(2 /*Unauthorized Access*/)
|
||||
}
|
||||
|
||||
err = DeleteSessionIfExists(context.Tx, r)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
}
|
||||
|
||||
sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
|
@ -2,24 +2,18 @@ package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func SplitAlreadyImported(tx *Tx, s *models.Split) (bool, error) {
|
||||
count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
|
||||
return count == 1, err
|
||||
}
|
||||
|
||||
// Return a map of security ID's to big.Rat's containing the amount that
|
||||
// security is imbalanced by
|
||||
func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, error) {
|
||||
func GetTransactionImbalances(tx store.Tx, t *models.Transaction) (map[int64]big.Rat, error) {
|
||||
sums := make(map[int64]big.Rat)
|
||||
|
||||
if !t.Valid() {
|
||||
@ -31,7 +25,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat,
|
||||
if t.Splits[i].AccountId != -1 {
|
||||
var err error
|
||||
var account *models.Account
|
||||
account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId)
|
||||
account, err = tx.GetAccount(t.Splits[i].AccountId, t.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -47,7 +41,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat,
|
||||
|
||||
// Returns true if all securities contained in this transaction are balanced,
|
||||
// false otherwise
|
||||
func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
|
||||
func TransactionBalanced(tx store.Tx, t *models.Transaction) (bool, error) {
|
||||
var zero big.Rat
|
||||
|
||||
sums, err := GetTransactionImbalances(tx, t)
|
||||
@ -63,219 +57,6 @@ func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transaction, error) {
|
||||
var t models.Transaction
|
||||
|
||||
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) {
|
||||
var transactions []models.Transaction
|
||||
|
||||
_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range transactions {
|
||||
_, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &transactions, nil
|
||||
}
|
||||
|
||||
func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error {
|
||||
for i := range accountids {
|
||||
account, err := GetAccount(tx, accountids[i], user.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
account.AccountVersion++
|
||||
count, err := tx.Update(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Updated more than one account")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type AccountMissingError struct{}
|
||||
|
||||
func (ame AccountMissingError) Error() string {
|
||||
return "Account missing"
|
||||
}
|
||||
|
||||
func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
|
||||
// Map of any accounts with transaction splits being added
|
||||
a_map := make(map[int64]bool)
|
||||
for i := range t.Splits {
|
||||
if t.Splits[i].AccountId != -1 {
|
||||
existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != 1 {
|
||||
return AccountMissingError{}
|
||||
}
|
||||
a_map[t.Splits[i].AccountId] = true
|
||||
} else if t.Splits[i].SecurityId == -1 {
|
||||
return AccountMissingError{}
|
||||
}
|
||||
}
|
||||
|
||||
//increment versions for all accounts
|
||||
var a_ids []int64
|
||||
for id := range a_map {
|
||||
a_ids = append(a_ids, id)
|
||||
}
|
||||
// ensure at least one of the splits is associated with an actual account
|
||||
if len(a_ids) < 1 {
|
||||
return AccountMissingError{}
|
||||
}
|
||||
err := incrementAccountVersions(tx, user, a_ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.UserId = user.UserId
|
||||
err = tx.Insert(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range t.Splits {
|
||||
t.Splits[i].TransactionId = t.TransactionId
|
||||
t.Splits[i].SplitId = -1
|
||||
err = tx.Insert(t.Splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
|
||||
var existing_splits []*models.Split
|
||||
|
||||
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Map of any accounts with transaction splits being added
|
||||
a_map := make(map[int64]bool)
|
||||
|
||||
// Make a map with any existing splits for this transaction
|
||||
s_map := make(map[int64]bool)
|
||||
for i := range existing_splits {
|
||||
s_map[existing_splits[i].SplitId] = true
|
||||
}
|
||||
|
||||
// Insert splits, updating any pre-existing ones
|
||||
for i := range t.Splits {
|
||||
t.Splits[i].TransactionId = t.TransactionId
|
||||
_, ok := s_map[t.Splits[i].SplitId]
|
||||
if ok {
|
||||
count, err := tx.Update(t.Splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 1 {
|
||||
return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count)
|
||||
}
|
||||
delete(s_map, t.Splits[i].SplitId)
|
||||
} else {
|
||||
t.Splits[i].SplitId = -1
|
||||
err := tx.Insert(t.Splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if t.Splits[i].AccountId != -1 {
|
||||
a_map[t.Splits[i].AccountId] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Delete any remaining pre-existing splits
|
||||
for i := range existing_splits {
|
||||
_, ok := s_map[existing_splits[i].SplitId]
|
||||
if existing_splits[i].AccountId != -1 {
|
||||
a_map[existing_splits[i].AccountId] = true
|
||||
}
|
||||
if ok {
|
||||
_, err := tx.Delete(existing_splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Increment versions for all accounts with modified splits
|
||||
var a_ids []int64
|
||||
for id := range a_map {
|
||||
a_ids = append(a_ids, id)
|
||||
}
|
||||
err = incrementAccountVersions(tx, user, a_ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := tx.Update(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 1 {
|
||||
return fmt.Errorf("Updated %d transactions (expected 1)", count)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
|
||||
var accountids []int64
|
||||
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := tx.Delete(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Deleted more than one transaction")
|
||||
}
|
||||
|
||||
err = incrementAccountVersions(tx, user, accountids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
user, err := GetUserFromSession(context.Tx, r)
|
||||
if err != nil {
|
||||
@ -296,7 +77,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
||||
|
||||
for i := range transaction.Splits {
|
||||
transaction.Splits[i].SplitId = -1
|
||||
_, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId)
|
||||
_, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
@ -310,9 +91,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
err = InsertTransaction(context.Tx, &transaction, user)
|
||||
err = context.Tx.InsertTransaction(&transaction, user)
|
||||
if err != nil {
|
||||
if _, ok := err.(AccountMissingError); ok {
|
||||
if _, ok := err.(store.AccountMissingError); ok {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
} else {
|
||||
log.Print(err)
|
||||
@ -325,7 +106,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
||||
if context.LastLevel() {
|
||||
//Return all Transactions
|
||||
var al models.TransactionList
|
||||
transactions, err := GetTransactions(context.Tx, user.UserId)
|
||||
transactions, err := context.Tx.GetTransactions(user.UserId)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -338,7 +119,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
transaction, err := GetTransaction(context.Tx, transactionid, user.UserId)
|
||||
transaction, err := context.Tx.GetTransaction(transactionid, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
@ -370,13 +151,13 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
||||
}
|
||||
|
||||
for i := range transaction.Splits {
|
||||
_, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId)
|
||||
_, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
}
|
||||
|
||||
err = UpdateTransaction(context.Tx, &transaction, user)
|
||||
err = context.Tx.UpdateTransaction(&transaction, user)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -384,12 +165,12 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
||||
|
||||
return &transaction
|
||||
} else if r.Method == "DELETE" {
|
||||
transaction, err := GetTransaction(context.Tx, transactionid, user.UserId)
|
||||
transaction, err := context.Tx.GetTransaction(transactionid, user.UserId)
|
||||
if err != nil {
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
err = DeleteTransaction(context.Tx, transaction, user)
|
||||
err = context.Tx.DeleteTransaction(transaction, user)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
@ -401,41 +182,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
||||
return NewError(3 /*Invalid Request*/)
|
||||
}
|
||||
|
||||
func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) {
|
||||
var pageDifference, tmp big.Rat
|
||||
for i := range transactions {
|
||||
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sum up the amounts from the splits we're returning so we can return
|
||||
// an ending balance
|
||||
for j := range transactions[i].Splits {
|
||||
if transactions[i].Splits[j].AccountId == accountid {
|
||||
rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmp.Add(&pageDifference, rat_amount)
|
||||
pageDifference.Set(&tmp)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &pageDifference, nil
|
||||
}
|
||||
|
||||
func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) {
|
||||
var splits []models.Split
|
||||
|
||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
|
||||
_, err := tx.Select(&splits, sql, accountid, user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func BalanceFromSplits(splits *[]*models.Split) (*big.Rat, error) {
|
||||
var balance, tmp big.Rat
|
||||
for _, s := range splits {
|
||||
for _, s := range *splits {
|
||||
rat_amount, err := models.GetBigAmount(s.Amount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -447,132 +196,6 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er
|
||||
return &balance, nil
|
||||
}
|
||||
|
||||
// Assumes accountid is valid and is owned by the current user
|
||||
func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
|
||||
var splits []models.Split
|
||||
|
||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
|
||||
_, err := tx.Select(&splits, sql, accountid, user.UserId, date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var balance, tmp big.Rat
|
||||
for _, s := range splits {
|
||||
rat_amount, err := models.GetBigAmount(s.Amount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmp.Add(&balance, rat_amount)
|
||||
balance.Set(&tmp)
|
||||
}
|
||||
|
||||
return &balance, nil
|
||||
}
|
||||
|
||||
func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
||||
var splits []models.Split
|
||||
|
||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
|
||||
_, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var balance, tmp big.Rat
|
||||
for _, s := range splits {
|
||||
rat_amount, err := models.GetBigAmount(s.Amount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmp.Add(&balance, rat_amount)
|
||||
balance.Set(&tmp)
|
||||
}
|
||||
|
||||
return &balance, nil
|
||||
}
|
||||
|
||||
func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
|
||||
var transactions []models.Transaction
|
||||
var atl models.AccountTransactionsList
|
||||
|
||||
var sqlsort, balanceLimitOffset string
|
||||
var balanceLimitOffsetArg uint64
|
||||
if sort == "date-asc" {
|
||||
sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC"
|
||||
balanceLimitOffset = " LIMIT ?"
|
||||
balanceLimitOffsetArg = page * limit
|
||||
} else if sort == "date-desc" {
|
||||
numSplits, err := tx.SelectInt("SELECT count(*) FROM splits")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC"
|
||||
balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits)
|
||||
balanceLimitOffsetArg = (page + 1) * limit
|
||||
}
|
||||
|
||||
var sqloffset string
|
||||
if page > 0 {
|
||||
sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
|
||||
}
|
||||
|
||||
account, err := GetAccount(tx, accountid, user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
atl.Account = account
|
||||
|
||||
sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset
|
||||
_, err = tx.Select(&transactions, sql, user.UserId, accountid, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
atl.Transactions = &transactions
|
||||
|
||||
pageDifference, err := TransactionsBalanceDifference(tx, accountid, transactions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
atl.TotalTransactions = count
|
||||
|
||||
security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if security == nil {
|
||||
return nil, errors.New("Security not found")
|
||||
}
|
||||
|
||||
// Sum all the splits for all transaction splits for this account that
|
||||
// occurred before the page we're returning
|
||||
var amounts []string
|
||||
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
|
||||
_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tmp, balance big.Rat
|
||||
for _, amount := range amounts {
|
||||
rat_amount, err := models.GetBigAmount(amount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmp.Add(&balance, rat_amount)
|
||||
balance.Set(&tmp)
|
||||
}
|
||||
atl.BeginningBalance = balance.FloatString(security.Precision)
|
||||
atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision)
|
||||
|
||||
return &atl, nil
|
||||
}
|
||||
|
||||
// Return only those transactions which have at least one split pertaining to
|
||||
// an account
|
||||
func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
|
||||
@ -608,7 +231,7 @@ func AccountTransactionsHandler(context *Context, r *http.Request, user *models.
|
||||
sort = sortstring
|
||||
}
|
||||
|
||||
accountTransactions, err := GetAccountTransactions(context.Tx, user, accountid, sort, page, limit)
|
||||
accountTransactions, err := context.Tx.GetAccountTransactions(user, accountid, sort, page, limit)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
|
@ -276,7 +276,7 @@ func TestGetTransactions(t *testing.T) {
|
||||
found := false
|
||||
for _, tran := range *tl.Transactions {
|
||||
if tran.TransactionId == curr.TransactionId {
|
||||
ensureTransactionsMatch(t, &curr, &tran, nil, true, true)
|
||||
ensureTransactionsMatch(t, &curr, tran, nil, true, true)
|
||||
if _, ok := foundIds[tran.TransactionId]; ok {
|
||||
continue
|
||||
}
|
||||
@ -410,7 +410,7 @@ func helperTestAccountTransactions(t *testing.T, d *TestData, account *models.Ac
|
||||
}
|
||||
if atl.Transactions != nil {
|
||||
for _, tran := range *atl.Transactions {
|
||||
transactions = append(transactions, tran)
|
||||
transactions = append(transactions, *tran)
|
||||
}
|
||||
lastFetchCount = int64(len(*atl.Transactions))
|
||||
} else {
|
||||
|
@ -2,8 +2,8 @@ package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
@ -14,41 +14,21 @@ func (ueu UserExistsError) Error() string {
|
||||
return "User exists"
|
||||
}
|
||||
|
||||
func GetUser(tx *Tx, userid int64) (*models.User, error) {
|
||||
var u models.User
|
||||
|
||||
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
|
||||
var u models.User
|
||||
|
||||
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func InsertUser(tx *Tx, u *models.User) error {
|
||||
func InsertUser(tx store.Tx, u *models.User) error {
|
||||
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
||||
if security_template == nil {
|
||||
return errors.New("Invalid ISO4217 Default Currency")
|
||||
}
|
||||
|
||||
existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username)
|
||||
exists, err := tx.UsernameExists(u.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing > 0 {
|
||||
if exists {
|
||||
return UserExistsError{}
|
||||
}
|
||||
|
||||
err = tx.Insert(u)
|
||||
err = tx.InsertUser(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -58,33 +38,31 @@ func InsertUser(tx *Tx, u *models.User) error {
|
||||
security = *security_template
|
||||
security.UserId = u.UserId
|
||||
|
||||
err = InsertSecurity(tx, &security)
|
||||
err = tx.InsertSecurity(&security)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the user's DefaultCurrency to our new SecurityId
|
||||
u.DefaultCurrency = security.SecurityId
|
||||
count, err := tx.Update(u)
|
||||
err = tx.UpdateUser(u)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if count != 1 {
|
||||
return errors.New("Would have updated more than one user")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) {
|
||||
func GetUserFromSession(tx store.Tx, r *http.Request) (*models.User, error) {
|
||||
s, err := GetSession(tx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return GetUser(tx, s.UserId)
|
||||
return tx.GetUser(s.UserId)
|
||||
}
|
||||
|
||||
func UpdateUser(tx *Tx, u *models.User) error {
|
||||
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
|
||||
func UpdateUser(tx store.Tx, u *models.User) error {
|
||||
security, err := tx.GetSecurity(u.DefaultCurrency, u.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
|
||||
@ -93,49 +71,7 @@ func UpdateUser(tx *Tx, u *models.User) error {
|
||||
return errors.New("New DefaultCurrency security is not a currency")
|
||||
}
|
||||
|
||||
count, err := tx.Update(u)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if count != 1 {
|
||||
return errors.New("Would have updated more than one user")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteUser(tx *Tx, u *models.User) error {
|
||||
count, err := tx.Delete(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("No user to delete")
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", u.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", u.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId)
|
||||
err = tx.UpdateUser(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -204,7 +140,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||
|
||||
return user
|
||||
} else if r.Method == "DELETE" {
|
||||
err := DeleteUser(context.Tx, user)
|
||||
err := context.Tx.DeleteUser(user)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
|
@ -94,7 +94,7 @@ type Account struct {
|
||||
}
|
||||
|
||||
type AccountList struct {
|
||||
Accounts *[]Account `json:"accounts"`
|
||||
Accounts *[]*Account `json:"accounts"`
|
||||
}
|
||||
|
||||
func (a *Account) Write(w http.ResponseWriter) error {
|
||||
|
@ -28,7 +28,7 @@ func (r *Report) Read(json_str string) error {
|
||||
}
|
||||
|
||||
type ReportList struct {
|
||||
Reports *[]Report `json:"reports"`
|
||||
Reports *[]*Report `json:"reports"`
|
||||
}
|
||||
|
||||
func (rl *ReportList) Write(w http.ResponseWriter) error {
|
||||
|
@ -82,12 +82,12 @@ type Transaction struct {
|
||||
}
|
||||
|
||||
type TransactionList struct {
|
||||
Transactions *[]Transaction `json:"transactions"`
|
||||
Transactions *[]*Transaction `json:"transactions"`
|
||||
}
|
||||
|
||||
type AccountTransactionsList struct {
|
||||
Account *Account
|
||||
Transactions *[]Transaction
|
||||
Transactions *[]*Transaction
|
||||
TotalTransactions int64
|
||||
BeginningBalance string
|
||||
EndingBalance string
|
||||
|
133
internal/store/db/accounts.go
Normal file
133
internal/store/db/accounts.go
Normal file
@ -0,0 +1,133 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
)
|
||||
|
||||
func (tx *Tx) GetAccount(accountid int64, userid int64) (*models.Account, error) {
|
||||
var account models.Account
|
||||
|
||||
err := tx.SelectOne(&account, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) GetAccounts(userid int64) (*[]*models.Account, error) {
|
||||
var accounts []*models.Account
|
||||
|
||||
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &accounts, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) FindMatchingAccounts(account *models.Account) (*[]*models.Account, error) {
|
||||
var accounts []*models.Account
|
||||
|
||||
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC", account.UserId, account.SecurityId, account.Type, account.Name, account.ParentAccountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &accounts, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) insertUpdateAccount(account *models.Account, insert bool) error {
|
||||
found := make(map[int64]bool)
|
||||
if !insert {
|
||||
found[account.AccountId] = true
|
||||
}
|
||||
parentid := account.ParentAccountId
|
||||
depth := 0
|
||||
for parentid != -1 {
|
||||
depth += 1
|
||||
if depth > 100 {
|
||||
return store.TooMuchNestingError{}
|
||||
}
|
||||
|
||||
var a models.Account
|
||||
err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid)
|
||||
if err != nil {
|
||||
return store.ParentAccountMissingError{}
|
||||
}
|
||||
|
||||
// Insertion by itself can never result in circular dependencies
|
||||
if insert {
|
||||
break
|
||||
}
|
||||
|
||||
found[parentid] = true
|
||||
parentid = a.ParentAccountId
|
||||
if _, ok := found[parentid]; ok {
|
||||
return store.CircularAccountsError{}
|
||||
}
|
||||
}
|
||||
|
||||
if insert {
|
||||
err := tx.Insert(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
oldacct, err := tx.GetAccount(account.AccountId, account.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.AccountVersion = oldacct.AccountVersion + 1
|
||||
|
||||
count, err := tx.Update(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Updated more than one account")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) InsertAccount(account *models.Account) error {
|
||||
return tx.insertUpdateAccount(account, true)
|
||||
}
|
||||
|
||||
func (tx *Tx) UpdateAccount(account *models.Account) error {
|
||||
return tx.insertUpdateAccount(account, false)
|
||||
}
|
||||
|
||||
func (tx *Tx) DeleteAccount(account *models.Account) error {
|
||||
if account.ParentAccountId != -1 {
|
||||
// Re-parent splits to this account's parent account if this account isn't a root account
|
||||
_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", account.ParentAccountId, account.AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Delete splits if this account is a root account
|
||||
_, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", account.AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Re-parent child accounts to this account's parent account
|
||||
_, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", account.ParentAccountId, account.AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := tx.Delete(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Was going to delete more than one account")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
"github.com/aclindsa/gorp"
|
||||
"github.com/aclindsa/moneygo/internal/config"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
@ -18,7 +19,7 @@ import (
|
||||
// implementation's string type specified by the same.
|
||||
const luaMaxLengthBuffer int = 4096
|
||||
|
||||
func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
|
||||
func getDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
|
||||
var dialect gorp.Dialect
|
||||
if dbtype == config.SQLite {
|
||||
dialect = gorp.SqliteDialect{}
|
||||
@ -38,11 +39,11 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
|
||||
dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
|
||||
dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId")
|
||||
dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId")
|
||||
dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId")
|
||||
dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId")
|
||||
dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId")
|
||||
dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId")
|
||||
dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId")
|
||||
dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId")
|
||||
dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId")
|
||||
rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId")
|
||||
rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer)
|
||||
|
||||
@ -54,9 +55,50 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
|
||||
return dbmap, nil
|
||||
}
|
||||
|
||||
func GetDSN(dbtype config.DbType, dsn string) string {
|
||||
func getDSN(dbtype config.DbType, dsn string) string {
|
||||
if dbtype == config.MySQL && !strings.Contains(dsn, "parseTime=true") {
|
||||
log.Fatalf("The DSN for MySQL MUST contain 'parseTime=True' but does not!")
|
||||
}
|
||||
return dsn
|
||||
}
|
||||
|
||||
type DbStore struct {
|
||||
dbMap *gorp.DbMap
|
||||
}
|
||||
|
||||
func (db *DbStore) Empty() error {
|
||||
return db.dbMap.TruncateTables()
|
||||
}
|
||||
|
||||
func (db *DbStore) Begin() (store.Tx, error) {
|
||||
tx, err := db.dbMap.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{db.dbMap.Dialect, tx}, nil
|
||||
}
|
||||
|
||||
func (db *DbStore) Close() error {
|
||||
err := db.dbMap.Db.Close()
|
||||
db.dbMap = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) {
|
||||
dsn = getDSN(dbtype, dsn)
|
||||
database, err := sql.Open(dbtype.String(), dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
database.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
dbmap, err := getDbMap(database, dbtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DbStore{dbmap}, nil
|
||||
}
|
78
internal/store/db/prices.go
Normal file
78
internal/store/db/prices.go
Normal file
@ -0,0 +1,78 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (tx *Tx) PriceExists(price *models.Price) (bool, error) {
|
||||
var prices []*models.Price
|
||||
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value)
|
||||
return len(prices) > 0, err
|
||||
}
|
||||
|
||||
func (tx *Tx) InsertPrice(price *models.Price) error {
|
||||
return tx.Insert(price)
|
||||
}
|
||||
|
||||
func (tx *Tx) GetPrice(priceid, securityid int64) (*models.Price, error) {
|
||||
var price models.Price
|
||||
err := tx.SelectOne(&price, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &price, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) GetPrices(securityid int64) (*[]*models.Price, error) {
|
||||
var prices []*models.Price
|
||||
|
||||
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &prices, nil
|
||||
}
|
||||
|
||||
// Return the latest price for security in currency units before date
|
||||
func (tx *Tx) GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||
var price models.Price
|
||||
err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &price, nil
|
||||
}
|
||||
|
||||
// Return the earliest price for security in currency units after date
|
||||
func (tx *Tx) GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||
var price models.Price
|
||||
err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &price, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) UpdatePrice(price *models.Price) error {
|
||||
count, err := tx.Update(price)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("Expected to update 1 price, was going to update %d", count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) DeletePrice(price *models.Price) error {
|
||||
count, err := tx.Delete(price)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("Expected to delete 1 price, was going to delete %d", count)
|
||||
}
|
||||
return nil
|
||||
}
|
56
internal/store/db/reports.go
Normal file
56
internal/store/db/reports.go
Normal file
@ -0,0 +1,56 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
)
|
||||
|
||||
func (tx *Tx) GetReport(reportid int64, userid int64) (*models.Report, error) {
|
||||
var r models.Report
|
||||
|
||||
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) GetReports(userid int64) (*[]*models.Report, error) {
|
||||
var reports []*models.Report
|
||||
|
||||
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &reports, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) InsertReport(report *models.Report) error {
|
||||
err := tx.Insert(report)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) UpdateReport(report *models.Report) error {
|
||||
count, err := tx.Update(report)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("Expected to update 1 report, was going to update %d", count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) DeleteReport(report *models.Report) error {
|
||||
count, err := tx.Delete(report)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("Expected to delete 1 report, was going to delete %d", count)
|
||||
}
|
||||
return nil
|
||||
}
|
88
internal/store/db/securities.go
Normal file
88
internal/store/db/securities.go
Normal file
@ -0,0 +1,88 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
)
|
||||
|
||||
func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) {
|
||||
var s models.Security
|
||||
|
||||
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) GetSecurities(userid int64) (*[]*models.Security, error) {
|
||||
var securities []*models.Security
|
||||
|
||||
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &securities, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) FindMatchingSecurities(security *models.Security) (*[]*models.Security, error) {
|
||||
var securities []*models.Security
|
||||
|
||||
_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", security.UserId, security.Type, security.AlternateId, security.Precision)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &securities, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) InsertSecurity(s *models.Security) error {
|
||||
err := tx.Insert(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) UpdateSecurity(security *models.Security) error {
|
||||
count, err := tx.Update(security)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("Expected to update 1 security, was going to update %d", count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) DeleteSecurity(s *models.Security) error {
|
||||
// First, ensure no accounts are using this security
|
||||
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
|
||||
|
||||
if accounts != 0 {
|
||||
return store.SecurityInUseError{"One or more accounts still use this security"}
|
||||
}
|
||||
|
||||
user, err := tx.GetUser(s.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if user.DefaultCurrency == s.SecurityId {
|
||||
return store.SecurityInUseError{"Cannot delete security which is user's default currency"}
|
||||
}
|
||||
|
||||
// Remove all prices involving this security (either of this security, or
|
||||
// using it as a currency)
|
||||
_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := tx.Delete(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("Expected to delete 1 security, was going to delete %d", count)
|
||||
}
|
||||
return nil
|
||||
}
|
42
internal/store/db/sessions.go
Normal file
42
internal/store/db/sessions.go
Normal file
@ -0,0 +1,42 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (tx *Tx) InsertSession(session *models.Session) error {
|
||||
return tx.Insert(session)
|
||||
}
|
||||
|
||||
func (tx *Tx) GetSession(secret string) (*models.Session, error) {
|
||||
var s models.Session
|
||||
|
||||
err := tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.Expires.Before(time.Now()) {
|
||||
tx.Delete(&s)
|
||||
return nil, fmt.Errorf("Session has expired")
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) SessionExists(secret string) (bool, error) {
|
||||
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", secret)
|
||||
return existing != 0, err
|
||||
}
|
||||
|
||||
func (tx *Tx) DeleteSession(session *models.Session) error {
|
||||
count, err := tx.Delete(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count)
|
||||
}
|
||||
return nil
|
||||
}
|
361
internal/store/db/transactions.go
Normal file
361
internal/store/db/transactions.go
Normal file
@ -0,0 +1,361 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"github.com/aclindsa/moneygo/internal/store"
|
||||
"math/big"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (tx *Tx) incrementAccountVersions(user *models.User, accountids []int64) error {
|
||||
for i := range accountids {
|
||||
account, err := tx.GetAccount(accountids[i], user.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
account.AccountVersion++
|
||||
count, err := tx.Update(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Updated more than one account")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) InsertTransaction(t *models.Transaction, user *models.User) error {
|
||||
// Map of any accounts with transaction splits being added
|
||||
a_map := make(map[int64]bool)
|
||||
for i := range t.Splits {
|
||||
if t.Splits[i].AccountId != -1 {
|
||||
existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != 1 {
|
||||
return store.AccountMissingError{}
|
||||
}
|
||||
a_map[t.Splits[i].AccountId] = true
|
||||
} else if t.Splits[i].SecurityId == -1 {
|
||||
return store.AccountMissingError{}
|
||||
}
|
||||
}
|
||||
|
||||
//increment versions for all accounts
|
||||
var a_ids []int64
|
||||
for id := range a_map {
|
||||
a_ids = append(a_ids, id)
|
||||
}
|
||||
// ensure at least one of the splits is associated with an actual account
|
||||
if len(a_ids) < 1 {
|
||||
return store.AccountMissingError{}
|
||||
}
|
||||
err := tx.incrementAccountVersions(user, a_ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.UserId = user.UserId
|
||||
err = tx.Insert(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range t.Splits {
|
||||
t.Splits[i].TransactionId = t.TransactionId
|
||||
t.Splits[i].SplitId = -1
|
||||
err = tx.Insert(t.Splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) SplitExists(s *models.Split) (bool, error) {
|
||||
count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
|
||||
return count == 1, err
|
||||
}
|
||||
|
||||
func (tx *Tx) GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) {
|
||||
var t models.Transaction
|
||||
|
||||
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) GetTransactions(userid int64) (*[]*models.Transaction, error) {
|
||||
var transactions []*models.Transaction
|
||||
|
||||
_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range transactions {
|
||||
_, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &transactions, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) UpdateTransaction(t *models.Transaction, user *models.User) error {
|
||||
var existing_splits []*models.Split
|
||||
|
||||
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Map of any accounts with transaction splits being added
|
||||
a_map := make(map[int64]bool)
|
||||
|
||||
// Make a map with any existing splits for this transaction
|
||||
s_map := make(map[int64]bool)
|
||||
for i := range existing_splits {
|
||||
s_map[existing_splits[i].SplitId] = true
|
||||
}
|
||||
|
||||
// Insert splits, updating any pre-existing ones
|
||||
for i := range t.Splits {
|
||||
t.Splits[i].TransactionId = t.TransactionId
|
||||
_, ok := s_map[t.Splits[i].SplitId]
|
||||
if ok {
|
||||
count, err := tx.Update(t.Splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 1 {
|
||||
return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count)
|
||||
}
|
||||
delete(s_map, t.Splits[i].SplitId)
|
||||
} else {
|
||||
t.Splits[i].SplitId = -1
|
||||
err := tx.Insert(t.Splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if t.Splits[i].AccountId != -1 {
|
||||
a_map[t.Splits[i].AccountId] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Delete any remaining pre-existing splits
|
||||
for i := range existing_splits {
|
||||
_, ok := s_map[existing_splits[i].SplitId]
|
||||
if existing_splits[i].AccountId != -1 {
|
||||
a_map[existing_splits[i].AccountId] = true
|
||||
}
|
||||
if ok {
|
||||
_, err := tx.Delete(existing_splits[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Increment versions for all accounts with modified splits
|
||||
var a_ids []int64
|
||||
for id := range a_map {
|
||||
a_ids = append(a_ids, id)
|
||||
}
|
||||
err = tx.incrementAccountVersions(user, a_ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := tx.Update(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 1 {
|
||||
return fmt.Errorf("Updated %d transactions (expected 1)", count)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) DeleteTransaction(t *models.Transaction, user *models.User) error {
|
||||
var accountids []int64
|
||||
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := tx.Delete(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return errors.New("Deleted more than one transaction")
|
||||
}
|
||||
|
||||
err = tx.incrementAccountVersions(user, accountids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) {
|
||||
var splits []*models.Split
|
||||
|
||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
|
||||
_, err := tx.Select(&splits, sql, accountid, user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &splits, nil
|
||||
}
|
||||
|
||||
// Assumes accountid is valid and is owned by the current user
|
||||
func (tx *Tx) GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) {
|
||||
var splits []*models.Split
|
||||
|
||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
|
||||
_, err := tx.Select(&splits, sql, accountid, user.UserId, date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &splits, err
|
||||
}
|
||||
|
||||
func (tx *Tx) GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) {
|
||||
var splits []*models.Split
|
||||
|
||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
|
||||
_, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &splits, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) transactionsBalanceDifference(accountid int64, transactions []*models.Transaction) (*big.Rat, error) {
|
||||
var pageDifference, tmp big.Rat
|
||||
for i := range transactions {
|
||||
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sum up the amounts from the splits we're returning so we can return
|
||||
// an ending balance
|
||||
for j := range transactions[i].Splits {
|
||||
if transactions[i].Splits[j].AccountId == accountid {
|
||||
rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmp.Add(&pageDifference, rat_amount)
|
||||
pageDifference.Set(&tmp)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &pageDifference, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
|
||||
var transactions []*models.Transaction
|
||||
var atl models.AccountTransactionsList
|
||||
|
||||
var sqlsort, balanceLimitOffset string
|
||||
var balanceLimitOffsetArg uint64
|
||||
if sort == "date-asc" {
|
||||
sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC"
|
||||
balanceLimitOffset = " LIMIT ?"
|
||||
balanceLimitOffsetArg = page * limit
|
||||
} else if sort == "date-desc" {
|
||||
numSplits, err := tx.SelectInt("SELECT count(*) FROM splits")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC"
|
||||
balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits)
|
||||
balanceLimitOffsetArg = (page + 1) * limit
|
||||
}
|
||||
|
||||
var sqloffset string
|
||||
if page > 0 {
|
||||
sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
|
||||
}
|
||||
|
||||
account, err := tx.GetAccount(accountid, user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
atl.Account = account
|
||||
|
||||
sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset
|
||||
_, err = tx.Select(&transactions, sql, user.UserId, accountid, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
atl.Transactions = &transactions
|
||||
|
||||
pageDifference, err := tx.transactionsBalanceDifference(accountid, transactions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
atl.TotalTransactions = count
|
||||
|
||||
security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if security == nil {
|
||||
return nil, errors.New("Security not found")
|
||||
}
|
||||
|
||||
// Sum all the splits for all transaction splits for this account that
|
||||
// occurred before the page we're returning
|
||||
var amounts []string
|
||||
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
|
||||
_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tmp, balance big.Rat
|
||||
for _, amount := range amounts {
|
||||
rat_amount, err := models.GetBigAmount(amount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmp.Add(&balance, rat_amount)
|
||||
balance.Set(&tmp)
|
||||
}
|
||||
atl.BeginningBalance = balance.FloatString(security.Precision)
|
||||
atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision)
|
||||
|
||||
return &atl, nil
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package handlers
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
@ -41,7 +41,20 @@ func (tx *Tx) Insert(list ...interface{}) error {
|
||||
}
|
||||
|
||||
func (tx *Tx) Update(list ...interface{}) (int64, error) {
|
||||
return tx.Tx.Update(list...)
|
||||
count, err := tx.Tx.Update(list...)
|
||||
if count == 0 {
|
||||
switch tx.Dialect.(type) {
|
||||
case gorp.MySQLDialect:
|
||||
// Always return 1 for 0 if we're using MySQL because it returns
|
||||
// count=0 if the row data was unchanged, even if the row existed
|
||||
|
||||
// TODO Find another way to fix this without risking ignoring
|
||||
// errors
|
||||
|
||||
count = 1
|
||||
}
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
|
||||
@ -55,11 +68,3 @@ func (tx *Tx) Commit() error {
|
||||
func (tx *Tx) Rollback() error {
|
||||
return tx.Tx.Rollback()
|
||||
}
|
||||
|
||||
func GetTx(db *gorp.DbMap) (*Tx, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{db.Dialect, tx}, nil
|
||||
}
|
86
internal/store/db/users.go
Normal file
86
internal/store/db/users.go
Normal file
@ -0,0 +1,86 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
)
|
||||
|
||||
func (tx *Tx) UsernameExists(username string) (bool, error) {
|
||||
existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", username)
|
||||
return existing != 0, err
|
||||
}
|
||||
|
||||
func (tx *Tx) InsertUser(user *models.User) error {
|
||||
return tx.Insert(user)
|
||||
}
|
||||
|
||||
func (tx *Tx) GetUser(userid int64) (*models.User, error) {
|
||||
var u models.User
|
||||
|
||||
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) GetUserByUsername(username string) (*models.User, error) {
|
||||
var u models.User
|
||||
|
||||
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) UpdateUser(user *models.User) error {
|
||||
count, err := tx.Update(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("Expected to update 1 user, was going to update %d", count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) DeleteUser(user *models.User) error {
|
||||
count, err := tx.Delete(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count)
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", user.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", user.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", user.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", user.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", user.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", user.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", user.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
123
internal/store/store.go
Normal file
123
internal/store/store.go
Normal file
@ -0,0 +1,123 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"github.com/aclindsa/moneygo/internal/models"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserStore interface {
|
||||
UsernameExists(username string) (bool, error)
|
||||
InsertUser(user *models.User) error
|
||||
GetUser(userid int64) (*models.User, error)
|
||||
GetUserByUsername(username string) (*models.User, error)
|
||||
UpdateUser(user *models.User) error
|
||||
DeleteUser(user *models.User) error
|
||||
}
|
||||
|
||||
type SessionStore interface {
|
||||
SessionExists(secret string) (bool, error)
|
||||
InsertSession(session *models.Session) error
|
||||
GetSession(secret string) (*models.Session, error)
|
||||
DeleteSession(session *models.Session) error
|
||||
}
|
||||
|
||||
type SecurityInUseError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e SecurityInUseError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
type SecurityStore interface {
|
||||
InsertSecurity(security *models.Security) error
|
||||
GetSecurity(securityid int64, userid int64) (*models.Security, error)
|
||||
GetSecurities(userid int64) (*[]*models.Security, error)
|
||||
FindMatchingSecurities(security *models.Security) (*[]*models.Security, error)
|
||||
UpdateSecurity(security *models.Security) error
|
||||
DeleteSecurity(security *models.Security) error
|
||||
}
|
||||
|
||||
type PriceStore interface {
|
||||
PriceExists(price *models.Price) (bool, error)
|
||||
InsertPrice(price *models.Price) error
|
||||
GetPrice(priceid, securityid int64) (*models.Price, error)
|
||||
GetPrices(securityid int64) (*[]*models.Price, error)
|
||||
GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error)
|
||||
GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error)
|
||||
UpdatePrice(price *models.Price) error
|
||||
DeletePrice(price *models.Price) error
|
||||
}
|
||||
|
||||
type ParentAccountMissingError struct{}
|
||||
|
||||
func (pame ParentAccountMissingError) Error() string {
|
||||
return "Parent account missing"
|
||||
}
|
||||
|
||||
type TooMuchNestingError struct{}
|
||||
|
||||
func (tmne TooMuchNestingError) Error() string {
|
||||
return "Too much account nesting"
|
||||
}
|
||||
|
||||
type CircularAccountsError struct{}
|
||||
|
||||
func (cae CircularAccountsError) Error() string {
|
||||
return "Would result in circular account relationship"
|
||||
}
|
||||
|
||||
type AccountStore interface {
|
||||
InsertAccount(account *models.Account) error
|
||||
GetAccount(accountid int64, userid int64) (*models.Account, error)
|
||||
GetAccounts(userid int64) (*[]*models.Account, error)
|
||||
FindMatchingAccounts(account *models.Account) (*[]*models.Account, error)
|
||||
UpdateAccount(account *models.Account) error
|
||||
DeleteAccount(account *models.Account) error
|
||||
}
|
||||
|
||||
type AccountMissingError struct{}
|
||||
|
||||
func (ame AccountMissingError) Error() string {
|
||||
return "Account missing"
|
||||
}
|
||||
|
||||
type TransactionStore interface {
|
||||
SplitExists(s *models.Split) (bool, error)
|
||||
InsertTransaction(t *models.Transaction, user *models.User) error
|
||||
GetTransaction(transactionid int64, userid int64) (*models.Transaction, error)
|
||||
GetTransactions(userid int64) (*[]*models.Transaction, error)
|
||||
UpdateTransaction(t *models.Transaction, user *models.User) error
|
||||
DeleteTransaction(t *models.Transaction, user *models.User) error
|
||||
GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error)
|
||||
GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error)
|
||||
GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error)
|
||||
GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error)
|
||||
}
|
||||
|
||||
type ReportStore interface {
|
||||
InsertReport(report *models.Report) error
|
||||
GetReport(reportid int64, userid int64) (*models.Report, error)
|
||||
GetReports(userid int64) (*[]*models.Report, error)
|
||||
UpdateReport(report *models.Report) error
|
||||
DeleteReport(report *models.Report) error
|
||||
}
|
||||
|
||||
type Tx interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
|
||||
UserStore
|
||||
SessionStore
|
||||
SecurityStore
|
||||
PriceStore
|
||||
AccountStore
|
||||
TransactionStore
|
||||
ReportStore
|
||||
}
|
||||
|
||||
type Store interface {
|
||||
Empty() error
|
||||
Begin() (Tx, error)
|
||||
Close() error
|
||||
}
|
15
main.go
15
main.go
@ -3,11 +3,10 @@ package main
|
||||
//go:generate make
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"flag"
|
||||
"github.com/aclindsa/moneygo/internal/config"
|
||||
"github.com/aclindsa/moneygo/internal/db"
|
||||
"github.com/aclindsa/moneygo/internal/handlers"
|
||||
"github.com/aclindsa/moneygo/internal/store/db"
|
||||
"github.com/kabukky/httpscerts"
|
||||
"log"
|
||||
"net"
|
||||
@ -67,21 +66,15 @@ func staticHandler(w http.ResponseWriter, r *http.Request, basedir string) {
|
||||
}
|
||||
|
||||
func main() {
|
||||
dsn := db.GetDSN(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN)
|
||||
database, err := sql.Open(cfg.MoneyGo.DBType.String(), dsn)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
dbmap, err := db.GetDbMap(database, cfg.MoneyGo.DBType)
|
||||
db, err := db.GetStore(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Get ServeMux for API and add our own handlers for files
|
||||
servemux := http.NewServeMux()
|
||||
servemux.Handle("/v1/", &handlers.APIHandler{DB: dbmap})
|
||||
servemux.Handle("/v1/", &handlers.APIHandler{Store: db})
|
||||
servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir))
|
||||
servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user