mirror of
https://github.com/aclindsa/moneygo.git
synced 2024-10-29 23:40:04 -04:00
commit
d9ddef250a
48
accounts.go
48
accounts.go
@ -124,10 +124,10 @@ func (al *AccountList) Write(w http.ResponseWriter) error {
|
||||
return enc.Encode(al)
|
||||
}
|
||||
|
||||
func GetAccount(accountid int64, userid int64) (*Account, error) {
|
||||
func GetAccount(db *DB, accountid int64, userid int64) (*Account, error) {
|
||||
var a Account
|
||||
|
||||
err := DB.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
||||
err := db.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -145,10 +145,10 @@ func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64)
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
func GetAccounts(userid int64) (*[]Account, error) {
|
||||
func GetAccounts(db *DB, userid int64) (*[]Account, error) {
|
||||
var accounts []Account
|
||||
|
||||
_, err := DB.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
||||
_, err := db.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -276,8 +276,8 @@ func (pame ParentAccountMissingError) Error() string {
|
||||
return "Parent account missing"
|
||||
}
|
||||
|
||||
func insertUpdateAccount(a *Account, insert bool) error {
|
||||
transaction, err := DB.Begin()
|
||||
func insertUpdateAccount(db *DB, a *Account, insert bool) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -329,16 +329,16 @@ func insertUpdateAccount(a *Account, insert bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertAccount(a *Account) error {
|
||||
return insertUpdateAccount(a, true)
|
||||
func InsertAccount(db *DB, a *Account) error {
|
||||
return insertUpdateAccount(db, a, true)
|
||||
}
|
||||
|
||||
func UpdateAccount(a *Account) error {
|
||||
return insertUpdateAccount(a, false)
|
||||
func UpdateAccount(db *DB, a *Account) error {
|
||||
return insertUpdateAccount(db, a, false)
|
||||
}
|
||||
|
||||
func DeleteAccount(a *Account) error {
|
||||
transaction, err := DB.Begin()
|
||||
func DeleteAccount(db *DB, a *Account) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -385,8 +385,8 @@ func DeleteAccount(a *Account) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := GetUserFromSession(r)
|
||||
func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
@ -405,7 +405,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
AccountImportHandler(w, r, user, accountid, importtype)
|
||||
AccountImportHandler(db, w, r, user, accountid, importtype)
|
||||
return
|
||||
}
|
||||
|
||||
@ -425,7 +425,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account.UserId = user.UserId
|
||||
account.AccountVersion = 0
|
||||
|
||||
security, err := GetSecurity(account.SecurityId, user.UserId)
|
||||
security, err := GetSecurity(db, account.SecurityId, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -436,7 +436,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = InsertAccount(&account)
|
||||
err = InsertAccount(db, &account)
|
||||
if err != nil {
|
||||
if _, ok := err.(ParentAccountMissingError); ok {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
@ -461,7 +461,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil || n != 1 {
|
||||
//Return all Accounts
|
||||
var al AccountList
|
||||
accounts, err := GetAccounts(user.UserId)
|
||||
accounts, err := GetAccounts(db, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -478,12 +478,12 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// if URL looks like /account/[0-9]+/transactions, use the account
|
||||
// transaction handler
|
||||
if accountTransactionsRE.MatchString(r.URL.Path) {
|
||||
AccountTransactionsHandler(w, r, user, accountid)
|
||||
AccountTransactionsHandler(db, w, r, user, accountid)
|
||||
return
|
||||
}
|
||||
|
||||
// Return Account with this Id
|
||||
account, err := GetAccount(accountid, user.UserId)
|
||||
account, err := GetAccount(db, accountid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
@ -517,7 +517,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
account.UserId = user.UserId
|
||||
|
||||
security, err := GetSecurity(account.SecurityId, user.UserId)
|
||||
security, err := GetSecurity(db, account.SecurityId, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -528,7 +528,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = UpdateAccount(&account)
|
||||
err = UpdateAccount(db, &account)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -542,13 +542,13 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
} else if r.Method == "DELETE" {
|
||||
account, err := GetAccount(accountid, user.UserId)
|
||||
account, err := GetAccount(db, accountid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = DeleteAccount(account)
|
||||
err = DeleteAccount(db, account)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
|
@ -15,14 +15,19 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) {
|
||||
|
||||
ctx := L.Context()
|
||||
|
||||
account_map, ok := ctx.Value(accountsContextKey).(map[int64]*Account)
|
||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find DB in lua's Context")
|
||||
}
|
||||
|
||||
account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account)
|
||||
if !ok {
|
||||
user, ok := ctx.Value(userContextKey).(*User)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find User in lua's Context")
|
||||
}
|
||||
|
||||
accounts, err := GetAccounts(user.UserId)
|
||||
accounts, err := GetAccounts(db, user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -144,6 +149,10 @@ func luaAccountBalance(L *lua.LState) int {
|
||||
a := luaCheckAccount(L, 1)
|
||||
|
||||
ctx := L.Context()
|
||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||
if !ok {
|
||||
panic("Couldn't find DB in lua's Context")
|
||||
}
|
||||
user, ok := ctx.Value(userContextKey).(*User)
|
||||
if !ok {
|
||||
panic("Couldn't find User in lua's Context")
|
||||
@ -162,12 +171,12 @@ func luaAccountBalance(L *lua.LState) int {
|
||||
if date != nil {
|
||||
end := luaWeakCheckTime(L, 3)
|
||||
if end != nil {
|
||||
rat, err = GetAccountBalanceDateRange(user, a.AccountId, date, end)
|
||||
rat, err = GetAccountBalanceDateRange(db, user, a.AccountId, date, end)
|
||||
} else {
|
||||
rat, err = GetAccountBalanceDate(user, a.AccountId, date)
|
||||
rat, err = GetAccountBalanceDate(db, user, a.AccountId, date)
|
||||
}
|
||||
} else {
|
||||
rat, err = GetAccountBalance(user, a.AccountId)
|
||||
rat, err = GetAccountBalance(db, user, a.AccountId)
|
||||
}
|
||||
if err != nil {
|
||||
panic("Failed to GetAccountBalance:" + err.Error())
|
||||
|
28
db.go
28
db.go
@ -2,22 +2,34 @@ package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"gopkg.in/gorp.v1"
|
||||
"log"
|
||||
)
|
||||
|
||||
var DB *gorp.DbMap
|
||||
|
||||
func initDB(cfg *Config) {
|
||||
func initDB(cfg *Config) (*gorp.DbMap, error) {
|
||||
db, err := sql.Open(cfg.MoneyGo.DBType.String(), cfg.MoneyGo.DSN)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dbmap := &gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}}
|
||||
var dialect gorp.Dialect
|
||||
if cfg.MoneyGo.DBType == SQLite {
|
||||
dialect = gorp.SqliteDialect{}
|
||||
} else if cfg.MoneyGo.DBType == MySQL {
|
||||
dialect = gorp.MySQLDialect{
|
||||
Engine: "InnoDB",
|
||||
Encoding: "UTF8",
|
||||
}
|
||||
} else if cfg.MoneyGo.DBType == Postgres {
|
||||
dialect = gorp.PostgresDialect{}
|
||||
} else {
|
||||
return nil, fmt.Errorf("Don't know gorp dialect to go with '%s' DB type", cfg.MoneyGo.DBType.String())
|
||||
}
|
||||
|
||||
dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
|
||||
dbmap.AddTableWithName(User{}, "users").SetKeys(true, "UserId")
|
||||
dbmap.AddTableWithName(Session{}, "sessions").SetKeys(true, "SessionId")
|
||||
dbmap.AddTableWithName(Account{}, "accounts").SetKeys(true, "AccountId")
|
||||
@ -29,8 +41,8 @@ func initDB(cfg *Config) {
|
||||
|
||||
err = dbmap.CreateTablesIfNotExists()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
DB = dbmap
|
||||
return dbmap, nil
|
||||
}
|
||||
|
@ -308,8 +308,8 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) {
|
||||
return &gncimport, nil
|
||||
}
|
||||
|
||||
func GnucashImportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := GetUserFromSession(r)
|
||||
func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
@ -365,7 +365,7 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sqltransaction, err := DB.Begin()
|
||||
sqltransaction, err := db.Begin()
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
|
20
imports.go
20
imports.go
@ -22,7 +22,7 @@ func (od *OFXDownload) Read(json_str string) error {
|
||||
return dec.Decode(od)
|
||||
}
|
||||
|
||||
func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid int64) {
|
||||
func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, accountid int64) {
|
||||
itl, err := ImportOFX(r)
|
||||
|
||||
if err != nil {
|
||||
@ -38,7 +38,7 @@ func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid i
|
||||
return
|
||||
}
|
||||
|
||||
sqltransaction, err := DB.Begin()
|
||||
sqltransaction, err := db.Begin()
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -258,7 +258,7 @@ func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid i
|
||||
WriteSuccess(w)
|
||||
}
|
||||
|
||||
func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
||||
func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
||||
download_json := r.PostFormValue("ofxdownload")
|
||||
if download_json == "" {
|
||||
log.Print("download_json")
|
||||
@ -274,7 +274,7 @@ func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accoun
|
||||
return
|
||||
}
|
||||
|
||||
account, err := GetAccount(accountid, user.UserId)
|
||||
account, err := GetAccount(db, accountid, user.UserId)
|
||||
if err != nil {
|
||||
log.Print("GetAccount")
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
@ -367,10 +367,10 @@ func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accoun
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
ofxImportHelper(response.Body, w, user, accountid)
|
||||
ofxImportHelper(db, response.Body, w, user, accountid)
|
||||
}
|
||||
|
||||
func OFXFileImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
||||
func OFXFileImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
||||
multipartReader, err := r.MultipartReader()
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
@ -390,19 +390,19 @@ func OFXFileImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac
|
||||
return
|
||||
}
|
||||
|
||||
ofxImportHelper(part, w, user, accountid)
|
||||
ofxImportHelper(db, part, w, user, accountid)
|
||||
}
|
||||
|
||||
/*
|
||||
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated
|
||||
*/
|
||||
func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) {
|
||||
func AccountImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) {
|
||||
|
||||
switch importtype {
|
||||
case "ofx":
|
||||
OFXImportHandler(w, r, user, accountid)
|
||||
OFXImportHandler(db, w, r, user, accountid)
|
||||
case "ofxfile":
|
||||
OFXFileImportHandler(w, r, user, accountid)
|
||||
OFXFileImportHandler(db, w, r, user, accountid)
|
||||
default:
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
}
|
||||
|
44
main.go
44
main.go
@ -4,6 +4,7 @@ package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"gopkg.in/gorp.v1"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -43,8 +44,6 @@ func init() {
|
||||
|
||||
// Setup the logging flags to be printed
|
||||
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
|
||||
|
||||
initDB(config)
|
||||
}
|
||||
|
||||
func rootHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@ -55,18 +54,39 @@ func staticHandler(w http.ResponseWriter, r *http.Request) {
|
||||
http.ServeFile(w, r, path.Join(config.MoneyGo.Basedir, r.URL.Path))
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Create a closure over db, allowing the handlers to look like a
|
||||
// http.HandlerFunc
|
||||
type DB = gorp.DbMap
|
||||
type DBHandler func(http.ResponseWriter, *http.Request, *DB)
|
||||
|
||||
func DBHandlerFunc(h DBHandler, db *DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
h(w, r, db)
|
||||
}
|
||||
}
|
||||
|
||||
func GetHandler(db *DB) http.Handler {
|
||||
servemux := http.NewServeMux()
|
||||
servemux.HandleFunc("/", rootHandler)
|
||||
servemux.HandleFunc("/static/", staticHandler)
|
||||
servemux.HandleFunc("/session/", SessionHandler)
|
||||
servemux.HandleFunc("/user/", UserHandler)
|
||||
servemux.HandleFunc("/security/", SecurityHandler)
|
||||
servemux.HandleFunc("/session/", DBHandlerFunc(SessionHandler, db))
|
||||
servemux.HandleFunc("/user/", DBHandlerFunc(UserHandler, db))
|
||||
servemux.HandleFunc("/security/", DBHandlerFunc(SecurityHandler, db))
|
||||
servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler)
|
||||
servemux.HandleFunc("/account/", AccountHandler)
|
||||
servemux.HandleFunc("/transaction/", TransactionHandler)
|
||||
servemux.HandleFunc("/import/gnucash", GnucashImportHandler)
|
||||
servemux.HandleFunc("/report/", ReportHandler)
|
||||
servemux.HandleFunc("/account/", DBHandlerFunc(AccountHandler, db))
|
||||
servemux.HandleFunc("/transaction/", DBHandlerFunc(TransactionHandler, db))
|
||||
servemux.HandleFunc("/import/gnucash", DBHandlerFunc(GnucashImportHandler, db))
|
||||
servemux.HandleFunc("/report/", DBHandlerFunc(ReportHandler, db))
|
||||
|
||||
return servemux
|
||||
}
|
||||
|
||||
func main() {
|
||||
database, err := initDB(config)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
handler := GetHandler(database)
|
||||
|
||||
listener, err := net.Listen("tcp", ":"+strconv.Itoa(config.MoneyGo.Port))
|
||||
if err != nil {
|
||||
@ -75,8 +95,8 @@ func main() {
|
||||
|
||||
log.Printf("Serving on port %d out of directory: %s", config.MoneyGo.Port, config.MoneyGo.Basedir)
|
||||
if config.MoneyGo.Fcgi {
|
||||
fcgi.Serve(listener, servemux)
|
||||
fcgi.Serve(listener, handler)
|
||||
} else {
|
||||
http.Serve(listener, servemux)
|
||||
http.Serve(listener, handler)
|
||||
}
|
||||
}
|
||||
|
13
prices.go
13
prices.go
@ -1,8 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/FlashBoys/go-finance"
|
||||
"gopkg.in/gorp.v1"
|
||||
"time"
|
||||
)
|
||||
@ -93,8 +91,8 @@ func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Securi
|
||||
}
|
||||
}
|
||||
|
||||
func GetClosestPrice(security, currency *Security, date *time.Time) (*Price, error) {
|
||||
transaction, err := DB.Begin()
|
||||
func GetClosestPrice(db *DB, security, currency *Security, date *time.Time) (*Price, error) {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -113,10 +111,3 @@ func GetClosestPrice(security, currency *Security, date *time.Time) (*Price, err
|
||||
|
||||
return price, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
q, err := finance.GetQuote("BRK-A")
|
||||
if err == nil {
|
||||
fmt.Printf("%+v", q)
|
||||
}
|
||||
}
|
||||
|
48
reports.go
48
reports.go
@ -27,6 +27,7 @@ const (
|
||||
accountsContextKey
|
||||
securitiesContextKey
|
||||
balanceContextKey
|
||||
dbContextKey
|
||||
)
|
||||
|
||||
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
|
||||
@ -76,36 +77,36 @@ func (r *Tabulation) Write(w http.ResponseWriter) error {
|
||||
return enc.Encode(r)
|
||||
}
|
||||
|
||||
func GetReport(reportid int64, userid int64) (*Report, error) {
|
||||
func GetReport(db *DB, reportid int64, userid int64) (*Report, error) {
|
||||
var r Report
|
||||
|
||||
err := DB.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
||||
err := db.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func GetReports(userid int64) (*[]Report, error) {
|
||||
func GetReports(db *DB, userid int64) (*[]Report, error) {
|
||||
var reports []Report
|
||||
|
||||
_, err := DB.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
||||
_, err := db.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &reports, nil
|
||||
}
|
||||
|
||||
func InsertReport(r *Report) error {
|
||||
err := DB.Insert(r)
|
||||
func InsertReport(db *DB, r *Report) error {
|
||||
err := db.Insert(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateReport(r *Report) error {
|
||||
count, err := DB.Update(r)
|
||||
func UpdateReport(db *DB, r *Report) error {
|
||||
count, err := db.Update(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -115,8 +116,8 @@ func UpdateReport(r *Report) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteReport(r *Report) error {
|
||||
count, err := DB.Delete(r)
|
||||
func DeleteReport(db *DB, r *Report) error {
|
||||
count, err := db.Delete(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -126,13 +127,14 @@ func DeleteReport(r *Report) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runReport(user *User, report *Report) (*Tabulation, error) {
|
||||
func runReport(db *DB, user *User, report *Report) (*Tabulation, error) {
|
||||
// Create a new LState without opening the default libs for security
|
||||
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
||||
defer L.Close()
|
||||
|
||||
// Create a new context holding the current user with a timeout
|
||||
ctx := context.WithValue(context.Background(), userContextKey, user)
|
||||
ctx = context.WithValue(ctx, dbContextKey, db)
|
||||
ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second)
|
||||
defer cancel()
|
||||
L.SetContext(ctx)
|
||||
@ -189,14 +191,14 @@ func runReport(user *User, report *Report) (*Tabulation, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func ReportTabulationHandler(w http.ResponseWriter, r *http.Request, user *User, reportid int64) {
|
||||
report, err := GetReport(reportid, user.UserId)
|
||||
func ReportTabulationHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, reportid int64) {
|
||||
report, err := GetReport(db, reportid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
tabulation, err := runReport(user, report)
|
||||
tabulation, err := runReport(db, user, report)
|
||||
if err != nil {
|
||||
// TODO handle different failure cases differently
|
||||
log.Print("runReport returned:", err)
|
||||
@ -214,8 +216,8 @@ func ReportTabulationHandler(w http.ResponseWriter, r *http.Request, user *User,
|
||||
}
|
||||
}
|
||||
|
||||
func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := GetUserFromSession(r)
|
||||
func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
@ -237,7 +239,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
report.ReportId = -1
|
||||
report.UserId = user.UserId
|
||||
|
||||
err = InsertReport(&report)
|
||||
err = InsertReport(db, &report)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -260,7 +262,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
ReportTabulationHandler(w, r, user, reportid)
|
||||
ReportTabulationHandler(db, w, r, user, reportid)
|
||||
return
|
||||
}
|
||||
|
||||
@ -269,7 +271,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil || n != 1 {
|
||||
//Return all Reports
|
||||
var rl ReportList
|
||||
reports, err := GetReports(user.UserId)
|
||||
reports, err := GetReports(db, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -284,7 +286,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
} else {
|
||||
// Return Report with this Id
|
||||
report, err := GetReport(reportid, user.UserId)
|
||||
report, err := GetReport(db, reportid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
@ -319,7 +321,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
report.UserId = user.UserId
|
||||
|
||||
err = UpdateReport(&report)
|
||||
err = UpdateReport(db, &report)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -333,13 +335,13 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
} else if r.Method == "DELETE" {
|
||||
report, err := GetReport(reportid, user.UserId)
|
||||
report, err := GetReport(db, reportid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = DeleteReport(report)
|
||||
err = DeleteReport(db, report)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
|
@ -96,10 +96,10 @@ func FindCurrencyTemplate(iso4217 int64) *Security {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSecurity(securityid int64, userid int64) (*Security, error) {
|
||||
func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) {
|
||||
var s Security
|
||||
|
||||
err := DB.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
||||
err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -116,18 +116,18 @@ func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func GetSecurities(userid int64) (*[]*Security, error) {
|
||||
func GetSecurities(db *DB, userid int64) (*[]*Security, error) {
|
||||
var securities []*Security
|
||||
|
||||
_, err := DB.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
||||
_, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &securities, nil
|
||||
}
|
||||
|
||||
func InsertSecurity(s *Security) error {
|
||||
err := DB.Insert(s)
|
||||
func InsertSecurity(db *DB, s *Security) error {
|
||||
err := db.Insert(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -142,8 +142,8 @@ func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateSecurity(s *Security) error {
|
||||
transaction, err := DB.Begin()
|
||||
func UpdateSecurity(db *DB, s *Security) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -176,8 +176,8 @@ func UpdateSecurity(s *Security) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteSecurity(s *Security) error {
|
||||
transaction, err := DB.Begin()
|
||||
func DeleteSecurity(db *DB, s *Security) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -279,8 +279,8 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi
|
||||
return security, nil
|
||||
}
|
||||
|
||||
func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := GetUserFromSession(r)
|
||||
func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
@ -302,7 +302,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
||||
security.SecurityId = -1
|
||||
security.UserId = user.UserId
|
||||
|
||||
err = InsertSecurity(&security)
|
||||
err = InsertSecurity(db, &security)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -324,7 +324,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
||||
//Return all securities
|
||||
var sl SecurityList
|
||||
|
||||
securities, err := GetSecurities(user.UserId)
|
||||
securities, err := GetSecurities(db, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -339,7 +339,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
security, err := GetSecurity(securityid, user.UserId)
|
||||
security, err := GetSecurity(db, securityid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
@ -373,7 +373,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
security.UserId = user.UserId
|
||||
|
||||
err = UpdateSecurity(&security)
|
||||
err = UpdateSecurity(db, &security)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -387,13 +387,13 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
} else if r.Method == "DELETE" {
|
||||
security, err := GetSecurity(securityid, user.UserId)
|
||||
security, err := GetSecurity(db, securityid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = DeleteSecurity(security)
|
||||
err = DeleteSecurity(db, security)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
|
@ -13,14 +13,19 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) {
|
||||
|
||||
ctx := L.Context()
|
||||
|
||||
security_map, ok := ctx.Value(securitiesContextKey).(map[int64]*Security)
|
||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find DB in lua's Context")
|
||||
}
|
||||
|
||||
security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security)
|
||||
if !ok {
|
||||
user, ok := ctx.Value(userContextKey).(*User)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find User in lua's Context")
|
||||
}
|
||||
|
||||
securities, err := GetSecurities(user.UserId)
|
||||
securities, err := GetSecurities(db, user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -149,7 +154,13 @@ func luaClosestPrice(L *lua.LState) int {
|
||||
c := luaCheckSecurity(L, 2)
|
||||
date := luaCheckTime(L, 3)
|
||||
|
||||
p, err := GetClosestPrice(s, c, date)
|
||||
ctx := L.Context()
|
||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
||||
if !ok {
|
||||
panic("Couldn't find DB in lua's Context")
|
||||
}
|
||||
|
||||
p, err := GetClosestPrice(db, s, c, date)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
} else {
|
||||
|
27
sessions.go
27
sessions.go
@ -22,7 +22,7 @@ func (s *Session) Write(w http.ResponseWriter) error {
|
||||
return enc.Encode(s)
|
||||
}
|
||||
|
||||
func GetSession(r *http.Request) (*Session, error) {
|
||||
func GetSession(db *DB, r *http.Request) (*Session, error) {
|
||||
var s Session
|
||||
|
||||
cookie, err := r.Cookie("moneygo-session")
|
||||
@ -31,17 +31,18 @@ func GetSession(r *http.Request) (*Session, error) {
|
||||
}
|
||||
s.SessionSecret = cookie.Value
|
||||
|
||||
err = DB.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
|
||||
err = db.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func DeleteSessionIfExists(r *http.Request) {
|
||||
session, err := GetSession(r)
|
||||
func DeleteSessionIfExists(db *DB, r *http.Request) {
|
||||
// TODO do this in one transaction
|
||||
session, err := GetSession(db, r)
|
||||
if err == nil {
|
||||
DB.Delete(session)
|
||||
db.Delete(session)
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,7 +54,7 @@ func NewSessionCookie() (string, error) {
|
||||
return base64.StdEncoding.EncodeToString(bits), nil
|
||||
}
|
||||
|
||||
func NewSession(w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) {
|
||||
func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) {
|
||||
s := Session{}
|
||||
|
||||
session_secret, err := NewSessionCookie()
|
||||
@ -75,14 +76,14 @@ func NewSession(w http.ResponseWriter, r *http.Request, userid int64) (*Session,
|
||||
s.SessionSecret = session_secret
|
||||
s.UserId = userid
|
||||
|
||||
err = DB.Insert(&s)
|
||||
err = db.Insert(&s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func SessionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func SessionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
if r.Method == "POST" || r.Method == "PUT" {
|
||||
user_json := r.PostFormValue("user")
|
||||
if user_json == "" {
|
||||
@ -97,7 +98,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
dbuser, err := GetUserByUsername(user.Username)
|
||||
dbuser, err := GetUserByUsername(db, user.Username)
|
||||
if err != nil {
|
||||
WriteError(w, 2 /*Unauthorized Access*/)
|
||||
return
|
||||
@ -109,9 +110,9 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
DeleteSessionIfExists(r)
|
||||
DeleteSessionIfExists(db, r)
|
||||
|
||||
session, err := NewSession(w, r, dbuser.UserId)
|
||||
session, err := NewSession(db, w, r, dbuser.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
return
|
||||
@ -124,7 +125,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
} else if r.Method == "GET" {
|
||||
s, err := GetSession(r)
|
||||
s, err := GetSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
@ -132,7 +133,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
s.Write(w)
|
||||
} else if r.Method == "DELETE" {
|
||||
DeleteSessionIfExists(r)
|
||||
DeleteSessionIfExists(db, r)
|
||||
WriteSuccess(w)
|
||||
}
|
||||
}
|
||||
|
137
transactions.go
137
transactions.go
@ -146,11 +146,7 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
|
||||
if t.Splits[i].AccountId != -1 {
|
||||
var err error
|
||||
var account *Account
|
||||
if transaction != nil {
|
||||
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
|
||||
} else {
|
||||
account, err = GetAccount(t.Splits[i].AccountId, t.UserId)
|
||||
}
|
||||
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -164,16 +160,12 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
|
||||
return sums, nil
|
||||
}
|
||||
|
||||
func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) {
|
||||
return t.GetImbalancesTx(nil)
|
||||
}
|
||||
|
||||
// Returns true if all securities contained in this transaction are balanced,
|
||||
// false otherwise
|
||||
func (t *Transaction) Balanced() (bool, error) {
|
||||
func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) {
|
||||
var zero big.Rat
|
||||
|
||||
sums, err := t.GetImbalances()
|
||||
sums, err := t.GetImbalancesTx(transaction)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -186,21 +178,23 @@ func (t *Transaction) Balanced() (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func GetTransaction(transactionid int64, userid int64) (*Transaction, error) {
|
||||
func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) {
|
||||
var t Transaction
|
||||
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -213,10 +207,10 @@ func GetTransaction(transactionid int64, userid int64) (*Transaction, error) {
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func GetTransactions(userid int64) (*[]Transaction, error) {
|
||||
func GetTransactions(db *DB, userid int64) (*[]Transaction, error) {
|
||||
var transactions []Transaction
|
||||
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -316,8 +310,8 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertTransaction(t *Transaction, user *User) error {
|
||||
transaction, err := DB.Begin()
|
||||
func InsertTransaction(db *DB, t *Transaction, user *User) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -337,17 +331,11 @@ func InsertTransaction(t *Transaction, user *User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateTransaction(t *Transaction, user *User) error {
|
||||
transaction, err := DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error {
|
||||
var existing_splits []*Split
|
||||
|
||||
_, err = transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||
_, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
@ -367,11 +355,9 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
||||
if ok {
|
||||
count, err := transaction.Update(t.Splits[i])
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
transaction.Rollback()
|
||||
return errors.New("Updated more than one transaction split")
|
||||
}
|
||||
delete(s_map, t.Splits[i].SplitId)
|
||||
@ -379,7 +365,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
||||
t.Splits[i].SplitId = -1
|
||||
err := transaction.Insert(t.Splits[i])
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -397,7 +382,6 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
||||
if ok {
|
||||
_, err := transaction.Delete(existing_splits[i])
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -410,31 +394,22 @@ func UpdateTransaction(t *Transaction, user *User) error {
|
||||
}
|
||||
err = incrementAccountVersions(transaction, user, a_ids)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := transaction.Update(t)
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
transaction.Rollback()
|
||||
return errors.New("Updated more than one transaction")
|
||||
}
|
||||
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
transaction.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteTransaction(t *Transaction, user *User) error {
|
||||
transaction, err := DB.Begin()
|
||||
func DeleteTransaction(db *DB, t *Transaction, user *User) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -477,8 +452,8 @@ func DeleteTransaction(t *Transaction, user *User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := GetUserFromSession(r)
|
||||
func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
@ -500,27 +475,37 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
transaction.TransactionId = -1
|
||||
transaction.UserId = user.UserId
|
||||
|
||||
balanced, err := transaction.Balanced()
|
||||
sqltx, err := db.Begin()
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
balanced, err := transaction.Balanced(sqltx)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
if !transaction.Valid() || !balanced {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
for i := range transaction.Splits {
|
||||
transaction.Splits[i].SplitId = -1
|
||||
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId)
|
||||
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = InsertTransaction(&transaction, user)
|
||||
err = InsertTransactionTx(sqltx, &transaction, user)
|
||||
if err != nil {
|
||||
if _, ok := err.(AccountMissingError); ok {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
@ -528,6 +513,15 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
}
|
||||
sqltx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
err = sqltx.Commit()
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -543,7 +537,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
//Return all Transactions
|
||||
var al TransactionList
|
||||
transactions, err := GetTransactions(user.UserId)
|
||||
transactions, err := GetTransactions(db, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -558,7 +552,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
} else {
|
||||
//Return Transaction with this Id
|
||||
transaction, err := GetTransaction(transactionid, user.UserId)
|
||||
transaction, err := GetTransaction(db, transactionid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
@ -591,27 +585,46 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
transaction.UserId = user.UserId
|
||||
|
||||
balanced, err := transaction.Balanced()
|
||||
sqltx, err := db.Begin()
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
balanced, err := transaction.Balanced(sqltx)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
if !transaction.Valid() || !balanced {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
for i := range transaction.Splits {
|
||||
_, err := GetAccount(transaction.Splits[i].AccountId, user.UserId)
|
||||
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = UpdateTransaction(&transaction, user)
|
||||
err = UpdateTransactionTx(sqltx, &transaction, user)
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = sqltx.Commit()
|
||||
if err != nil {
|
||||
sqltx.Rollback()
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
return
|
||||
@ -630,13 +643,13 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
transaction, err := GetTransaction(transactionid, user.UserId)
|
||||
transaction, err := GetTransaction(db, transactionid, user.UserId)
|
||||
if err != nil {
|
||||
WriteError(w, 3 /*Invalid Request*/)
|
||||
return
|
||||
}
|
||||
|
||||
err = DeleteTransaction(transaction, user)
|
||||
err = DeleteTransaction(db, transaction, user)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -672,9 +685,9 @@ func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int6
|
||||
return &pageDifference, nil
|
||||
}
|
||||
|
||||
func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) {
|
||||
func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) {
|
||||
var splits []Split
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -707,9 +720,9 @@ func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) {
|
||||
}
|
||||
|
||||
// Assumes accountid is valid and is owned by the current user
|
||||
func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.Rat, error) {
|
||||
func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) {
|
||||
var splits []Split
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -741,9 +754,9 @@ func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.R
|
||||
return &balance, nil
|
||||
}
|
||||
|
||||
func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
||||
func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
||||
var splits []Split
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -775,11 +788,11 @@ func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Ti
|
||||
return &balance, nil
|
||||
}
|
||||
|
||||
func GetAccountTransactions(user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
|
||||
func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
|
||||
var transactions []Transaction
|
||||
var atl AccountTransactionsList
|
||||
|
||||
transaction, err := DB.Begin()
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -878,7 +891,7 @@ func GetAccountTransactions(user *User, accountid int64, sort string, page uint6
|
||||
|
||||
// Return only those transactions which have at least one split pertaining to
|
||||
// an account
|
||||
func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
|
||||
func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request,
|
||||
user *User, accountid int64) {
|
||||
|
||||
var page uint64 = 0
|
||||
@ -916,7 +929,7 @@ func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
|
||||
sort = sortstring
|
||||
}
|
||||
|
||||
accountTransactions, err := GetAccountTransactions(user, accountid, sort, page, limit)
|
||||
accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
|
32
users.go
32
users.go
@ -47,10 +47,10 @@ func (u *User) HashPassword() {
|
||||
u.Password = ""
|
||||
}
|
||||
|
||||
func GetUser(userid int64) (*User, error) {
|
||||
func GetUser(db *DB, userid int64) (*User, error) {
|
||||
var u User
|
||||
|
||||
err := DB.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||
err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -67,18 +67,18 @@ func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) {
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func GetUserByUsername(username string) (*User, error) {
|
||||
func GetUserByUsername(db *DB, username string) (*User, error) {
|
||||
var u User
|
||||
|
||||
err := DB.SelectOne(&u, "SELECT * from users where Username=?", username)
|
||||
err := db.SelectOne(&u, "SELECT * from users where Username=?", username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func InsertUser(u *User) error {
|
||||
transaction, err := DB.Begin()
|
||||
func InsertUser(db *DB, u *User) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -136,16 +136,16 @@ func InsertUser(u *User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetUserFromSession(r *http.Request) (*User, error) {
|
||||
s, err := GetSession(r)
|
||||
func GetUserFromSession(db *DB, r *http.Request) (*User, error) {
|
||||
s, err := GetSession(db, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return GetUser(s.UserId)
|
||||
return GetUser(db, s.UserId)
|
||||
}
|
||||
|
||||
func UpdateUser(u *User) error {
|
||||
transaction, err := DB.Begin()
|
||||
func UpdateUser(db *DB, u *User) error {
|
||||
transaction, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -180,7 +180,7 @@ func UpdateUser(u *User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func UserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
||||
if r.Method == "POST" {
|
||||
user_json := r.PostFormValue("user")
|
||||
if user_json == "" {
|
||||
@ -197,7 +197,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
user.UserId = -1
|
||||
user.HashPassword()
|
||||
|
||||
err = InsertUser(&user)
|
||||
err = InsertUser(db, &user)
|
||||
if err != nil {
|
||||
if _, ok := err.(UserExistsError); ok {
|
||||
WriteError(w, 4 /*User Exists*/)
|
||||
@ -216,7 +216,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
user, err := GetUserFromSession(r)
|
||||
user, err := GetUserFromSession(db, r)
|
||||
if err != nil {
|
||||
WriteError(w, 1 /*Not Signed In*/)
|
||||
return
|
||||
@ -264,7 +264,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
user.PasswordHash = old_pwhash
|
||||
}
|
||||
|
||||
err = UpdateUser(user)
|
||||
err = UpdateUser(db, user)
|
||||
if err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
@ -278,7 +278,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
} else if r.Method == "DELETE" {
|
||||
count, err := DB.Delete(&user)
|
||||
count, err := db.Delete(&user)
|
||||
if count != 1 || err != nil {
|
||||
WriteError(w, 999 /*Internal Error*/)
|
||||
log.Print(err)
|
||||
|
Loading…
Reference in New Issue
Block a user