diff --git a/.travis.yml b/.travis.yml index e4a54dc..00f2cc2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -59,9 +59,11 @@ script: - touch $GOPATH/src/github.com/aclindsa/moneygo/internal/handlers/cusip_list.csv # Build and test MoneyGo - go generate -v github.com/aclindsa/moneygo/internal/handlers - - go test -v -covermode=count -coverpkg github.com/aclindsa/moneygo/internal/config,github.com/aclindsa/moneygo/internal/handlers,github.com/aclindsa/moneygo/internal/models,github.com/aclindsa/moneygo/internal/store,github.com/aclindsa/moneygo/internal/store/db,github.com/aclindsa/moneygo/internal/reports -coverprofile=integration_coverage.out github.com/aclindsa/moneygo/internal/integration - - go test -v -covermode=count -coverpkg github.com/aclindsa/moneygo/internal/config,github.com/aclindsa/moneygo/internal/handlers,github.com/aclindsa/moneygo/internal/models,github.com/aclindsa/moneygo/internal/store,github.com/aclindsa/moneygo/internal/store/db,github.com/aclindsa/moneygo/internal/reports -coverprofile=config_coverage.out github.com/aclindsa/moneygo/internal/config + - export COVER_PACKAGES="github.com/aclindsa/moneygo/internal/config,github.com/aclindsa/moneygo/internal/handlers,github.com/aclindsa/moneygo/internal/models,github.com/aclindsa/moneygo/internal/store,github.com/aclindsa/moneygo/internal/store/db,github.com/aclindsa/moneygo/internal/reports" + - go test -v -covermode=count -coverpkg $COVER_PACKAGES -coverprofile=integration_coverage.out github.com/aclindsa/moneygo/internal/integration + - go test -v -covermode=count -coverpkg $COVER_PACKAGES -coverprofile=config_coverage.out github.com/aclindsa/moneygo/internal/config + - go test -v -covermode=count -coverpkg $COVER_PACKAGES -coverprofile=models_coverage.out github.com/aclindsa/moneygo/internal/models # Report the test coverage after_script: - - $GOPATH/bin/goveralls -coverprofile=integration_coverage.out,config_coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN + - $GOPATH/bin/goveralls -coverprofile=integration_coverage.out,config_coverage.out,models_coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN diff --git a/internal/handlers/gnucash.go b/internal/handlers/gnucash.go index 05fa0af..7a80c69 100644 --- a/internal/handlers/gnucash.go +++ b/internal/handlers/gnucash.go @@ -10,7 +10,6 @@ import ( "io" "log" "math" - "math/big" "net/http" "time" ) @@ -49,7 +48,7 @@ func (gc *GnucashCommodity) UnmarshalXML(d *xml.Decoder, start xml.StartElement) gc.Precision = template.Precision } else { if gxc.Fraction > 0 { - gc.Precision = int(math.Ceil(math.Log10(float64(gxc.Fraction)))) + gc.Precision = uint64(math.Ceil(math.Log10(float64(gxc.Fraction)))) } else { gc.Precision = 0 } @@ -178,13 +177,14 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { p.CurrencyId = currency.SecurityId p.Date = price.Date.Date.Time - var r big.Rat - _, ok = r.SetString(price.Value) - if ok { - p.Value = r.FloatString(currency.Precision) - } else { + _, ok = p.Value.SetString(price.Value) + if !ok { return nil, fmt.Errorf("Can't set price value: %s", price.Value) } + if p.Value.Precision() > currency.Precision { + // TODO we're possibly losing data here... but do we care? + p.Value.Round(currency.Precision) + } p.RemoteId = "gnucash:" + price.Id gncimport.Prices = append(gncimport.Prices, p) @@ -293,13 +293,13 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { s.Number = gt.Number s.Memo = gs.Memo - var r big.Rat - _, ok = r.SetString(gs.Amount) - if ok { - s.Amount = r.FloatString(security.Precision) - } else { + _, ok = s.Amount.SetString(gs.Amount) + if !ok { return nil, fmt.Errorf("Can't set split Amount: %s", gs.Amount) } + if s.Amount.Precision() > security.Precision { + return nil, fmt.Errorf("Imported price's precision (%d) is greater than the security's (%s)\n", s.Amount.Precision(), security) + } t.Splits = append(t.Splits, s) } @@ -356,6 +356,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite } if err != nil { + log.Print(err) return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 58696d0..b8a9924 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -164,7 +164,11 @@ func ofxImportHelper(tx store.Tx, r io.Reader, user *models.User, accountid int6 log.Print(err) return NewError(999 /*Internal Error*/) } - split.Amount = r.FloatString(security.Precision) + split.Amount.Rat = *r + if split.Amount.Precision() > security.Precision { + log.Printf("Precision on created imbalance-correction split (%d) greater than the underlying security (%s) allows (%d)", split.Amount.Precision(), security, security.Precision) + return NewError(999 /*Internal Error*/) + } split.SecurityId = -1 split.AccountId = imbalanced_account.AccountId transaction.Splits = append(transaction.Splits, split) diff --git a/internal/handlers/ofx.go b/internal/handlers/ofx.go index a183aab..77cf066 100644 --- a/internal/handlers/ofx.go +++ b/internal/handlers/ofx.go @@ -97,9 +97,12 @@ func (i *OFXImport) AddTransaction(tran *ofxgo.Transaction, account *models.Acco s1.ImportSplitType = models.ImportAccount s2.ImportSplitType = models.ExternalAccount + s1.Amount.Rat = *amt + s2.Amount.Rat = *amt.Neg(amt) security := i.Securities[account.SecurityId-1] - s1.Amount = amt.FloatString(security.Precision) - s2.Amount = amt.Neg(amt).FloatString(security.Precision) + if s1.Amount.Precision() > security.Precision { + return errors.New("Imported transaction amount is too precise for security") + } s1.Status = models.Imported s2.Status = models.Imported @@ -262,7 +265,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac SecurityId: curdef.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), Memo: memo + "(commission)", - Amount: commission.FloatString(curdef.Precision), + Amount: models.Amount{commission}, }) } if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 { @@ -274,7 +277,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac SecurityId: curdef.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), Memo: memo + "(taxes)", - Amount: taxes.FloatString(curdef.Precision), + Amount: models.Amount{taxes}, }) } if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 { @@ -286,7 +289,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac SecurityId: curdef.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), Memo: memo + "(fees)", - Amount: fees.FloatString(curdef.Precision), + Amount: models.Amount{fees}, }) } if num := load.Num(); !num.IsInt64() || num.Int64() != 0 { @@ -298,7 +301,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac SecurityId: curdef.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), Memo: memo + "(load)", - Amount: load.FloatString(curdef.Precision), + Amount: models.Amount{load}, }) } t.Splits = append(t.Splits, &models.Split{ @@ -309,7 +312,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac SecurityId: -1, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? @@ -319,7 +322,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac SecurityId: curdef.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), Memo: memo, - Amount: tradingTotal.FloatString(curdef.Precision), + Amount: models.Amount{tradingTotal}, }) var units big.Rat @@ -332,7 +335,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac SecurityId: security.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), Memo: memo, - Amount: units.FloatString(security.Precision), + Amount: models.Amount{units}, }) units.Neg(&units) t.Splits = append(t.Splits, &models.Split{ @@ -343,7 +346,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac SecurityId: security.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), Memo: memo, - Amount: units.FloatString(security.Precision), + Amount: models.Amount{units}, }) return &t, nil @@ -378,7 +381,7 @@ func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *models.Security, SecurityId: -1, RemoteId: "ofx:" + income.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) total.Neg(&total) t.Splits = append(t.Splits, &models.Split{ @@ -389,7 +392,7 @@ func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *models.Security, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + income.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) return &t, nil @@ -423,7 +426,7 @@ func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *models. SecurityId: -1, RemoteId: "ofx:" + expense.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) total.Neg(&total) t.Splits = append(t.Splits, &models.Split{ @@ -434,7 +437,7 @@ func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *models. SecurityId: curdef.SecurityId, RemoteId: "ofx:" + expense.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) return &t, nil @@ -462,7 +465,7 @@ func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curde SecurityId: -1, RemoteId: "ofx:" + marginint.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) total.Neg(&total) t.Splits = append(t.Splits, &models.Split{ @@ -473,7 +476,7 @@ func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curde SecurityId: curdef.SecurityId, RemoteId: "ofx:" + marginint.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) return &t, nil @@ -526,7 +529,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), Memo: memo + "(commission)", - Amount: commission.FloatString(curdef.Precision), + Amount: models.Amount{commission}, }) } if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 { @@ -538,7 +541,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), Memo: memo + "(taxes)", - Amount: taxes.FloatString(curdef.Precision), + Amount: models.Amount{taxes}, }) } if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 { @@ -550,7 +553,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), Memo: memo + "(fees)", - Amount: fees.FloatString(curdef.Precision), + Amount: models.Amount{fees}, }) } if num := load.Num(); !num.IsInt64() || num.Int64() != 0 { @@ -562,7 +565,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), Memo: memo + "(load)", - Amount: load.FloatString(curdef.Precision), + Amount: models.Amount{load}, }) } t.Splits = append(t.Splits, &models.Split{ @@ -573,7 +576,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec SecurityId: -1, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) t.Splits = append(t.Splits, &models.Split{ @@ -584,7 +587,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) total.Neg(&total) t.Splits = append(t.Splits, &models.Split{ @@ -595,7 +598,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec SecurityId: -1, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? @@ -605,7 +608,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), Memo: memo, - Amount: tradingTotal.FloatString(curdef.Precision), + Amount: models.Amount{tradingTotal}, }) var units big.Rat @@ -618,7 +621,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec SecurityId: security.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), Memo: memo, - Amount: units.FloatString(security.Precision), + Amount: models.Amount{units}, }) units.Neg(&units) t.Splits = append(t.Splits, &models.Split{ @@ -629,7 +632,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec SecurityId: security.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), Memo: memo, - Amount: units.FloatString(security.Precision), + Amount: models.Amount{units}, }) return &t, nil @@ -663,7 +666,7 @@ func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *models.Sec SecurityId: -1, RemoteId: "ofx:" + retofcap.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) total.Neg(&total) t.Splits = append(t.Splits, &models.Split{ @@ -674,7 +677,7 @@ func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *models.Sec SecurityId: curdef.SecurityId, RemoteId: "ofx:" + retofcap.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) return &t, nil @@ -730,7 +733,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), Memo: memo + "(commission)", - Amount: commission.FloatString(curdef.Precision), + Amount: models.Amount{commission}, }) } if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 { @@ -742,7 +745,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), Memo: memo + "(taxes)", - Amount: taxes.FloatString(curdef.Precision), + Amount: models.Amount{taxes}, }) } if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 { @@ -754,7 +757,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), Memo: memo + "(fees)", - Amount: fees.FloatString(curdef.Precision), + Amount: models.Amount{fees}, }) } if num := load.Num(); !num.IsInt64() || num.Int64() != 0 { @@ -766,7 +769,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), Memo: memo + "(load)", - Amount: load.FloatString(curdef.Precision), + Amount: models.Amount{load}, }) } t.Splits = append(t.Splits, &models.Split{ @@ -777,7 +780,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, SecurityId: -1, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), Memo: memo, - Amount: total.FloatString(curdef.Precision), + Amount: models.Amount{total}, }) t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? @@ -787,7 +790,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), Memo: memo, - Amount: tradingTotal.FloatString(curdef.Precision), + Amount: models.Amount{tradingTotal}, }) var units big.Rat @@ -800,7 +803,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, SecurityId: security.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), Memo: memo, - Amount: units.FloatString(security.Precision), + Amount: models.Amount{units}, }) units.Neg(&units) t.Splits = append(t.Splits, &models.Split{ @@ -811,7 +814,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, SecurityId: security.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), Memo: memo, - Amount: units.FloatString(security.Precision), + Amount: models.Amount{units}, }) return &t, nil @@ -842,7 +845,7 @@ func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *models.Ac SecurityId: security.SecurityId, RemoteId: "ofx:" + transfer.InvTran.FiTID.String(), Memo: memo, - Amount: units.FloatString(security.Precision), + Amount: models.Amount{units}, }) units.Neg(&units) t.Splits = append(t.Splits, &models.Split{ @@ -853,7 +856,7 @@ func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *models.Ac SecurityId: security.SecurityId, RemoteId: "ofx:" + transfer.InvTran.FiTID.String(), Memo: memo, - Amount: units.FloatString(security.Precision), + Amount: models.Amount{units}, }) return &t, nil diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 8dd6557..68fd311 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -31,9 +31,8 @@ func GetTransactionImbalances(tx store.Tx, t *models.Transaction) (map[int64]big } securityid = account.SecurityId } - amount, _ := t.Splits[i].GetAmount() sum := sums[securityid] - (&sum).Add(&sum, amount) + (&sum).Add(&sum, &t.Splits[i].Amount.Rat) sums[securityid] = sum } return sums, nil diff --git a/internal/integration/common_test.go b/internal/integration/common_test.go index 4d33d85..0dfd95e 100644 --- a/internal/integration/common_test.go +++ b/internal/integration/common_test.go @@ -202,6 +202,19 @@ func uploadFile(client *http.Client, filename, urlsuffix string) error { return nil } +func NewAmount(amt string) models.Amount { + var a models.Amount + if _, ok := a.SetString(amt); !ok { + panic("Unable to call Amount.SetString()") + } + return a +} + +func amountsMatch(a models.Amount, amt string) bool { + cmp := NewAmount(amt) + return a.Cmp(&cmp.Rat) == 0 +} + func accountBalanceHelper(t *testing.T, client *http.Client, account *models.Account, balance string) { t.Helper() transactions, err := getAccountTransactions(client, account.AccountId, 0, 0, "") @@ -209,7 +222,7 @@ func accountBalanceHelper(t *testing.T, client *http.Client, account *models.Acc t.Fatalf("Couldn't fetch account transactions for '%s': %s\n", account.Name, err) } - if transactions.EndingBalance != balance { + if !amountsMatch(transactions.EndingBalance, balance) { t.Errorf("Expected ending balance for '%s' to be '%s', but found %s\n", account.Name, balance, transactions.EndingBalance) } } diff --git a/internal/integration/gnucash_test.go b/internal/integration/gnucash_test.go index cd471c4..6e15e62 100644 --- a/internal/integration/gnucash_test.go +++ b/internal/integration/gnucash_test.go @@ -114,11 +114,11 @@ func TestImportGnucash(t *testing.T) { } var p1787, p2894, p3170 bool for _, price := range *prices.Prices { - if price.CurrencyId == d.securities[0].SecurityId && price.Value == "17.87" { + if price.CurrencyId == d.securities[0].SecurityId && amountsMatch(price.Value, "17.87") { p1787 = true - } else if price.CurrencyId == d.securities[0].SecurityId && price.Value == "28.94" { + } else if price.CurrencyId == d.securities[0].SecurityId && amountsMatch(price.Value, "28.94") { p2894 = true - } else if price.CurrencyId == d.securities[0].SecurityId && price.Value == "31.70" { + } else if price.CurrencyId == d.securities[0].SecurityId && amountsMatch(price.Value, "31.70") { p3170 = true } } diff --git a/internal/integration/prices_test.go b/internal/integration/prices_test.go index cb1fc5f..74f2143 100644 --- a/internal/integration/prices_test.go +++ b/internal/integration/prices_test.go @@ -68,7 +68,7 @@ func TestCreatePrice(t *testing.T) { if !p.Date.Equal(orig.Date) { t.Errorf("Date doesn't match") } - if p.Value != orig.Value { + if p.Value.Cmp(&orig.Value.Rat) != 0 { t.Errorf("Value doesn't match") } if p.RemoteId != orig.RemoteId { @@ -98,7 +98,7 @@ func TestGetPrice(t *testing.T) { if !p.Date.Equal(orig.Date) { t.Errorf("Date doesn't match") } - if p.Value != orig.Value { + if p.Value.Cmp(&orig.Value.Rat) != 0 { t.Errorf("Value doesn't match") } if p.RemoteId != orig.RemoteId { @@ -132,7 +132,7 @@ func TestGetPrices(t *testing.T) { found := false for _, p := range *pl.Prices { - if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date.Equal(orig.Date) && p.Value == orig.Value && p.RemoteId == orig.RemoteId { + if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date.Equal(orig.Date) && p.Value.Cmp(&orig.Value.Rat) == 0 && p.RemoteId == orig.RemoteId { if _, ok := foundIds[p.PriceId]; ok { continue } @@ -146,7 +146,11 @@ func TestGetPrices(t *testing.T) { } } - if numprices != len(*pl.Prices) { + if pl.Prices == nil { + if numprices != 0 { + t.Fatalf("Expected %d prices, received 0", numprices) + } + } else if numprices != len(*pl.Prices) { t.Fatalf("Expected %d prices, received %d", numprices, len(*pl.Prices)) } } @@ -162,7 +166,7 @@ func TestUpdatePrice(t *testing.T) { tmp := curr.SecurityId curr.SecurityId = curr.CurrencyId curr.CurrencyId = tmp - curr.Value = "5.55" + curr.Value = NewAmount("5.55") curr.Date = time.Date(2019, time.June, 5, 12, 5, 6, 7, time.UTC) curr.RemoteId = "something" @@ -181,7 +185,7 @@ func TestUpdatePrice(t *testing.T) { if !p.Date.Equal(curr.Date) { t.Errorf("Date doesn't match") } - if p.Value != curr.Value { + if p.Value.Cmp(&curr.Value.Rat) != 0 { t.Errorf("Value doesn't match") } if p.RemoteId != curr.RemoteId { diff --git a/internal/integration/testdata_test.go b/internal/integration/testdata_test.go index 465fa19..59b1f91 100644 --- a/internal/integration/testdata_test.go +++ b/internal/integration/testdata_test.go @@ -213,35 +213,35 @@ var data = []TestData{ SecurityId: 1, CurrencyId: 0, Date: time.Date(2017, time.January, 2, 21, 0, 0, 0, time.UTC), - Value: "225.24", + Value: NewAmount("225.24"), RemoteId: "12387-129831-1238", }, { SecurityId: 1, CurrencyId: 0, Date: time.Date(2017, time.January, 3, 21, 0, 0, 0, time.UTC), - Value: "226.58", + Value: NewAmount("226.58"), RemoteId: "12387-129831-1239", }, { SecurityId: 1, CurrencyId: 0, Date: time.Date(2017, time.January, 4, 21, 0, 0, 0, time.UTC), - Value: "226.40", + Value: NewAmount("226.40"), RemoteId: "12387-129831-1240", }, { SecurityId: 1, CurrencyId: 0, Date: time.Date(2017, time.January, 5, 21, 0, 0, 0, time.UTC), - Value: "227.21", + Value: NewAmount("227.21"), RemoteId: "12387-129831-1241", }, { SecurityId: 0, CurrencyId: 3, Date: time.Date(2017, time.November, 16, 18, 49, 53, 0, time.UTC), - Value: "0.85", + Value: NewAmount("0.85"), RemoteId: "USDEUR819298714", }, }, @@ -313,13 +313,13 @@ var data = []TestData{ Status: models.Reconciled, AccountId: 1, SecurityId: -1, - Amount: "-5.6", + Amount: NewAmount("-5.6"), }, { Status: models.Reconciled, AccountId: 3, SecurityId: -1, - Amount: "5.6", + Amount: NewAmount("5.6"), }, }, }, @@ -332,13 +332,13 @@ var data = []TestData{ Status: models.Reconciled, AccountId: 1, SecurityId: -1, - Amount: "-81.59", + Amount: NewAmount("-81.59"), }, { Status: models.Reconciled, AccountId: 3, SecurityId: -1, - Amount: "81.59", + Amount: NewAmount("81.59"), }, }, }, @@ -351,13 +351,13 @@ var data = []TestData{ Status: models.Reconciled, AccountId: 1, SecurityId: -1, - Amount: "-39.99", + Amount: NewAmount("-39.99"), }, { Status: models.Entered, AccountId: 4, SecurityId: -1, - Amount: "39.99", + Amount: NewAmount("39.99"), }, }, }, @@ -370,13 +370,13 @@ var data = []TestData{ Status: models.Reconciled, AccountId: 5, SecurityId: -1, - Amount: "-24.56", + Amount: NewAmount("-24.56"), }, { Status: models.Entered, AccountId: 6, SecurityId: -1, - Amount: "24.56", + Amount: NewAmount("24.56"), }, }, }, diff --git a/internal/integration/transactions_test.go b/internal/integration/transactions_test.go index c9cf73f..7a2c585 100644 --- a/internal/integration/transactions_test.go +++ b/internal/integration/transactions_test.go @@ -120,7 +120,7 @@ func ensureTransactionsMatch(t *testing.T, expected, tran *models.Transaction, a origsplit.RemoteId == origsplit.RemoteId && origsplit.Number == s.Number && origsplit.Memo == s.Memo && - origsplit.Amount == s.Amount && + origsplit.Amount.Cmp(&s.Amount.Rat) == 0 && (!matchsplitids || origsplit.SplitId == s.SplitId) { if _, ok := foundIds[s.SplitId]; ok { @@ -187,13 +187,13 @@ func TestCreateTransaction(t *testing.T) { Status: models.Reconciled, AccountId: d.accounts[1].AccountId, SecurityId: -1, - Amount: "-39.98", + Amount: NewAmount("-39.98"), }, { Status: models.Entered, AccountId: d.accounts[4].AccountId, SecurityId: -1, - Amount: "39.99", + Amount: NewAmount("39.99"), }, }, } @@ -333,7 +333,7 @@ func TestUpdateTransaction(t *testing.T) { tran.UserId = curr.UserId // Make sure we can't create an unbalanced transaction - tran.Splits[len(tran.Splits)-1].Amount = "42" + tran.Splits[len(tran.Splits)-1].Amount = NewAmount("42") _, err = updateTransaction(d.clients[orig.UserId], tran) if err == nil { t.Fatalf("Expected error updating imbalanced transaction") diff --git a/internal/models/amounts.go b/internal/models/amounts.go new file mode 100644 index 0000000..f5627a6 --- /dev/null +++ b/internal/models/amounts.go @@ -0,0 +1,156 @@ +package models + +import ( + "encoding/json" + "fmt" + "math" + "math/big" + "strings" +) + +type Amount struct { + big.Rat +} + +type PrecisionError struct { + message string +} + +func (p PrecisionError) Error() string { + return p.message +} + +// Whole returns the integral portion of the Amount +func (amount Amount) Whole() (int64, error) { + var whole big.Int + whole.Quo(amount.Num(), amount.Denom()) + if whole.IsInt64() { + return whole.Int64(), nil + } + return 0, PrecisionError{"integral portion of Amount cannot be represented as an int64"} +} + +// Fractional returns the fractional portion of the Amount, multiplied by +// 10^precision +func (amount Amount) Fractional(precision uint64) (int64, error) { + if precision < amount.Precision() { + return 0, PrecisionError{"Fractional portion of Amount cannot be represented with the given precision"} + } + + // Reduce the fraction to its simplest form + var r, gcd, d, n big.Int + r.Rem(amount.Num(), amount.Denom()) + gcd.GCD(nil, nil, &r, amount.Denom()) + if gcd.Sign() != 0 { + n.Quo(&r, &gcd) + d.Quo(amount.Denom(), &gcd) + } else { + n.Set(&r) + d.Set(amount.Denom()) + } + + // Figure out what we need to multiply the numerator by to get the + // denominator to be 10^precision + var prec, multiplier big.Int + prec.SetUint64(precision) + multiplier.SetInt64(10) + multiplier.Exp(&multiplier, &prec, nil) + multiplier.Quo(&multiplier, &d) + + n.Mul(&n, &multiplier) + if n.IsInt64() { + return n.Int64(), nil + } + return 0, fmt.Errorf("Fractional portion of Amount does not fit in int64 with given precision") +} + +// FromParts re-assembles an Amount from the results from previous calls to +// Whole and Fractional +func (amount *Amount) FromParts(whole, fractional int64, precision uint64) { + var fracnum, fracdenom, power big.Int + fracnum.SetInt64(fractional) + fracdenom.SetInt64(10) + power.SetUint64(precision) + fracdenom.Exp(&fracdenom, &power, nil) + + var fracrat big.Rat + fracrat.SetFrac(&fracnum, &fracdenom) + amount.Rat.SetInt64(whole) + amount.Rat.Add(&amount.Rat, &fracrat) +} + +// Round rounds the given Amount to the given precision +func (amount *Amount) Round(precision uint64) { + // This probably isn't exactly the most efficient way to do this... + amount.SetString(amount.FloatString(int(precision))) +} + +func (amount Amount) String() string { + return amount.FloatString(int(amount.Precision())) +} + +func (amount *Amount) UnmarshalJSON(bytes []byte) error { + var value string + if err := json.Unmarshal(bytes, &value); err != nil { + return err + } + value = strings.TrimSpace(value) + + if _, ok := amount.SetString(value); !ok { + return fmt.Errorf("Failed to parse '%s' into Amount", value) + } + return nil +} + +func (amount Amount) MarshalJSON() ([]byte, error) { + return json.Marshal(amount.String()) +} + +// Precision returns the minimum positive integer p such that if you multiplied +// this Amount by 10^p, it would become an integer +func (amount Amount) Precision() uint64 { + if amount.IsInt() || amount.Sign() == 0 { + return 0 + } + + // Find d, the denominator of the reduced fractional portion of 'amount' + var r, gcd, d big.Int + r.Rem(amount.Num(), amount.Denom()) + gcd.GCD(nil, nil, &r, amount.Denom()) + if gcd.Sign() != 0 { + d.Quo(amount.Denom(), &gcd) + } else { + d.Set(amount.Denom()) + } + d.Abs(&d) + + var power, result big.Int + one := big.NewInt(1) + ten := big.NewInt(10) + + // Estimate an initial power + if d.IsUint64() { + power.SetInt64(int64(math.Log10(float64(d.Uint64())))) + } else { + + // If the simplified denominator wasn't a uint64, its > 10^19 + power.SetInt64(19) + } + + // If the initial estimate was too high, bring it down + result.Exp(ten, &power, nil) + for result.Cmp(&d) > 0 { + power.Sub(&power, one) + result.Exp(ten, &power, nil) + } + // If it was too low, bring it up + for result.Cmp(&d) < 0 { + power.Add(&power, one) + result.Exp(ten, &power, nil) + } + + if !power.IsUint64() { + panic("Unable to represent Amount's precision as a uint64") + } + return power.Uint64() +} diff --git a/internal/models/amounts_test.go b/internal/models/amounts_test.go new file mode 100644 index 0000000..c7f45ed --- /dev/null +++ b/internal/models/amounts_test.go @@ -0,0 +1,159 @@ +package models_test + +import ( + "github.com/aclindsa/moneygo/internal/models" + "testing" +) + +func expectedPrecision(t *testing.T, amount *models.Amount, precision uint64) { + t.Helper() + if amount.Precision() != precision { + t.Errorf("Expected precision %d for %s, found %d", precision, amount.String(), amount.Precision()) + } +} + +func TestAmountPrecision(t *testing.T) { + var a models.Amount + a.SetString("1.1928712") + expectedPrecision(t, &a, 7) + a.SetString("0") + expectedPrecision(t, &a, 0) + a.SetString("-0.7") + expectedPrecision(t, &a, 1) + a.SetString("-1.1837281037509137509173049173052130957210361309572047598275398265926351231426357130289523647634895285603247284245928712") + expectedPrecision(t, &a, 118) + a.SetInt64(1050) + expectedPrecision(t, &a, 0) +} + +func TestAmountRound(t *testing.T) { + var a models.Amount + tests := []struct { + String string + RoundTo uint64 + Expected string + }{ + {"0", 5, "0"}, + {"929.92928", 2, "929.93"}, + {"-105.499999", 4, "-105.5"}, + {"0.5111111", 1, "0.5"}, + {"0.5111111", 0, "1"}, + {"9.876456", 3, "9.876"}, + } + + for _, test := range tests { + a.SetString(test.String) + a.Round(test.RoundTo) + if a.String() != test.Expected { + t.Errorf("Expected '%s' after Round(%d) to be %s intead of %s\n", test.String, test.RoundTo, test.Expected, a.String()) + } + } +} + +func TestAmountString(t *testing.T) { + var a models.Amount + for _, s := range []string{ + "1.1928712", + "0", + "-0.7", + "-1.1837281037509137509173049173052130957210361309572047598275398265926351231426357130289523647634895285603247284245928712", + "1050", + } { + a.SetString(s) + if s != a.String() { + t.Errorf("Expected '%s', found '%s'", s, a.String()) + } + } + + a.SetString("+182.27") + if "182.27" != a.String() { + t.Errorf("Expected '182.27', found '%s'", a.String()) + } + a.SetString("-0") + if "0" != a.String() { + t.Errorf("Expected '0', found '%s'", a.String()) + } +} + +func TestWhole(t *testing.T) { + var a models.Amount + tests := []struct { + String string + Whole int64 + }{ + {"0", 0}, + {"-0", 0}, + {"181.1293871230", 181}, + {"-0.1821", 0}, + {"99992737.9", 99992737}, + {"-7380.000009", -7380}, + {"4108740192740912741", 4108740192740912741}, + } + + for _, test := range tests { + a.SetString(test.String) + val, err := a.Whole() + if err != nil { + t.Errorf("Unexpected error: %s\n", err) + } else if val != test.Whole { + t.Errorf("Expected '%s'.Whole() to return %d intead of %d\n", test.String, test.Whole, val) + } + } + + a.SetString("81367662642302823790328492349823472634926342") + _, err := a.Whole() + if err == nil { + t.Errorf("Expected error for overflowing int64") + } +} + +func TestFractional(t *testing.T) { + var a models.Amount + tests := []struct { + String string + Precision uint64 + Fractional int64 + }{ + {"0", 5, 0}, + {"181.1293871230", 9, 129387123}, + {"181.1293871230", 10, 1293871230}, + {"181.1293871230", 15, 129387123000000}, + {"1828.37", 7, 3700000}, + {"-0.748", 5, -74800}, + {"-9", 5, 0}, + {"-9.9", 1, -9}, + } + + for _, test := range tests { + a.SetString(test.String) + val, err := a.Fractional(test.Precision) + if err != nil { + t.Errorf("Unexpected error: %s\n", err) + } else if val != test.Fractional { + t.Errorf("Expected '%s'.Fractional(%d) to return %d intead of %d\n", test.String, test.Precision, test.Fractional, val) + } + } +} + +func TestFromParts(t *testing.T) { + var a models.Amount + tests := []struct { + Whole int64 + Fractional int64 + Precision uint64 + Result string + }{ + {839, 9080, 4, "839.908"}, + {-10, 0, 5, "-10"}, + {0, 900, 10, "0.00000009"}, + {9128713621, 87272727, 20, "9128713621.00000000000087272727"}, + {89, 1, 0, "90"}, // Not sure if this should really be supported, but it is + } + + for _, test := range tests { + a.FromParts(test.Whole, test.Fractional, test.Precision) + if a.String() != test.Result { + t.Errorf("Expected Amount.FromParts(%d, %d, %d) to return %s intead of %s\n", test.Whole, test.Fractional, test.Precision, test.Result, a.String()) + } + } +} diff --git a/internal/models/prices.go b/internal/models/prices.go index 7958e52..13594d6 100644 --- a/internal/models/prices.go +++ b/internal/models/prices.go @@ -12,7 +12,7 @@ type Price struct { SecurityId int64 CurrencyId int64 Date time.Time - Value string // String representation of decimal price of Security in Currency units, suitable for passing to big.Rat.SetString() + Value Amount // price of Security in Currency units RemoteId string // unique ID from source, for detecting duplicates } diff --git a/internal/models/securities.go b/internal/models/securities.go index 67557be..9d7125a 100644 --- a/internal/models/securities.go +++ b/internal/models/securities.go @@ -23,6 +23,9 @@ func GetSecurityType(typestring string) SecurityType { } } +// MaxPrexision denotes the maximum valid value for Security.Precision +const MaxPrecision uint64 = 15 + type Security struct { SecurityId int64 UserId int64 @@ -31,7 +34,7 @@ type Security struct { Symbol string // Number of decimal digits (to the right of the decimal point) this // security is precise to - Precision int `db:"Preciseness"` + Precision uint64 `db:"Preciseness"` Type SecurityType // AlternateId is CUSIP for Type=Stock, ISO4217 for Type=Currency AlternateId string diff --git a/internal/models/transactions.go b/internal/models/transactions.go index 5046e65..e873a95 100644 --- a/internal/models/transactions.go +++ b/internal/models/transactions.go @@ -2,8 +2,6 @@ package models import ( "encoding/json" - "errors" - "math/big" "net/http" "strings" "time" @@ -49,28 +47,11 @@ type Split struct { RemoteId string // unique ID from server, for detecting duplicates Number string // Check or reference number Memo string - Amount string // String representation of decimal, suitable for passing to big.Rat.SetString() -} - -func GetBigAmount(amt string) (*big.Rat, error) { - var r big.Rat - _, success := r.SetString(amt) - if !success { - return nil, errors.New("Couldn't convert string amount to big.Rat via SetString()") - } - return &r, nil -} - -func (s *Split) GetAmount() (*big.Rat, error) { - return GetBigAmount(s.Amount) + Amount Amount } func (s *Split) Valid() bool { - if (s.AccountId == -1) == (s.SecurityId == -1) { - return false - } - _, err := s.GetAmount() - return err == nil + return (s.AccountId == -1) != (s.SecurityId == -1) } type Transaction struct { @@ -89,8 +70,8 @@ type AccountTransactionsList struct { Account *Account Transactions *[]*Transaction TotalTransactions int64 - BeginningBalance string - EndingBalance string + BeginningBalance Amount + EndingBalance Amount } func (t *Transaction) Write(w http.ResponseWriter) error { diff --git a/internal/reports/accounts.go b/internal/reports/accounts.go index d495c06..faa7778 100644 --- a/internal/reports/accounts.go +++ b/internal/reports/accounts.go @@ -147,18 +147,12 @@ func luaAccount__index(L *lua.LState) int { return 1 } -func balanceFromSplits(splits *[]*models.Split) (*big.Rat, error) { - var balance, tmp big.Rat +func balanceFromSplits(splits *[]*models.Split) *big.Rat { + var balance big.Rat for _, s := range *splits { - rat_amount, err := models.GetBigAmount(s.Amount) - if err != nil { - return nil, err - } - tmp.Add(&balance, rat_amount) - balance.Set(&tmp) + balance.Add(&balance, &s.Amount.Rat) } - - return &balance, nil + return &balance } func luaAccountBalance(L *lua.LState) int { @@ -196,14 +190,12 @@ func luaAccountBalance(L *lua.LState) int { if err != nil { panic("Failed to fetch splits for account:" + err.Error()) } - rat, err := balanceFromSplits(splits) - if err != nil { - panic("Failed to calculate balance for account:" + err.Error()) - } + rat := balanceFromSplits(splits) b := &Balance{ - Amount: rat, + Amount: models.Amount{*rat}, Security: security, } + L.Push(BalanceToLua(L, b)) return 1 diff --git a/internal/reports/balance.go b/internal/reports/balance.go index acd5c43..bc06bbe 100644 --- a/internal/reports/balance.go +++ b/internal/reports/balance.go @@ -3,12 +3,11 @@ package reports import ( "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" - "math/big" ) type Balance struct { Security *models.Security - Amount *big.Rat + Amount models.Amount } const luaBalanceTypeName = "balance" @@ -66,10 +65,8 @@ func luaGetBalanceOperands(L *lua.LState, m int, n int) (*Balance, *Balance) { } else if bm != nil { nn := L.CheckNumber(n) var balance Balance - var rat big.Rat balance.Security = bm.Security - balance.Amount = rat.SetFloat64(float64(nn)) - if balance.Amount == nil { + if balance.Amount.SetFloat64(float64(nn)) == nil { L.ArgError(n, "non-finite float invalid for operand to balance arithemetic") return nil, nil } @@ -77,10 +74,8 @@ func luaGetBalanceOperands(L *lua.LState, m int, n int) (*Balance, *Balance) { } else if bn != nil { nm := L.CheckNumber(m) var balance Balance - var rat big.Rat balance.Security = bn.Security - balance.Amount = rat.SetFloat64(float64(nm)) - if balance.Amount == nil { + if balance.Amount.SetFloat64(float64(nm)) == nil { L.ArgError(m, "non-finite float invalid for operand to balance arithemetic") return nil, nil } @@ -110,7 +105,7 @@ func luaBalance__index(L *lua.LState) int { func luaBalance__tostring(L *lua.LState) int { b := luaCheckBalance(L, 1) - L.Push(lua.LString(b.Security.Symbol + " " + b.Amount.FloatString(b.Security.Precision))) + L.Push(lua.LString(b.Security.Symbol + " " + b.Amount.FloatString(int(b.Security.Precision)))) return 1 } @@ -119,7 +114,7 @@ func luaBalance__eq(L *lua.LState) int { a := luaCheckBalance(L, 1) b := luaCheckBalance(L, 2) - L.Push(lua.LBool(a.Security.SecurityId == b.Security.SecurityId && a.Amount.Cmp(b.Amount) == 0)) + L.Push(lua.LBool(a.Security.SecurityId == b.Security.SecurityId && a.Amount.Cmp(&b.Amount.Rat) == 0)) return 1 } @@ -131,7 +126,7 @@ func luaBalance__lt(L *lua.LState) int { L.ArgError(2, "Can't compare balances with different securities") } - L.Push(lua.LBool(a.Amount.Cmp(b.Amount) < 0)) + L.Push(lua.LBool(a.Amount.Cmp(&b.Amount.Rat) < 0)) return 1 } @@ -143,7 +138,7 @@ func luaBalance__le(L *lua.LState) int { L.ArgError(2, "Can't compare balances with different securities") } - L.Push(lua.LBool(a.Amount.Cmp(b.Amount) <= 0)) + L.Push(lua.LBool(a.Amount.Cmp(&b.Amount.Rat) <= 0)) return 1 } @@ -156,9 +151,8 @@ func luaBalance__add(L *lua.LState) int { } var balance Balance - var rat big.Rat balance.Security = a.Security - balance.Amount = rat.Add(a.Amount, b.Amount) + balance.Amount.Add(&a.Amount.Rat, &b.Amount.Rat) L.Push(BalanceToLua(L, &balance)) return 1 @@ -172,9 +166,8 @@ func luaBalance__sub(L *lua.LState) int { } var balance Balance - var rat big.Rat balance.Security = a.Security - balance.Amount = rat.Sub(a.Amount, b.Amount) + balance.Amount.Sub(&a.Amount.Rat, &b.Amount.Rat) L.Push(BalanceToLua(L, &balance)) return 1 @@ -188,9 +181,8 @@ func luaBalance__mul(L *lua.LState) int { } var balance Balance - var rat big.Rat balance.Security = a.Security - balance.Amount = rat.Mul(a.Amount, b.Amount) + balance.Amount.Mul(&a.Amount.Rat, &b.Amount.Rat) L.Push(BalanceToLua(L, &balance)) return 1 @@ -204,9 +196,8 @@ func luaBalance__div(L *lua.LState) int { } var balance Balance - var rat big.Rat balance.Security = a.Security - balance.Amount = rat.Quo(a.Amount, b.Amount) + balance.Amount.Quo(&a.Amount.Rat, &b.Amount.Rat) L.Push(BalanceToLua(L, &balance)) return 1 @@ -216,9 +207,8 @@ func luaBalance__unm(L *lua.LState) int { b := luaCheckBalance(L, 1) var balance Balance - var rat big.Rat balance.Security = b.Security - balance.Amount = rat.Neg(b.Amount) + balance.Amount.Neg(&b.Amount.Rat) L.Push(BalanceToLua(L, &balance)) return 1 diff --git a/internal/reports/prices.go b/internal/reports/prices.go index 862448a..6c80e4c 100644 --- a/internal/reports/prices.go +++ b/internal/reports/prices.go @@ -60,11 +60,7 @@ func luaPrice__index(L *lua.LState) int { } L.Push(SecurityToLua(L, c)) case "Value", "value": - amt, err := models.GetBigAmount(p.Value) - if err != nil { - panic(err) - } - float, _ := amt.Float64() + float, _ := p.Value.Float64() L.Push(lua.LNumber(float)) default: L.ArgError(2, "unexpected price attribute: "+field) @@ -86,7 +82,7 @@ func luaPrice__tostring(L *lua.LState) int { panic("Price's currency or security not found for user") } - L.Push(lua.LString(p.Value + " " + c.Symbol + " (" + s.Symbol + ")")) + L.Push(lua.LString(p.Value.String() + " " + c.Symbol + " (" + s.Symbol + ")")) return 1 } diff --git a/internal/store/db/db.go b/internal/store/db/db.go index 5daa9aa..76669c5 100644 --- a/internal/store/db/db.go +++ b/internal/store/db/db.go @@ -40,10 +40,10 @@ func getDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId") dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId") - dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId") + dbmap.AddTableWithName(Price{}, "prices").SetKeys(true, "PriceId") dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId") dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId") - dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId") + dbmap.AddTableWithName(Split{}, "splits").SetKeys(true, "SplitId") rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId") rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer) diff --git a/internal/store/db/prices.go b/internal/store/db/prices.go index 2df2ab9..e44f32f 100644 --- a/internal/store/db/prices.go +++ b/internal/store/db/prices.go @@ -6,73 +6,150 @@ import ( "time" ) +// Price is a mirror of models.Price with the Value broken out into whole and +// fractional components +type Price struct { + PriceId int64 + SecurityId int64 + CurrencyId int64 + Date time.Time + WholeValue int64 + FractionalValue int64 + RemoteId string // unique ID from source, for detecting duplicates +} + +func NewPrice(p *models.Price) (*Price, error) { + whole, err := p.Value.Whole() + if err != nil { + return nil, err + } + fractional, err := p.Value.Fractional(MaxPrecision) + if err != nil { + return nil, err + } + return &Price{ + PriceId: p.PriceId, + SecurityId: p.SecurityId, + CurrencyId: p.CurrencyId, + Date: p.Date, + WholeValue: whole, + FractionalValue: fractional, + RemoteId: p.RemoteId, + }, nil +} + +func (p Price) Price() *models.Price { + price := &models.Price{ + PriceId: p.PriceId, + SecurityId: p.SecurityId, + CurrencyId: p.CurrencyId, + Date: p.Date, + RemoteId: p.RemoteId, + } + price.Value.FromParts(p.WholeValue, p.FractionalValue, MaxPrecision) + + return price +} + func (tx *Tx) PriceExists(price *models.Price) (bool, error) { - var prices []*models.Price - _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) + p, err := NewPrice(price) + if err != nil { + return false, err + } + + var prices []*Price + _, err = tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND WholeValue=? AND FractionalValue=?", p.SecurityId, p.CurrencyId, p.Date, p.WholeValue, p.FractionalValue) return len(prices) > 0, err } func (tx *Tx) InsertPrice(price *models.Price) error { - return tx.Insert(price) + p, err := NewPrice(price) + if err != nil { + return err + } + err = tx.Insert(p) + if err != nil { + return err + } + *price = *p.Price() + return nil } func (tx *Tx) GetPrice(priceid, securityid int64) (*models.Price, error) { - var price models.Price + var price Price err := tx.SelectOne(&price, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid) if err != nil { return nil, err } - return &price, nil + return price.Price(), nil } func (tx *Tx) GetPrices(securityid int64) (*[]*models.Price, error) { - var prices []*models.Price + var prices []*Price + var modelprices []*models.Price _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid) if err != nil { return nil, err } - return &prices, nil + + for _, p := range prices { + modelprices = append(modelprices, p.Price()) + } + + return &modelprices, nil } // Return the latest price for security in currency units before date func (tx *Tx) GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) { - var price models.Price + var price Price err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { return nil, err } - return &price, nil + return price.Price(), nil } // Return the earliest price for security in currency units after date func (tx *Tx) GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) { - var price models.Price + var price Price err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { return nil, err } - return &price, nil + return price.Price(), nil } func (tx *Tx) UpdatePrice(price *models.Price) error { - count, err := tx.Update(price) + p, err := NewPrice(price) + if err != nil { + return err + } + + count, err := tx.Update(p) if err != nil { return err } if count != 1 { return fmt.Errorf("Expected to update 1 price, was going to update %d", count) } + *price = *p.Price() return nil } func (tx *Tx) DeletePrice(price *models.Price) error { - count, err := tx.Delete(price) + p, err := NewPrice(price) + if err != nil { + return err + } + + count, err := tx.Delete(p) if err != nil { return err } if count != 1 { return fmt.Errorf("Expected to delete 1 price, was going to delete %d", count) } + *price = *p.Price() return nil } diff --git a/internal/store/db/securities.go b/internal/store/db/securities.go index 83a5659..6db9b28 100644 --- a/internal/store/db/securities.go +++ b/internal/store/db/securities.go @@ -6,6 +6,17 @@ import ( "github.com/aclindsa/moneygo/internal/store" ) +// MaxPrexision denotes the maximum valid value for models.Security.Precision. +// This constant is used when storing amounts in securities into the database, +// so it must not be changed without appropriately migrating the database. +const MaxPrecision uint64 = 15 + +func init() { + if MaxPrecision < models.MaxPrecision { + panic("db.MaxPrecision must be >= models.MaxPrecision") + } +} + func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) { var s models.Security diff --git a/internal/store/db/transactions.go b/internal/store/db/transactions.go index 29168df..1bcf989 100644 --- a/internal/store/db/transactions.go +++ b/internal/store/db/transactions.go @@ -9,6 +9,71 @@ import ( "time" ) +// Split is a mirror of models.Split with the Amount broken out into whole and +// fractional components +type Split struct { + SplitId int64 + TransactionId int64 + Status int64 + ImportSplitType int64 + + // One of AccountId and SecurityId must be -1 + // In normal splits, AccountId will be valid and SecurityId will be -1. The + // only case where this is reversed is for transactions that have been + // imported and not yet associated with an account. + AccountId int64 + SecurityId int64 + + RemoteId string // unique ID from server, for detecting duplicates + Number string // Check or reference number + Memo string + + // Amount.Whole and Amount.Fractional(MaxPrecision) + WholeAmount int64 + FractionalAmount int64 +} + +func NewSplit(s *models.Split) (*Split, error) { + whole, err := s.Amount.Whole() + if err != nil { + return nil, err + } + fractional, err := s.Amount.Fractional(MaxPrecision) + if err != nil { + return nil, err + } + return &Split{ + SplitId: s.SplitId, + TransactionId: s.TransactionId, + Status: s.Status, + ImportSplitType: s.ImportSplitType, + AccountId: s.AccountId, + SecurityId: s.SecurityId, + RemoteId: s.RemoteId, + Number: s.Number, + Memo: s.Memo, + WholeAmount: whole, + FractionalAmount: fractional, + }, nil +} + +func (s Split) Split() *models.Split { + split := &models.Split{ + SplitId: s.SplitId, + TransactionId: s.TransactionId, + Status: s.Status, + ImportSplitType: s.ImportSplitType, + AccountId: s.AccountId, + SecurityId: s.SecurityId, + RemoteId: s.RemoteId, + Number: s.Number, + Memo: s.Memo, + } + split.Amount.FromParts(s.WholeAmount, s.FractionalAmount, MaxPrecision) + + return split +} + func (tx *Tx) incrementAccountVersions(user *models.User, accountids []int64) error { for i := range accountids { account, err := tx.GetAccount(accountids[i], user.UserId) @@ -68,10 +133,15 @@ func (tx *Tx) InsertTransaction(t *models.Transaction, user *models.User) error for i := range t.Splits { t.Splits[i].TransactionId = t.TransactionId t.Splits[i].SplitId = -1 - err = tx.Insert(t.Splits[i]) + s, err := NewSplit(t.Splits[i]) if err != nil { return err } + err = tx.Insert(s) + if err != nil { + return err + } + *t.Splits[i] = *s.Split() } return nil @@ -84,17 +154,22 @@ func (tx *Tx) SplitExists(s *models.Split) (bool, error) { func (tx *Tx) GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) { var t models.Transaction + var splits []*Split err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) if err != nil { return nil, err } - _, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) + _, err = tx.Select(&splits, "SELECT * from splits where TransactionId=?", transactionid) if err != nil { return nil, err } + for _, split := range splits { + t.Splits = append(t.Splits, split.Split()) + } + return &t, nil } @@ -107,17 +182,21 @@ func (tx *Tx) GetTransactions(userid int64) (*[]*models.Transaction, error) { } for i := range transactions { - _, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) + var splits []*Split + _, err := tx.Select(&splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) if err != nil { return nil, err } + for _, split := range splits { + transactions[i].Splits = append(transactions[i].Splits, split.Split()) + } } return &transactions, nil } func (tx *Tx) UpdateTransaction(t *models.Transaction, user *models.User) error { - var existing_splits []*models.Split + var existing_splits []*Split _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) if err != nil { @@ -136,25 +215,30 @@ func (tx *Tx) UpdateTransaction(t *models.Transaction, user *models.User) error // Insert splits, updating any pre-existing ones for i := range t.Splits { t.Splits[i].TransactionId = t.TransactionId - _, ok := s_map[t.Splits[i].SplitId] + s, err := NewSplit(t.Splits[i]) + if err != nil { + return err + } + _, ok := s_map[s.SplitId] if ok { - count, err := tx.Update(t.Splits[i]) + count, err := tx.Update(s) if err != nil { return err } if count > 1 { return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count) } - delete(s_map, t.Splits[i].SplitId) + delete(s_map, s.SplitId) } else { - t.Splits[i].SplitId = -1 - err := tx.Insert(t.Splits[i]) + s.SplitId = -1 + err := tx.Insert(s) if err != nil { return err } } + *t.Splits[i] = *s.Split() if t.Splits[i].AccountId != -1 { - a_map[t.Splits[i].AccountId] = true + a_map[s.AccountId] = true } } @@ -222,57 +306,69 @@ func (tx *Tx) DeleteTransaction(t *models.Transaction, user *models.User) error } func (tx *Tx) GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) { - var splits []*models.Split + var modelsplits []*models.Split + var splits []*Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" _, err := tx.Select(&splits, sql, accountid, user.UserId) if err != nil { return nil, err } - return &splits, nil + + for _, s := range splits { + modelsplits = append(modelsplits, s.Split()) + } + return &modelsplits, nil } // Assumes accountid is valid and is owned by the current user func (tx *Tx) GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) { - var splits []*models.Split + var modelsplits []*models.Split + var splits []*Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" _, err := tx.Select(&splits, sql, accountid, user.UserId, date) if err != nil { return nil, err } - return &splits, err + + for _, s := range splits { + modelsplits = append(modelsplits, s.Split()) + } + return &modelsplits, nil } func (tx *Tx) GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) { - var splits []*models.Split + var modelsplits []*models.Split + var splits []*Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" _, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) if err != nil { return nil, err } - return &splits, nil + + for _, s := range splits { + modelsplits = append(modelsplits, s.Split()) + } + return &modelsplits, nil } func (tx *Tx) transactionsBalanceDifference(accountid int64, transactions []*models.Transaction) (*big.Rat, error) { - var pageDifference, tmp big.Rat + var pageDifference big.Rat for i := range transactions { - _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) + var splits []*Split + _, err := tx.Select(&splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) if err != nil { return nil, err } // Sum up the amounts from the splits we're returning so we can return // an ending balance - for j := range transactions[i].Splits { + for j, s := range splits { + transactions[i].Splits = append(transactions[i].Splits, s.Split()) if transactions[i].Splits[j].AccountId == accountid { - rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount) - if err != nil { - return nil, err - } - tmp.Add(&pageDifference, rat_amount) - pageDifference.Set(&tmp) + pageDifference.Add(&pageDifference, &transactions[i].Splits[j].Amount.Rat) } } } @@ -338,24 +434,31 @@ func (tx *Tx) GetAccountTransactions(user *models.User, accountid int64, sort st // Sum all the splits for all transaction splits for this account that // occurred before the page we're returning - var amounts []string - sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" - _, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) + sql = "FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" + count, err = tx.SelectInt("SELECT count(*) "+sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) if err != nil { return nil, err } - var tmp, balance big.Rat - for _, amount := range amounts { - rat_amount, err := models.GetBigAmount(amount) + var balance models.Amount + + // Don't attempt to 'sum()' the splits if none exist, because it is + // supposed to return null/nil in this case, which makes gorp angry since + // we're using SelectInt() + if count > 0 { + whole, err := tx.SelectInt("SELECT sum(s.WholeAmount) "+sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) if err != nil { return nil, err } - tmp.Add(&balance, rat_amount) - balance.Set(&tmp) + fractional, err := tx.SelectInt("SELECT sum(s.FractionalAmount) "+sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) + if err != nil { + return nil, err + } + balance.FromParts(whole, fractional, MaxPrecision) } - atl.BeginningBalance = balance.FloatString(security.Precision) - atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) + + atl.BeginningBalance = balance + atl.EndingBalance.Rat.Add(&balance.Rat, pageDifference) return &atl, nil }