mirror of
				https://github.com/aclindsa/moneygo.git
				synced 2025-10-30 09:33:25 -04:00 
			
		
		
		
	First pass at reorganizing go code into sub-packages
This commit is contained in:
		
							
								
								
									
										9
									
								
								internal/handlers/Makefile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								internal/handlers/Makefile
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| all: security_templates.go | ||||
|  | ||||
| security_templates.go: cusip_list.csv scripts/gen_security_list.py | ||||
| 	./scripts/gen_security_list.py > security_templates.go | ||||
|  | ||||
| cusip_list.csv: | ||||
| 	./scripts/gen_cusip_csv.sh > cusip_list.csv | ||||
|  | ||||
| .PHONY = all | ||||
							
								
								
									
										561
									
								
								internal/handlers/accounts.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										561
									
								
								internal/handlers/accounts.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,561 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"gopkg.in/gorp.v1" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type AccountType int64 | ||||
|  | ||||
| const ( | ||||
| 	Bank       AccountType = 1 // start at 1 so that the default (0) is invalid | ||||
| 	Cash                   = 2 | ||||
| 	Asset                  = 3 | ||||
| 	Liability              = 4 | ||||
| 	Investment             = 5 | ||||
| 	Income                 = 6 | ||||
| 	Expense                = 7 | ||||
| 	Trading                = 8 | ||||
| 	Equity                 = 9 | ||||
| 	Receivable             = 10 | ||||
| 	Payable                = 11 | ||||
| ) | ||||
|  | ||||
| var AccountTypes = []AccountType{ | ||||
| 	Bank, | ||||
| 	Cash, | ||||
| 	Asset, | ||||
| 	Liability, | ||||
| 	Investment, | ||||
| 	Income, | ||||
| 	Expense, | ||||
| 	Trading, | ||||
| 	Equity, | ||||
| 	Receivable, | ||||
| 	Payable, | ||||
| } | ||||
|  | ||||
| func (t AccountType) String() string { | ||||
| 	switch t { | ||||
| 	case Bank: | ||||
| 		return "Bank" | ||||
| 	case Cash: | ||||
| 		return "Cash" | ||||
| 	case Asset: | ||||
| 		return "Asset" | ||||
| 	case Liability: | ||||
| 		return "Liability" | ||||
| 	case Investment: | ||||
| 		return "Investment" | ||||
| 	case Income: | ||||
| 		return "Income" | ||||
| 	case Expense: | ||||
| 		return "Expense" | ||||
| 	case Trading: | ||||
| 		return "Trading" | ||||
| 	case Equity: | ||||
| 		return "Equity" | ||||
| 	case Receivable: | ||||
| 		return "Receivable" | ||||
| 	case Payable: | ||||
| 		return "Payable" | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| type Account struct { | ||||
| 	AccountId         int64 | ||||
| 	ExternalAccountId string | ||||
| 	UserId            int64 | ||||
| 	SecurityId        int64 | ||||
| 	ParentAccountId   int64 // -1 if this account is at the root | ||||
| 	Type              AccountType | ||||
| 	Name              string | ||||
|  | ||||
| 	// monotonically-increasing account transaction version number. Used for | ||||
| 	// allowing a client to ensure they have a consistent version when paging | ||||
| 	// through transactions. | ||||
| 	AccountVersion int64 `json:"Version"` | ||||
|  | ||||
| 	// Optional fields specifying how to fetch transactions from a bank via OFX | ||||
| 	OFXURL       string | ||||
| 	OFXORG       string | ||||
| 	OFXFID       string | ||||
| 	OFXUser      string | ||||
| 	OFXBankID    string // OFX BankID (BrokerID if AcctType == Investment) | ||||
| 	OFXAcctID    string | ||||
| 	OFXAcctType  string // ofxgo.acctType | ||||
| 	OFXClientUID string | ||||
| 	OFXAppID     string | ||||
| 	OFXAppVer    string | ||||
| 	OFXVersion   string | ||||
| 	OFXNoIndent  bool | ||||
| } | ||||
|  | ||||
| type AccountList struct { | ||||
| 	Accounts *[]Account `json:"accounts"` | ||||
| } | ||||
|  | ||||
| var accountTransactionsRE *regexp.Regexp | ||||
| var accountImportRE *regexp.Regexp | ||||
|  | ||||
| func init() { | ||||
| 	accountTransactionsRE = regexp.MustCompile(`^/account/[0-9]+/transactions/?$`) | ||||
| 	accountImportRE = regexp.MustCompile(`^/account/[0-9]+/import/[a-z]+/?$`) | ||||
| } | ||||
|  | ||||
| func (a *Account) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(a) | ||||
| } | ||||
|  | ||||
| func (a *Account) Read(json_str string) error { | ||||
| 	dec := json.NewDecoder(strings.NewReader(json_str)) | ||||
| 	return dec.Decode(a) | ||||
| } | ||||
|  | ||||
| func (al *AccountList) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(al) | ||||
| } | ||||
|  | ||||
| func GetAccount(db *DB, accountid int64, userid int64) (*Account, error) { | ||||
| 	var a Account | ||||
|  | ||||
| 	err := db.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &a, nil | ||||
| } | ||||
|  | ||||
| func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64) (*Account, error) { | ||||
| 	var a Account | ||||
|  | ||||
| 	err := transaction.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &a, nil | ||||
| } | ||||
|  | ||||
| func GetAccounts(db *DB, userid int64) (*[]Account, error) { | ||||
| 	var accounts []Account | ||||
|  | ||||
| 	_, err := db.Select(&accounts, "SELECT * from accounts where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &accounts, nil | ||||
| } | ||||
|  | ||||
| // Get (and attempt to create if it doesn't exist). Matches on UserId, | ||||
| // SecurityId, Type, Name, and ParentAccountId | ||||
| func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, error) { | ||||
| 	var accounts []Account | ||||
| 	var account Account | ||||
|  | ||||
| 	// Try to find the top-level trading account | ||||
| 	_, err := transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if len(accounts) == 1 { | ||||
| 		account = accounts[0] | ||||
| 	} else { | ||||
| 		account.UserId = a.UserId | ||||
| 		account.SecurityId = a.SecurityId | ||||
| 		account.Type = a.Type | ||||
| 		account.Name = a.Name | ||||
| 		account.ParentAccountId = a.ParentAccountId | ||||
|  | ||||
| 		err = transaction.Insert(&account) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 	return &account, nil | ||||
| } | ||||
|  | ||||
| // Get (and attempt to create if it doesn't exist) the security/currency | ||||
| // trading account for the supplied security/currency | ||||
| func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) { | ||||
| 	var tradingAccount Account | ||||
| 	var account Account | ||||
|  | ||||
| 	user, err := GetUserTx(transaction, userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	tradingAccount.UserId = userid | ||||
| 	tradingAccount.Type = Trading | ||||
| 	tradingAccount.Name = "Trading" | ||||
| 	tradingAccount.SecurityId = user.DefaultCurrency | ||||
| 	tradingAccount.ParentAccountId = -1 | ||||
|  | ||||
| 	// Find/create the top-level trading account | ||||
| 	ta, err := GetCreateAccountTx(transaction, tradingAccount) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	security, err := GetSecurityTx(transaction, securityid, userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	account.UserId = userid | ||||
| 	account.Name = security.Name | ||||
| 	account.ParentAccountId = ta.AccountId | ||||
| 	account.SecurityId = securityid | ||||
| 	account.Type = Trading | ||||
|  | ||||
| 	a, err := GetCreateAccountTx(transaction, account) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return a, nil | ||||
| } | ||||
|  | ||||
| // Get (and attempt to create if it doesn't exist) the security/currency | ||||
| // imbalance account for the supplied security/currency | ||||
| func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) { | ||||
| 	var imbalanceAccount Account | ||||
| 	var account Account | ||||
| 	xxxtemplate := FindSecurityTemplate("XXX", Currency) | ||||
| 	if xxxtemplate == nil { | ||||
| 		return nil, errors.New("Couldn't find XXX security template") | ||||
| 	} | ||||
| 	xxxsecurity, err := ImportGetCreateSecurity(transaction, userid, xxxtemplate) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.New("Couldn't create XXX security") | ||||
| 	} | ||||
|  | ||||
| 	imbalanceAccount.UserId = userid | ||||
| 	imbalanceAccount.Name = "Imbalances" | ||||
| 	imbalanceAccount.ParentAccountId = -1 | ||||
| 	imbalanceAccount.SecurityId = xxxsecurity.SecurityId | ||||
| 	imbalanceAccount.Type = Bank | ||||
|  | ||||
| 	// Find/create the top-level trading account | ||||
| 	ia, err := GetCreateAccountTx(transaction, imbalanceAccount) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	security, err := GetSecurityTx(transaction, securityid, userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	account.UserId = userid | ||||
| 	account.Name = security.Name | ||||
| 	account.ParentAccountId = ia.AccountId | ||||
| 	account.SecurityId = securityid | ||||
| 	account.Type = Bank | ||||
|  | ||||
| 	a, err := GetCreateAccountTx(transaction, account) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return a, nil | ||||
| } | ||||
|  | ||||
| type ParentAccountMissingError struct{} | ||||
|  | ||||
| func (pame ParentAccountMissingError) Error() string { | ||||
| 	return "Parent account missing" | ||||
| } | ||||
|  | ||||
| func insertUpdateAccount(db *DB, a *Account, insert bool) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if a.ParentAccountId != -1 { | ||||
| 		existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", a.ParentAccountId) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return err | ||||
| 		} | ||||
| 		if existing != 1 { | ||||
| 			transaction.Rollback() | ||||
| 			return ParentAccountMissingError{} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if insert { | ||||
| 		err = transaction.Insert(a) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		oldacct, err := GetAccountTx(transaction, a.AccountId, a.UserId) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		a.AccountVersion = oldacct.AccountVersion + 1 | ||||
|  | ||||
| 		count, err := transaction.Update(a) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return err | ||||
| 		} | ||||
| 		if count != 1 { | ||||
| 			transaction.Rollback() | ||||
| 			return errors.New("Updated more than one account") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func InsertAccount(db *DB, a *Account) error { | ||||
| 	return insertUpdateAccount(db, a, true) | ||||
| } | ||||
|  | ||||
| func UpdateAccount(db *DB, a *Account) error { | ||||
| 	return insertUpdateAccount(db, a, false) | ||||
| } | ||||
|  | ||||
| func DeleteAccount(db *DB, a *Account) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if a.ParentAccountId != -1 { | ||||
| 		// Re-parent splits to this account's parent account if this account isn't a root account | ||||
| 		_, err = transaction.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		// Delete splits if this account is a root account | ||||
| 		_, err = transaction.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Re-parent child accounts to this account's parent account | ||||
| 	_, err = transaction.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Delete(a) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Was going to delete more than one account") | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	user, err := GetUserFromSession(db, r) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 1 /*Not Signed In*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if r.Method == "POST" { | ||||
| 		// if URL looks like /account/[0-9]+/import, use the account | ||||
| 		// import handler | ||||
| 		if accountImportRE.MatchString(r.URL.Path) { | ||||
| 			var accountid int64 | ||||
| 			var importtype string | ||||
| 			n, err := GetURLPieces(r.URL.Path, "/account/%d/import/%s", &accountid, &importtype) | ||||
|  | ||||
| 			if err != nil || n != 2 { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			AccountImportHandler(db, w, r, user, accountid, importtype) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		account_json := r.PostFormValue("account") | ||||
| 		if account_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		var account Account | ||||
| 		err := account.Read(account_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		account.AccountId = -1 | ||||
| 		account.UserId = user.UserId | ||||
| 		account.AccountVersion = 0 | ||||
|  | ||||
| 		security, err := GetSecurity(db, account.SecurityId, user.UserId) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 		if security == nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		err = InsertAccount(db, &account) | ||||
| 		if err != nil { | ||||
| 			if _, ok := err.(ParentAccountMissingError); ok { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 			} else { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		w.WriteHeader(201 /*Created*/) | ||||
| 		err = account.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 	} else if r.Method == "GET" { | ||||
| 		var accountid int64 | ||||
| 		n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid) | ||||
|  | ||||
| 		if err != nil || n != 1 { | ||||
| 			//Return all Accounts | ||||
| 			var al AccountList | ||||
| 			accounts, err := GetAccounts(db, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			al.Accounts = accounts | ||||
| 			err = (&al).Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			// if URL looks like /account/[0-9]+/transactions, use the account | ||||
| 			// transaction handler | ||||
| 			if accountTransactionsRE.MatchString(r.URL.Path) { | ||||
| 				AccountTransactionsHandler(db, w, r, user, accountid) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// Return Account with this Id | ||||
| 			account, err := GetAccount(db, accountid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = account.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		accountid, err := GetURLID(r.URL.Path) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		if r.Method == "PUT" { | ||||
| 			account_json := r.PostFormValue("account") | ||||
| 			if account_json == "" { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			var account Account | ||||
| 			err := account.Read(account_json) | ||||
| 			if err != nil || account.AccountId != accountid { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
| 			account.UserId = user.UserId | ||||
|  | ||||
| 			security, err := GetSecurity(db, account.SecurityId, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			if security == nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = UpdateAccount(db, &account) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = account.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			account, err := GetAccount(db, accountid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteAccount(db, account) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			WriteSuccess(w) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										217
									
								
								internal/handlers/accounts_lua.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								internal/handlers/accounts_lua.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,217 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| 	"math/big" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| const luaAccountTypeName = "account" | ||||
|  | ||||
| func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { | ||||
| 	var account_map map[int64]*Account | ||||
|  | ||||
| 	ctx := L.Context() | ||||
|  | ||||
| 	db, ok := ctx.Value(dbContextKey).(*DB) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Couldn't find DB in lua's Context") | ||||
| 	} | ||||
|  | ||||
| 	account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account) | ||||
| 	if !ok { | ||||
| 		user, ok := ctx.Value(userContextKey).(*User) | ||||
| 		if !ok { | ||||
| 			return nil, errors.New("Couldn't find User in lua's Context") | ||||
| 		} | ||||
|  | ||||
| 		accounts, err := GetAccounts(db, user.UserId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		account_map = make(map[int64]*Account) | ||||
| 		for i := range *accounts { | ||||
| 			account_map[(*accounts)[i].AccountId] = &(*accounts)[i] | ||||
| 		} | ||||
|  | ||||
| 		ctx = context.WithValue(ctx, accountsContextKey, account_map) | ||||
| 		L.SetContext(ctx) | ||||
| 	} | ||||
|  | ||||
| 	return account_map, nil | ||||
| } | ||||
|  | ||||
| func luaGetAccounts(L *lua.LState) int { | ||||
| 	account_map, err := luaContextGetAccounts(L) | ||||
| 	if err != nil { | ||||
| 		panic("luaGetAccounts couldn't fetch accounts") | ||||
| 	} | ||||
|  | ||||
| 	table := L.NewTable() | ||||
|  | ||||
| 	for accountid := range account_map { | ||||
| 		table.RawSetInt(int(accountid), AccountToLua(L, account_map[accountid])) | ||||
| 	} | ||||
|  | ||||
| 	L.Push(table) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaRegisterAccounts(L *lua.LState) { | ||||
| 	mt := L.NewTypeMetatable(luaAccountTypeName) | ||||
| 	L.SetGlobal("account", mt) | ||||
| 	L.SetField(mt, "__index", L.NewFunction(luaAccount__index)) | ||||
| 	L.SetField(mt, "__tostring", L.NewFunction(luaAccount__tostring)) | ||||
| 	L.SetField(mt, "__eq", L.NewFunction(luaAccount__eq)) | ||||
| 	L.SetField(mt, "__metatable", lua.LString("protected")) | ||||
|  | ||||
| 	for _, accttype := range AccountTypes { | ||||
| 		L.SetField(mt, accttype.String(), lua.LNumber(float64(accttype))) | ||||
| 	} | ||||
|  | ||||
| 	getAccountsFn := L.NewFunction(luaGetAccounts) | ||||
| 	L.SetField(mt, "get_all", getAccountsFn) | ||||
| 	// also register the get_accounts function as a global in its own right | ||||
| 	L.SetGlobal("get_accounts", getAccountsFn) | ||||
| } | ||||
|  | ||||
| func AccountToLua(L *lua.LState, account *Account) *lua.LUserData { | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = account | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaAccountTypeName)) | ||||
| 	return ud | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Account and returns this *Account. | ||||
| func luaCheckAccount(L *lua.LState, n int) *Account { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if account, ok := ud.Value.(*Account); ok { | ||||
| 		return account | ||||
| 	} | ||||
| 	L.ArgError(n, "account expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaAccount__index(L *lua.LState) int { | ||||
| 	a := luaCheckAccount(L, 1) | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "AccountId", "accountid": | ||||
| 		L.Push(lua.LNumber(float64(a.AccountId))) | ||||
| 	case "Security", "security": | ||||
| 		security_map, err := luaContextGetSecurities(L) | ||||
| 		if err != nil { | ||||
| 			panic("account.security couldn't fetch securities") | ||||
| 		} | ||||
| 		if security, ok := security_map[a.SecurityId]; ok { | ||||
| 			L.Push(SecurityToLua(L, security)) | ||||
| 		} else { | ||||
| 			panic("SecurityId not in lua security_map") | ||||
| 		} | ||||
| 	case "SecurityId", "securityid": | ||||
| 		L.Push(lua.LNumber(float64(a.SecurityId))) | ||||
| 	case "Parent", "parent", "ParentAccount", "parentaccount": | ||||
| 		if a.ParentAccountId == -1 { | ||||
| 			L.Push(lua.LNil) | ||||
| 		} else { | ||||
| 			account_map, err := luaContextGetAccounts(L) | ||||
| 			if err != nil { | ||||
| 				panic("account.parent couldn't fetch accounts") | ||||
| 			} | ||||
| 			if parent, ok := account_map[a.ParentAccountId]; ok { | ||||
| 				L.Push(AccountToLua(L, parent)) | ||||
| 			} else { | ||||
| 				panic("ParentAccountId not in lua account_map") | ||||
| 			} | ||||
| 		} | ||||
| 	case "Name", "name": | ||||
| 		L.Push(lua.LString(a.Name)) | ||||
| 	case "Type", "type": | ||||
| 		L.Push(lua.LNumber(float64(a.Type))) | ||||
| 	case "TypeName", "Typename": | ||||
| 		L.Push(lua.LString(a.Type.String())) | ||||
| 	case "typename": | ||||
| 		L.Push(lua.LString(strings.ToLower(a.Type.String()))) | ||||
| 	case "Balance", "balance": | ||||
| 		L.Push(L.NewFunction(luaAccountBalance)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected account attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaAccountBalance(L *lua.LState) int { | ||||
| 	a := luaCheckAccount(L, 1) | ||||
|  | ||||
| 	ctx := L.Context() | ||||
| 	db, ok := ctx.Value(dbContextKey).(*DB) | ||||
| 	if !ok { | ||||
| 		panic("Couldn't find DB in lua's Context") | ||||
| 	} | ||||
| 	user, ok := ctx.Value(userContextKey).(*User) | ||||
| 	if !ok { | ||||
| 		panic("Couldn't find User in lua's Context") | ||||
| 	} | ||||
| 	security_map, err := luaContextGetSecurities(L) | ||||
| 	if err != nil { | ||||
| 		panic("account.security couldn't fetch securities") | ||||
| 	} | ||||
| 	security, ok := security_map[a.SecurityId] | ||||
| 	if !ok { | ||||
| 		panic("SecurityId not in lua security_map") | ||||
| 	} | ||||
| 	date := luaWeakCheckTime(L, 2) | ||||
| 	var b Balance | ||||
| 	var rat *big.Rat | ||||
| 	if date != nil { | ||||
| 		end := luaWeakCheckTime(L, 3) | ||||
| 		if end != nil { | ||||
| 			rat, err = GetAccountBalanceDateRange(db, user, a.AccountId, date, end) | ||||
| 		} else { | ||||
| 			rat, err = GetAccountBalanceDate(db, user, a.AccountId, date) | ||||
| 		} | ||||
| 	} else { | ||||
| 		rat, err = GetAccountBalance(db, user, a.AccountId) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		panic("Failed to GetAccountBalance:" + err.Error()) | ||||
| 	} | ||||
| 	b.Amount = rat | ||||
| 	b.Security = security | ||||
| 	L.Push(BalanceToLua(L, &b)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaAccount__tostring(L *lua.LState) int { | ||||
| 	a := luaCheckAccount(L, 1) | ||||
|  | ||||
| 	account_map, err := luaContextGetAccounts(L) | ||||
| 	if err != nil { | ||||
| 		panic("luaGetAccounts couldn't fetch accounts") | ||||
| 	} | ||||
|  | ||||
| 	full_name := a.Name | ||||
| 	for a.ParentAccountId != -1 { | ||||
| 		a = account_map[a.ParentAccountId] | ||||
| 		full_name = a.Name + "/" + full_name | ||||
| 	} | ||||
|  | ||||
| 	L.Push(lua.LString(full_name)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaAccount__eq(L *lua.LState) int { | ||||
| 	a := luaCheckAccount(L, 1) | ||||
| 	b := luaCheckAccount(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.AccountId == b.AccountId)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
							
								
								
									
										224
									
								
								internal/handlers/balance_lua.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										224
									
								
								internal/handlers/balance_lua.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,224 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| 	"math/big" | ||||
| ) | ||||
|  | ||||
| type Balance struct { | ||||
| 	Security *Security | ||||
| 	Amount   *big.Rat | ||||
| } | ||||
|  | ||||
| const luaBalanceTypeName = "balance" | ||||
|  | ||||
| func luaRegisterBalances(L *lua.LState) { | ||||
| 	mt := L.NewTypeMetatable(luaBalanceTypeName) | ||||
| 	L.SetGlobal("balance", mt) | ||||
| 	L.SetField(mt, "__index", L.NewFunction(luaBalance__index)) | ||||
| 	L.SetField(mt, "__tostring", L.NewFunction(luaBalance__tostring)) | ||||
| 	L.SetField(mt, "__eq", L.NewFunction(luaBalance__eq)) | ||||
| 	L.SetField(mt, "__lt", L.NewFunction(luaBalance__lt)) | ||||
| 	L.SetField(mt, "__le", L.NewFunction(luaBalance__le)) | ||||
| 	L.SetField(mt, "__add", L.NewFunction(luaBalance__add)) | ||||
| 	L.SetField(mt, "__sub", L.NewFunction(luaBalance__sub)) | ||||
| 	L.SetField(mt, "__mul", L.NewFunction(luaBalance__mul)) | ||||
| 	L.SetField(mt, "__div", L.NewFunction(luaBalance__div)) | ||||
| 	L.SetField(mt, "__unm", L.NewFunction(luaBalance__unm)) | ||||
| 	L.SetField(mt, "__metatable", lua.LString("protected")) | ||||
| } | ||||
|  | ||||
| func BalanceToLua(L *lua.LState, balance *Balance) *lua.LUserData { | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = balance | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaBalanceTypeName)) | ||||
| 	return ud | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Balance and returns this *Balance. | ||||
| func luaCheckBalance(L *lua.LState, n int) *Balance { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if balance, ok := ud.Value.(*Balance); ok { | ||||
| 		return balance | ||||
| 	} | ||||
| 	L.ArgError(n, "balance expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaWeakCheckBalance(L *lua.LState, n int) *Balance { | ||||
| 	v := L.Get(n) | ||||
| 	if ud, ok := v.(*lua.LUserData); ok { | ||||
| 		if balance, ok := ud.Value.(*Balance); ok { | ||||
| 			return balance | ||||
| 		} | ||||
| 		L.ArgError(n, "balance expected") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaGetBalanceOperands(L *lua.LState, n int, m int) (*Balance, *Balance) { | ||||
| 	bn := luaWeakCheckBalance(L, n) | ||||
| 	bm := luaWeakCheckBalance(L, m) | ||||
|  | ||||
| 	if bn != nil && bm != nil { | ||||
| 		return bn, bm | ||||
| 	} else if bn != nil { | ||||
| 		nm := L.CheckNumber(m) | ||||
| 		var balance Balance | ||||
| 		var rat big.Rat | ||||
| 		balance.Security = bn.Security | ||||
| 		balance.Amount = rat.SetFloat64(float64(nm)) | ||||
| 		if balance.Amount == nil { | ||||
| 			L.ArgError(n, "non-finite float invalid for operand to balance arithemetic") | ||||
| 			return nil, nil | ||||
| 		} | ||||
| 		return bn, &balance | ||||
| 	} else if bm != nil { | ||||
| 		nn := L.CheckNumber(n) | ||||
| 		var balance Balance | ||||
| 		var rat big.Rat | ||||
| 		balance.Security = bm.Security | ||||
| 		balance.Amount = rat.SetFloat64(float64(nn)) | ||||
| 		if balance.Amount == nil { | ||||
| 			L.ArgError(n, "non-finite float invalid for operand to balance arithemetic") | ||||
| 			return nil, nil | ||||
| 		} | ||||
| 		return bm, &balance | ||||
| 	} | ||||
| 	L.ArgError(n, "balance expected") | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func luaBalance__index(L *lua.LState) int { | ||||
| 	a := luaCheckBalance(L, 1) | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "Security", "security": | ||||
| 		L.Push(SecurityToLua(L, a.Security)) | ||||
| 	case "Amount", "amount": | ||||
| 		float, _ := a.Amount.Float64() | ||||
| 		L.Push(lua.LNumber(float)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected balance attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__tostring(L *lua.LState) int { | ||||
| 	b := luaCheckBalance(L, 1) | ||||
|  | ||||
| 	L.Push(lua.LString(b.Security.Symbol + " " + b.Amount.FloatString(b.Security.Precision))) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__eq(L *lua.LState) int { | ||||
| 	a := luaCheckBalance(L, 1) | ||||
| 	b := luaCheckBalance(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Security.SecurityId == b.Security.SecurityId && a.Amount.Cmp(b.Amount) == 0)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__lt(L *lua.LState) int { | ||||
| 	a := luaCheckBalance(L, 1) | ||||
| 	b := luaCheckBalance(L, 2) | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't compare balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Amount.Cmp(b.Amount) < 0)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__le(L *lua.LState) int { | ||||
| 	a := luaCheckBalance(L, 1) | ||||
| 	b := luaCheckBalance(L, 2) | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't compare balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Amount.Cmp(b.Amount) <= 0)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__add(L *lua.LState) int { | ||||
| 	a, b := luaGetBalanceOperands(L, 1, 2) | ||||
|  | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't add balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	var balance Balance | ||||
| 	var rat big.Rat | ||||
| 	balance.Security = a.Security | ||||
| 	balance.Amount = rat.Add(a.Amount, b.Amount) | ||||
| 	L.Push(BalanceToLua(L, &balance)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__sub(L *lua.LState) int { | ||||
| 	a, b := luaGetBalanceOperands(L, 1, 2) | ||||
|  | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't subtract balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	var balance Balance | ||||
| 	var rat big.Rat | ||||
| 	balance.Security = a.Security | ||||
| 	balance.Amount = rat.Sub(a.Amount, b.Amount) | ||||
| 	L.Push(BalanceToLua(L, &balance)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__mul(L *lua.LState) int { | ||||
| 	a, b := luaGetBalanceOperands(L, 1, 2) | ||||
|  | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't multiply balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	var balance Balance | ||||
| 	var rat big.Rat | ||||
| 	balance.Security = a.Security | ||||
| 	balance.Amount = rat.Mul(a.Amount, b.Amount) | ||||
| 	L.Push(BalanceToLua(L, &balance)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__div(L *lua.LState) int { | ||||
| 	a, b := luaGetBalanceOperands(L, 1, 2) | ||||
|  | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't divide balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	var balance Balance | ||||
| 	var rat big.Rat | ||||
| 	balance.Security = a.Security | ||||
| 	balance.Amount = rat.Quo(a.Amount, b.Amount) | ||||
| 	L.Push(BalanceToLua(L, &balance)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__unm(L *lua.LState) int { | ||||
| 	b := luaCheckBalance(L, 1) | ||||
|  | ||||
| 	var balance Balance | ||||
| 	var rat big.Rat | ||||
| 	balance.Security = b.Security | ||||
| 	balance.Amount = rat.Neg(b.Amount) | ||||
| 	L.Push(BalanceToLua(L, &balance)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
							
								
								
									
										169
									
								
								internal/handlers/date_lua.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								internal/handlers/date_lua.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,169 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| const luaDateTypeName = "date" | ||||
| const timeFormat = "2006-01-02" | ||||
|  | ||||
| func luaRegisterDates(L *lua.LState) { | ||||
| 	mt := L.NewTypeMetatable(luaDateTypeName) | ||||
| 	L.SetGlobal("date", mt) | ||||
| 	L.SetField(mt, "new", L.NewFunction(luaDateNew)) | ||||
| 	L.SetField(mt, "now", L.NewFunction(luaDateNow)) | ||||
| 	L.SetField(mt, "__index", L.NewFunction(luaDate__index)) | ||||
| 	L.SetField(mt, "__tostring", L.NewFunction(luaDate__tostring)) | ||||
| 	L.SetField(mt, "__eq", L.NewFunction(luaDate__eq)) | ||||
| 	L.SetField(mt, "__lt", L.NewFunction(luaDate__lt)) | ||||
| 	L.SetField(mt, "__le", L.NewFunction(luaDate__le)) | ||||
| 	L.SetField(mt, "__add", L.NewFunction(luaDate__add)) | ||||
| 	L.SetField(mt, "__sub", L.NewFunction(luaDate__sub)) | ||||
| 	L.SetField(mt, "__metatable", lua.LString("protected")) | ||||
| } | ||||
|  | ||||
| func TimeToLua(L *lua.LState, date *time.Time) *lua.LUserData { | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = date | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaDateTypeName)) | ||||
| 	return ud | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Time and returns this *Time. | ||||
| func luaCheckTime(L *lua.LState, n int) *time.Time { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if date, ok := ud.Value.(*time.Time); ok { | ||||
| 		return date | ||||
| 	} | ||||
| 	L.ArgError(n, "date expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaWeakCheckTime(L *lua.LState, n int) *time.Time { | ||||
| 	v := L.Get(n) | ||||
| 	if ud, ok := v.(*lua.LUserData); ok { | ||||
| 		if date, ok := ud.Value.(*time.Time); ok { | ||||
| 			return date | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaWeakCheckTableFieldInt(L *lua.LState, T *lua.LTable, n int, name string, def int) int { | ||||
| 	lv := T.RawGetString(name) | ||||
| 	if lv == lua.LNil { | ||||
| 		return def | ||||
| 	} | ||||
| 	if i, ok := lv.(lua.LNumber); ok { | ||||
| 		return int(i) | ||||
| 	} | ||||
| 	L.ArgError(n, "table field '"+name+"' expected to be int") | ||||
| 	return def | ||||
| } | ||||
|  | ||||
| func luaDateNew(L *lua.LState) int { | ||||
| 	v := L.Get(1) | ||||
| 	if s, ok := v.(lua.LString); ok { | ||||
| 		date, err := time.Parse(timeFormat, s.String()) | ||||
| 		if err != nil { | ||||
| 			L.ArgError(1, "error parsing date string: "+err.Error()) | ||||
| 			return 0 | ||||
| 		} | ||||
| 		L.Push(TimeToLua(L, &date)) | ||||
| 		return 1 | ||||
| 	} | ||||
| 	var year, month, day int | ||||
| 	if t, ok := v.(*lua.LTable); ok { | ||||
| 		year = luaWeakCheckTableFieldInt(L, t, 1, "year", 0) | ||||
| 		month = luaWeakCheckTableFieldInt(L, t, 1, "month", 1) | ||||
| 		day = luaWeakCheckTableFieldInt(L, t, 1, "day", 1) | ||||
| 	} else { | ||||
| 		year = L.CheckInt(1) | ||||
| 		month = L.CheckInt(2) | ||||
| 		day = L.CheckInt(3) | ||||
| 	} | ||||
| 	date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.Local) | ||||
| 	L.Push(TimeToLua(L, &date)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDateNow(L *lua.LState) int { | ||||
| 	now := time.Now() | ||||
| 	date := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local) | ||||
| 	L.Push(TimeToLua(L, &date)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__index(L *lua.LState) int { | ||||
| 	d := luaCheckTime(L, 1) | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "Year", "year": | ||||
| 		L.Push(lua.LNumber(d.Year())) | ||||
| 	case "Month", "month": | ||||
| 		L.Push(lua.LNumber(float64(d.Month()))) | ||||
| 	case "Day", "day": | ||||
| 		L.Push(lua.LNumber(float64(d.Day()))) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected date attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__tostring(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
|  | ||||
| 	L.Push(lua.LString(a.Format(timeFormat))) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__eq(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
| 	b := luaCheckTime(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Equal(*b))) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__lt(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
| 	b := luaCheckTime(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Before(*b))) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__le(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
| 	b := luaCheckTime(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Equal(*b) || a.Before(*b))) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__add(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
| 	b := luaCheckTime(L, 2) | ||||
|  | ||||
| 	date := a.AddDate(b.Year(), int(b.Month()), b.Day()) | ||||
| 	L.Push(TimeToLua(L, &date)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__sub(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
| 	b := luaCheckTime(L, 2) | ||||
|  | ||||
| 	date := a.AddDate(-b.Year(), -int(b.Month()), -b.Day()) | ||||
| 	L.Push(TimeToLua(L, &date)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
							
								
								
									
										37
									
								
								internal/handlers/errors.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								internal/handlers/errors.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| type Error struct { | ||||
| 	ErrorId     int | ||||
| 	ErrorString string | ||||
| } | ||||
|  | ||||
| var error_codes = map[int]string{ | ||||
| 	1: "Not Signed In", | ||||
| 	2: "Unauthorized Access", | ||||
| 	3: "Invalid Request", | ||||
| 	4: "User Exists", | ||||
| 	//  5:   "Connection Failed", //reserved for client-side error | ||||
| 	6:   "Import Error", | ||||
| 	999: "Internal Error", | ||||
| } | ||||
|  | ||||
| func WriteError(w http.ResponseWriter, error_code int) { | ||||
| 	msg, ok := error_codes[error_code] | ||||
| 	if !ok { | ||||
| 		log.Printf("Error: WriteError received error code of %d", error_code) | ||||
| 		msg = error_codes[999] | ||||
| 	} | ||||
| 	e := Error{error_code, msg} | ||||
|  | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	err := enc.Encode(e) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										493
									
								
								internal/handlers/gnucash.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										493
									
								
								internal/handlers/gnucash.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,493 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"compress/gzip" | ||||
| 	"encoding/xml" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"math" | ||||
| 	"math/big" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type GnucashXMLCommodity struct { | ||||
| 	Name        string `xml:"http://www.gnucash.org/XML/cmdty id"` | ||||
| 	Description string `xml:"http://www.gnucash.org/XML/cmdty name"` | ||||
| 	Type        string `xml:"http://www.gnucash.org/XML/cmdty space"` | ||||
| 	Fraction    int    `xml:"http://www.gnucash.org/XML/cmdty fraction"` | ||||
| 	XCode       string `xml:"http://www.gnucash.org/XML/cmdty xcode"` | ||||
| } | ||||
|  | ||||
| type GnucashCommodity struct{ Security } | ||||
|  | ||||
| func (gc *GnucashCommodity) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { | ||||
| 	var gxc GnucashXMLCommodity | ||||
| 	if err := d.DecodeElement(&gxc, &start); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	gc.Name = gxc.Name | ||||
| 	gc.Symbol = gxc.Name | ||||
| 	gc.Description = gxc.Description | ||||
| 	gc.AlternateId = gxc.XCode | ||||
|  | ||||
| 	gc.Security.Type = Stock // assumed default | ||||
| 	if gxc.Type == "ISO4217" { | ||||
| 		gc.Security.Type = Currency | ||||
| 		// Get the number from our templates for the AlternateId because | ||||
| 		// Gnucash uses 'id' (our Name) to supply the string ISO4217 code | ||||
| 		template := FindSecurityTemplate(gxc.Name, Currency) | ||||
| 		if template == nil { | ||||
| 			return errors.New("Unable to find security template for Gnucash ISO4217 commodity") | ||||
| 		} | ||||
| 		gc.AlternateId = template.AlternateId | ||||
| 		gc.Precision = template.Precision | ||||
| 	} else { | ||||
| 		if gxc.Fraction > 0 { | ||||
| 			gc.Precision = int(math.Ceil(math.Log10(float64(gxc.Fraction)))) | ||||
| 		} else { | ||||
| 			gc.Precision = 0 | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type GnucashTime struct{ time.Time } | ||||
|  | ||||
| func (g *GnucashTime) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { | ||||
| 	var s string | ||||
| 	if err := d.DecodeElement(&s, &start); err != nil { | ||||
| 		return fmt.Errorf("date should be a string") | ||||
| 	} | ||||
| 	t, err := time.Parse("2006-01-02 15:04:05 -0700", s) | ||||
| 	g.Time = t | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| type GnucashDate struct { | ||||
| 	Date GnucashTime `xml:"http://www.gnucash.org/XML/ts date"` | ||||
| } | ||||
|  | ||||
| type GnucashPrice struct { | ||||
| 	Id        string           `xml:"http://www.gnucash.org/XML/price id"` | ||||
| 	Commodity GnucashCommodity `xml:"http://www.gnucash.org/XML/price commodity"` | ||||
| 	Currency  GnucashCommodity `xml:"http://www.gnucash.org/XML/price currency"` | ||||
| 	Date      GnucashDate      `xml:"http://www.gnucash.org/XML/price time"` | ||||
| 	Source    string           `xml:"http://www.gnucash.org/XML/price source"` | ||||
| 	Type      string           `xml:"http://www.gnucash.org/XML/price type"` | ||||
| 	Value     string           `xml:"http://www.gnucash.org/XML/price value"` | ||||
| } | ||||
|  | ||||
| type GnucashPriceDB struct { | ||||
| 	Prices []GnucashPrice `xml:"price"` | ||||
| } | ||||
|  | ||||
| type GnucashAccount struct { | ||||
| 	Version         string              `xml:"version,attr"` | ||||
| 	accountid       int64               // Used to map Gnucash guid's to integer ones | ||||
| 	AccountId       string              `xml:"http://www.gnucash.org/XML/act id"` | ||||
| 	ParentAccountId string              `xml:"http://www.gnucash.org/XML/act parent"` | ||||
| 	Name            string              `xml:"http://www.gnucash.org/XML/act name"` | ||||
| 	Description     string              `xml:"http://www.gnucash.org/XML/act description"` | ||||
| 	Type            string              `xml:"http://www.gnucash.org/XML/act type"` | ||||
| 	Commodity       GnucashXMLCommodity `xml:"http://www.gnucash.org/XML/act commodity"` | ||||
| } | ||||
|  | ||||
| type GnucashTransaction struct { | ||||
| 	TransactionId string              `xml:"http://www.gnucash.org/XML/trn id"` | ||||
| 	Description   string              `xml:"http://www.gnucash.org/XML/trn description"` | ||||
| 	Number        string              `xml:"http://www.gnucash.org/XML/trn num"` | ||||
| 	DatePosted    GnucashDate         `xml:"http://www.gnucash.org/XML/trn date-posted"` | ||||
| 	DateEntered   GnucashDate         `xml:"http://www.gnucash.org/XML/trn date-entered"` | ||||
| 	Commodity     GnucashXMLCommodity `xml:"http://www.gnucash.org/XML/trn currency"` | ||||
| 	Splits        []GnucashSplit      `xml:"http://www.gnucash.org/XML/trn splits>split"` | ||||
| } | ||||
|  | ||||
| type GnucashSplit struct { | ||||
| 	SplitId   string `xml:"http://www.gnucash.org/XML/split id"` | ||||
| 	Status    string `xml:"http://www.gnucash.org/XML/split reconciled-state"` | ||||
| 	AccountId string `xml:"http://www.gnucash.org/XML/split account"` | ||||
| 	Memo      string `xml:"http://www.gnucash.org/XML/split memo"` | ||||
| 	Amount    string `xml:"http://www.gnucash.org/XML/split quantity"` | ||||
| 	Value     string `xml:"http://www.gnucash.org/XML/split value"` | ||||
| } | ||||
|  | ||||
| type GnucashXMLImport struct { | ||||
| 	XMLName      xml.Name             `xml:"gnc-v2"` | ||||
| 	Commodities  []GnucashCommodity   `xml:"http://www.gnucash.org/XML/gnc book>commodity"` | ||||
| 	PriceDB      GnucashPriceDB       `xml:"http://www.gnucash.org/XML/gnc book>pricedb"` | ||||
| 	Accounts     []GnucashAccount     `xml:"http://www.gnucash.org/XML/gnc book>account"` | ||||
| 	Transactions []GnucashTransaction `xml:"http://www.gnucash.org/XML/gnc book>transaction"` | ||||
| } | ||||
|  | ||||
| type GnucashImport struct { | ||||
| 	Securities   []Security | ||||
| 	Accounts     []Account | ||||
| 	Transactions []Transaction | ||||
| 	Prices       []Price | ||||
| } | ||||
|  | ||||
| func ImportGnucash(r io.Reader) (*GnucashImport, error) { | ||||
| 	var gncxml GnucashXMLImport | ||||
| 	var gncimport GnucashImport | ||||
|  | ||||
| 	// Perform initial parsing of xml into structs | ||||
| 	decoder := xml.NewDecoder(r) | ||||
| 	err := decoder.Decode(&gncxml) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// Fixup securities, making a map of them as we go | ||||
| 	securityMap := make(map[string]Security) | ||||
| 	for i := range gncxml.Commodities { | ||||
| 		s := gncxml.Commodities[i].Security | ||||
| 		s.SecurityId = int64(i + 1) | ||||
| 		securityMap[s.Name] = s | ||||
|  | ||||
| 		// Ignore gnucash's "template" commodity | ||||
| 		if s.Name != "template" || | ||||
| 			s.Description != "template" || | ||||
| 			s.AlternateId != "template" { | ||||
| 			gncimport.Securities = append(gncimport.Securities, s) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Create prices, setting security and currency IDs from securityMap | ||||
| 	for i := range gncxml.PriceDB.Prices { | ||||
| 		price := gncxml.PriceDB.Prices[i] | ||||
| 		var p Price | ||||
| 		security, ok := securityMap[price.Commodity.Name] | ||||
| 		if !ok { | ||||
| 			return nil, fmt.Errorf("Unable to find commodity '%s' for price '%s'", price.Commodity.Name, price.Id) | ||||
| 		} | ||||
| 		currency, ok := securityMap[price.Currency.Name] | ||||
| 		if !ok { | ||||
| 			return nil, fmt.Errorf("Unable to find currency '%s' for price '%s'", price.Currency.Name, price.Id) | ||||
| 		} | ||||
| 		if currency.Type != Currency { | ||||
| 			return nil, fmt.Errorf("Currency for imported price isn't actually a currency\n") | ||||
| 		} | ||||
| 		p.PriceId = int64(i + 1) | ||||
| 		p.SecurityId = security.SecurityId | ||||
| 		p.CurrencyId = currency.SecurityId | ||||
| 		p.Date = price.Date.Date.Time | ||||
|  | ||||
| 		var r big.Rat | ||||
| 		_, ok = r.SetString(price.Value) | ||||
| 		if ok { | ||||
| 			p.Value = r.FloatString(currency.Precision) | ||||
| 		} else { | ||||
| 			return nil, fmt.Errorf("Can't set price value: %s", price.Value) | ||||
| 		} | ||||
|  | ||||
| 		p.RemoteId = "gnucash:" + price.Id | ||||
| 		gncimport.Prices = append(gncimport.Prices, p) | ||||
| 	} | ||||
|  | ||||
| 	//find root account, while simultaneously creating map of GUID's to | ||||
| 	//accounts | ||||
| 	var rootAccount GnucashAccount | ||||
| 	accountMap := make(map[string]GnucashAccount) | ||||
| 	for i := range gncxml.Accounts { | ||||
| 		gncxml.Accounts[i].accountid = int64(i + 1) | ||||
| 		if gncxml.Accounts[i].Type == "ROOT" { | ||||
| 			rootAccount = gncxml.Accounts[i] | ||||
| 		} else { | ||||
| 			accountMap[gncxml.Accounts[i].AccountId] = gncxml.Accounts[i] | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	//Translate to our account format, figuring out parent relationships | ||||
| 	for guid := range accountMap { | ||||
| 		ga := accountMap[guid] | ||||
| 		var a Account | ||||
|  | ||||
| 		a.AccountId = ga.accountid | ||||
| 		if ga.ParentAccountId == rootAccount.AccountId { | ||||
| 			a.ParentAccountId = -1 | ||||
| 		} else { | ||||
| 			parent, ok := accountMap[ga.ParentAccountId] | ||||
| 			if ok { | ||||
| 				a.ParentAccountId = parent.accountid | ||||
| 			} else { | ||||
| 				a.ParentAccountId = -1 // Ugly, but assign to top-level if we can't find its parent | ||||
| 			} | ||||
| 		} | ||||
| 		a.Name = ga.Name | ||||
| 		if security, ok := securityMap[ga.Commodity.Name]; ok { | ||||
| 			a.SecurityId = security.SecurityId | ||||
| 		} else { | ||||
| 			return nil, fmt.Errorf("Unable to find security: %s", ga.Commodity.Name) | ||||
| 		} | ||||
|  | ||||
| 		//TODO find account types | ||||
| 		switch ga.Type { | ||||
| 		default: | ||||
| 			a.Type = Bank | ||||
| 		case "ASSET": | ||||
| 			a.Type = Asset | ||||
| 		case "BANK": | ||||
| 			a.Type = Bank | ||||
| 		case "CASH": | ||||
| 			a.Type = Cash | ||||
| 		case "CREDIT", "LIABILITY": | ||||
| 			a.Type = Liability | ||||
| 		case "EQUITY": | ||||
| 			a.Type = Equity | ||||
| 		case "EXPENSE": | ||||
| 			a.Type = Expense | ||||
| 		case "INCOME": | ||||
| 			a.Type = Income | ||||
| 		case "PAYABLE": | ||||
| 			a.Type = Payable | ||||
| 		case "RECEIVABLE": | ||||
| 			a.Type = Receivable | ||||
| 		case "MUTUAL", "STOCK": | ||||
| 			a.Type = Investment | ||||
| 		case "TRADING": | ||||
| 			a.Type = Trading | ||||
| 		} | ||||
|  | ||||
| 		gncimport.Accounts = append(gncimport.Accounts, a) | ||||
| 	} | ||||
|  | ||||
| 	//Translate transactions to our format | ||||
| 	for i := range gncxml.Transactions { | ||||
| 		gt := gncxml.Transactions[i] | ||||
|  | ||||
| 		t := new(Transaction) | ||||
| 		t.Description = gt.Description | ||||
| 		t.Date = gt.DatePosted.Date.Time | ||||
| 		for j := range gt.Splits { | ||||
| 			gs := gt.Splits[j] | ||||
| 			s := new(Split) | ||||
|  | ||||
| 			switch gs.Status { | ||||
| 			default: // 'n', or not present | ||||
| 				s.Status = Imported | ||||
| 			case "c": | ||||
| 				s.Status = Cleared | ||||
| 			case "y": | ||||
| 				s.Status = Reconciled | ||||
| 			} | ||||
|  | ||||
| 			account, ok := accountMap[gs.AccountId] | ||||
| 			if !ok { | ||||
| 				return nil, fmt.Errorf("Unable to find account: %s", gs.AccountId) | ||||
| 			} | ||||
| 			s.AccountId = account.accountid | ||||
|  | ||||
| 			security, ok := securityMap[account.Commodity.Name] | ||||
| 			if !ok { | ||||
| 				return nil, fmt.Errorf("Unable to find security: %s", account.Commodity.Name) | ||||
| 			} | ||||
| 			s.SecurityId = -1 | ||||
|  | ||||
| 			s.RemoteId = "gnucash:" + gs.SplitId | ||||
| 			s.Number = gt.Number | ||||
| 			s.Memo = gs.Memo | ||||
|  | ||||
| 			var r big.Rat | ||||
| 			_, ok = r.SetString(gs.Amount) | ||||
| 			if ok { | ||||
| 				s.Amount = r.FloatString(security.Precision) | ||||
| 			} else { | ||||
| 				return nil, fmt.Errorf("Can't set split Amount: %s", gs.Amount) | ||||
| 			} | ||||
|  | ||||
| 			t.Splits = append(t.Splits, s) | ||||
| 		} | ||||
| 		gncimport.Transactions = append(gncimport.Transactions, *t) | ||||
| 	} | ||||
|  | ||||
| 	return &gncimport, nil | ||||
| } | ||||
|  | ||||
| func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	user, err := GetUserFromSession(db, r) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 1 /*Not Signed In*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if r.Method != "POST" { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	multipartReader, err := r.MultipartReader() | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Assume there is only one 'part' and it's the one we care about | ||||
| 	part, err := multipartReader.NextPart() | ||||
| 	if err != nil { | ||||
| 		if err == io.EOF { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 		} else { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	bufread := bufio.NewReader(part) | ||||
| 	gzHeader, err := bufread.Peek(2) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Does this look like a gzipped file? | ||||
| 	var gnucashImport *GnucashImport | ||||
| 	if gzHeader[0] == 0x1f && gzHeader[1] == 0x8b { | ||||
| 		gzr, err := gzip.NewReader(bufread) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 		gnucashImport, err = ImportGnucash(gzr) | ||||
| 	} else { | ||||
| 		gnucashImport, err = ImportGnucash(bufread) | ||||
| 	} | ||||
|  | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	sqltransaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Import securities, building map from Gnucash security IDs to our | ||||
| 	// internal IDs | ||||
| 	securityMap := make(map[int64]int64) | ||||
| 	for _, security := range gnucashImport.Securities { | ||||
| 		securityId := security.SecurityId // save off because it could be updated | ||||
| 		s, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &security) | ||||
| 		if err != nil { | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 6 /*Import Error*/) | ||||
| 			log.Print(err) | ||||
| 			log.Print(security) | ||||
| 			return | ||||
| 		} | ||||
| 		securityMap[securityId] = s.SecurityId | ||||
| 	} | ||||
|  | ||||
| 	// Import prices, setting security and currency IDs from securityMap | ||||
| 	for _, price := range gnucashImport.Prices { | ||||
| 		price.SecurityId = securityMap[price.SecurityId] | ||||
| 		price.CurrencyId = securityMap[price.CurrencyId] | ||||
| 		price.PriceId = 0 | ||||
|  | ||||
| 		err := CreatePriceIfNotExist(sqltransaction, &price) | ||||
| 		if err != nil { | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 6 /*Import Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Get/create accounts in the database, building a map from Gnucash account | ||||
| 	// IDs to our internal IDs as we go | ||||
| 	accountMap := make(map[int64]int64) | ||||
| 	accountsRemaining := len(gnucashImport.Accounts) | ||||
| 	accountsRemainingLast := accountsRemaining | ||||
| 	for accountsRemaining > 0 { | ||||
| 		for _, account := range gnucashImport.Accounts { | ||||
|  | ||||
| 			// If the account has already been added to the map, skip it | ||||
| 			_, ok := accountMap[account.AccountId] | ||||
| 			if ok { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// If it hasn't been added, but its parent has, add it to the map | ||||
| 			_, ok = accountMap[account.ParentAccountId] | ||||
| 			if ok || account.ParentAccountId == -1 { | ||||
| 				account.UserId = user.UserId | ||||
| 				if account.ParentAccountId != -1 { | ||||
| 					account.ParentAccountId = accountMap[account.ParentAccountId] | ||||
| 				} | ||||
| 				account.SecurityId = securityMap[account.SecurityId] | ||||
| 				a, err := GetCreateAccountTx(sqltransaction, account) | ||||
| 				if err != nil { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print(err) | ||||
| 					return | ||||
| 				} | ||||
| 				accountMap[account.AccountId] = a.AccountId | ||||
| 				accountsRemaining-- | ||||
| 			} | ||||
| 		} | ||||
| 		if accountsRemaining == accountsRemainingLast { | ||||
| 			//We didn't make any progress in importing the next level of accounts, so there must be a circular parent-child relationship, so give up and tell the user they're wrong | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(fmt.Errorf("Circular account parent-child relationship when importing %s", part.FileName())) | ||||
| 			return | ||||
| 		} | ||||
| 		accountsRemainingLast = accountsRemaining | ||||
| 	} | ||||
|  | ||||
| 	// Insert transactions, fixing up account IDs to match internal ones from | ||||
| 	// above | ||||
| 	for _, transaction := range gnucashImport.Transactions { | ||||
| 		var already_imported bool | ||||
| 		for _, split := range transaction.Splits { | ||||
| 			acctId, ok := accountMap[split.AccountId] | ||||
| 			if !ok { | ||||
| 				sqltransaction.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(fmt.Errorf("Error: Split's AccountID Doesn't exist: %d\n", split.AccountId)) | ||||
| 				return | ||||
| 			} | ||||
| 			split.AccountId = acctId | ||||
|  | ||||
| 			exists, err := split.AlreadyImportedTx(sqltransaction) | ||||
| 			if err != nil { | ||||
| 				sqltransaction.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print("Error checking if split was already imported:", err) | ||||
| 				return | ||||
| 			} else if exists { | ||||
| 				already_imported = true | ||||
| 			} | ||||
| 		} | ||||
| 		if !already_imported { | ||||
| 			err := InsertTransactionTx(sqltransaction, &transaction, user) | ||||
| 			if err != nil { | ||||
| 				sqltransaction.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	err = sqltransaction.Commit() | ||||
| 	if err != nil { | ||||
| 		sqltransaction.Rollback() | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	WriteSuccess(w) | ||||
| } | ||||
							
								
								
									
										31
									
								
								internal/handlers/handlers.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								internal/handlers/handlers.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"gopkg.in/gorp.v1" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| // Create a closure over db, allowing the handlers to look like a | ||||
| // http.HandlerFunc | ||||
| type DB = gorp.DbMap | ||||
| type DBHandler func(http.ResponseWriter, *http.Request, *DB) | ||||
|  | ||||
| func DBHandlerFunc(h DBHandler, db *DB) http.HandlerFunc { | ||||
| 	return func(w http.ResponseWriter, r *http.Request) { | ||||
| 		h(w, r, db) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetHandler(db *DB) *http.ServeMux { | ||||
| 	servemux := http.NewServeMux() | ||||
| 	servemux.HandleFunc("/session/", DBHandlerFunc(SessionHandler, db)) | ||||
| 	servemux.HandleFunc("/user/", DBHandlerFunc(UserHandler, db)) | ||||
| 	servemux.HandleFunc("/security/", DBHandlerFunc(SecurityHandler, db)) | ||||
| 	servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler) | ||||
| 	servemux.HandleFunc("/account/", DBHandlerFunc(AccountHandler, db)) | ||||
| 	servemux.HandleFunc("/transaction/", DBHandlerFunc(TransactionHandler, db)) | ||||
| 	servemux.HandleFunc("/import/gnucash", DBHandlerFunc(GnucashImportHandler, db)) | ||||
| 	servemux.HandleFunc("/report/", DBHandlerFunc(ReportHandler, db)) | ||||
|  | ||||
| 	return servemux | ||||
| } | ||||
							
								
								
									
										409
									
								
								internal/handlers/imports.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										409
									
								
								internal/handlers/imports.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,409 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"github.com/aclindsa/ofxgo" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"math/big" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type OFXDownload struct { | ||||
| 	OFXPassword string | ||||
| 	StartDate   time.Time | ||||
| 	EndDate     time.Time | ||||
| } | ||||
|  | ||||
| func (od *OFXDownload) Read(json_str string) error { | ||||
| 	dec := json.NewDecoder(strings.NewReader(json_str)) | ||||
| 	return dec.Decode(od) | ||||
| } | ||||
|  | ||||
| func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, accountid int64) { | ||||
| 	itl, err := ImportOFX(r) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		//TODO is this necessarily an invalid request (what if it was an error on our end)? | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if len(itl.Accounts) != 1 { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Printf("Found %d accounts when importing OFX, expected 1", len(itl.Accounts)) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	sqltransaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Return Account with this Id | ||||
| 	account, err := GetAccountTx(sqltransaction, accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		sqltransaction.Rollback() | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	importedAccount := itl.Accounts[0] | ||||
|  | ||||
| 	if len(account.ExternalAccountId) > 0 && | ||||
| 		account.ExternalAccountId != importedAccount.ExternalAccountId { | ||||
| 		sqltransaction.Rollback() | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Printf("OFX import has \"%s\" as ExternalAccountId, but the account being imported to has\"%s\"", | ||||
| 			importedAccount.ExternalAccountId, | ||||
| 			account.ExternalAccountId) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Find matching existing securities or create new ones for those | ||||
| 	// referenced by the OFX import. Also create a map from placeholder import | ||||
| 	// SecurityIds to the actual SecurityIDs | ||||
| 	var securitymap = make(map[int64]Security) | ||||
| 	for _, ofxsecurity := range itl.Securities { | ||||
| 		// save off since ImportGetCreateSecurity overwrites SecurityId on | ||||
| 		// ofxsecurity | ||||
| 		oldsecurityid := ofxsecurity.SecurityId | ||||
| 		security, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &ofxsecurity) | ||||
| 		if err != nil { | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 		securitymap[oldsecurityid] = *security | ||||
| 	} | ||||
|  | ||||
| 	if account.SecurityId != securitymap[importedAccount.SecurityId].SecurityId { | ||||
| 		sqltransaction.Rollback() | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Printf("OFX import account's SecurityId (%d) does not match this account's (%d)", securitymap[importedAccount.SecurityId].SecurityId, account.SecurityId) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// TODO Ensure all transactions have at least one split in the account | ||||
| 	// we're importing to? | ||||
|  | ||||
| 	var transactions []Transaction | ||||
| 	for _, transaction := range itl.Transactions { | ||||
| 		transaction.UserId = user.UserId | ||||
|  | ||||
| 		if !transaction.Valid() { | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print("Unexpected invalid transaction from OFX import") | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Ensure that either AccountId or SecurityId is set for this split, | ||||
| 		// and fixup the SecurityId to be a valid one for this user's actual | ||||
| 		// securities instead of a placeholder from the import | ||||
| 		for _, split := range transaction.Splits { | ||||
| 			split.Status = Imported | ||||
| 			if split.AccountId != -1 { | ||||
| 				if split.AccountId != importedAccount.AccountId { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print("Imported split's AccountId wasn't -1 but also didn't match the account") | ||||
| 					return | ||||
| 				} | ||||
| 				split.AccountId = account.AccountId | ||||
| 			} else if split.SecurityId != -1 { | ||||
| 				if sec, ok := securitymap[split.SecurityId]; ok { | ||||
| 					// TODO try to auto-match splits to existing accounts based on past transactions that look like this one | ||||
| 					if split.ImportSplitType == TradingAccount { | ||||
| 						// Find/make trading account if we're that type of split | ||||
| 						trading_account, err := GetTradingAccount(sqltransaction, user.UserId, sec.SecurityId) | ||||
| 						if err != nil { | ||||
| 							sqltransaction.Rollback() | ||||
| 							WriteError(w, 999 /*Internal Error*/) | ||||
| 							log.Print("Couldn't find split's SecurityId in map during OFX import") | ||||
| 							return | ||||
| 						} | ||||
| 						split.AccountId = trading_account.AccountId | ||||
| 						split.SecurityId = -1 | ||||
| 					} else if split.ImportSplitType == SubAccount { | ||||
| 						subaccount := &Account{ | ||||
| 							UserId:          user.UserId, | ||||
| 							Name:            sec.Name, | ||||
| 							ParentAccountId: account.AccountId, | ||||
| 							SecurityId:      sec.SecurityId, | ||||
| 							Type:            account.Type, | ||||
| 						} | ||||
| 						subaccount, err := GetCreateAccountTx(sqltransaction, *subaccount) | ||||
| 						if err != nil { | ||||
| 							sqltransaction.Rollback() | ||||
| 							WriteError(w, 999 /*Internal Error*/) | ||||
| 							log.Print(err) | ||||
| 							return | ||||
| 						} | ||||
| 						split.AccountId = subaccount.AccountId | ||||
| 						split.SecurityId = -1 | ||||
| 					} else { | ||||
| 						split.SecurityId = sec.SecurityId | ||||
| 					} | ||||
| 				} else { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print("Couldn't find split's SecurityId in map during OFX import") | ||||
| 					return | ||||
| 				} | ||||
| 			} else { | ||||
| 				sqltransaction.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print("Neither Split.AccountId Split.SecurityId was set during OFX import") | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		imbalances, err := transaction.GetImbalancesTx(sqltransaction) | ||||
| 		if err != nil { | ||||
| 			sqltransaction.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Fixup any imbalances in transactions | ||||
| 		var zero big.Rat | ||||
| 		for imbalanced_security, imbalance := range imbalances { | ||||
| 			if imbalance.Cmp(&zero) != 0 { | ||||
| 				imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, imbalanced_security) | ||||
| 				if err != nil { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print(err) | ||||
| 					return | ||||
| 				} | ||||
|  | ||||
| 				// Add new split to fixup imbalance | ||||
| 				split := new(Split) | ||||
| 				r := new(big.Rat) | ||||
| 				r.Neg(&imbalance) | ||||
| 				security, err := GetSecurityTx(sqltransaction, imbalanced_security, user.UserId) | ||||
| 				if err != nil { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print(err) | ||||
| 					return | ||||
| 				} | ||||
| 				split.Amount = r.FloatString(security.Precision) | ||||
| 				split.SecurityId = -1 | ||||
| 				split.AccountId = imbalanced_account.AccountId | ||||
| 				transaction.Splits = append(transaction.Splits, split) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// Move any splits with SecurityId but not AccountId to Imbalances | ||||
| 		// accounts. In the same loop, check to see if this transaction/split | ||||
| 		// has been imported before | ||||
| 		var already_imported bool | ||||
| 		for _, split := range transaction.Splits { | ||||
| 			if split.SecurityId != -1 || split.AccountId == -1 { | ||||
| 				imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, split.SecurityId) | ||||
| 				if err != nil { | ||||
| 					sqltransaction.Rollback() | ||||
| 					WriteError(w, 999 /*Internal Error*/) | ||||
| 					log.Print(err) | ||||
| 					return | ||||
| 				} | ||||
|  | ||||
| 				split.AccountId = imbalanced_account.AccountId | ||||
| 				split.SecurityId = -1 | ||||
| 			} | ||||
|  | ||||
| 			exists, err := split.AlreadyImportedTx(sqltransaction) | ||||
| 			if err != nil { | ||||
| 				sqltransaction.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print("Error checking if split was already imported:", err) | ||||
| 				return | ||||
| 			} else if exists { | ||||
| 				already_imported = true | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if !already_imported { | ||||
| 			transactions = append(transactions, transaction) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	for _, transaction := range transactions { | ||||
| 		err := InsertTransactionTx(sqltransaction, &transaction, user) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	err = sqltransaction.Commit() | ||||
| 	if err != nil { | ||||
| 		sqltransaction.Rollback() | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	WriteSuccess(w) | ||||
| } | ||||
|  | ||||
| func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) { | ||||
| 	download_json := r.PostFormValue("ofxdownload") | ||||
| 	if download_json == "" { | ||||
| 		log.Print("download_json") | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var ofxdownload OFXDownload | ||||
| 	err := ofxdownload.Read(download_json) | ||||
| 	if err != nil { | ||||
| 		log.Print("ofxdownload.Read") | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	account, err := GetAccount(db, accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		log.Print("GetAccount") | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	ofxver := ofxgo.OfxVersion203 | ||||
| 	if len(account.OFXVersion) != 0 { | ||||
| 		ofxver, err = ofxgo.NewOfxVersion(account.OFXVersion) | ||||
| 		if err != nil { | ||||
| 			log.Print("NewOfxVersion") | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var client = ofxgo.Client{ | ||||
| 		AppID:       account.OFXAppID, | ||||
| 		AppVer:      account.OFXAppVer, | ||||
| 		SpecVersion: ofxver, | ||||
| 		NoIndent:    account.OFXNoIndent, | ||||
| 	} | ||||
|  | ||||
| 	var query ofxgo.Request | ||||
| 	query.URL = account.OFXURL | ||||
| 	query.Signon.ClientUID = ofxgo.UID(account.OFXClientUID) | ||||
| 	query.Signon.UserID = ofxgo.String(account.OFXUser) | ||||
| 	query.Signon.UserPass = ofxgo.String(ofxdownload.OFXPassword) | ||||
| 	query.Signon.Org = ofxgo.String(account.OFXORG) | ||||
| 	query.Signon.Fid = ofxgo.String(account.OFXFID) | ||||
|  | ||||
| 	transactionuid, err := ofxgo.RandomUID() | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Println("Error creating uid for transaction:", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if account.Type == Investment { | ||||
| 		// Investment account | ||||
| 		statementRequest := ofxgo.InvStatementRequest{ | ||||
| 			TrnUID: *transactionuid, | ||||
| 			InvAcctFrom: ofxgo.InvAcct{ | ||||
| 				BrokerID: ofxgo.String(account.OFXBankID), | ||||
| 				AcctID:   ofxgo.String(account.OFXAcctID), | ||||
| 			}, | ||||
| 			Include:        true, | ||||
| 			IncludeOO:      true, | ||||
| 			IncludePos:     true, | ||||
| 			IncludeBalance: true, | ||||
| 			Include401K:    true, | ||||
| 			Include401KBal: true, | ||||
| 		} | ||||
| 		query.InvStmt = append(query.InvStmt, &statementRequest) | ||||
| 	} else if account.OFXAcctType == "CC" { | ||||
| 		// Import credit card transactions | ||||
| 		statementRequest := ofxgo.CCStatementRequest{ | ||||
| 			TrnUID: *transactionuid, | ||||
| 			CCAcctFrom: ofxgo.CCAcct{ | ||||
| 				AcctID: ofxgo.String(account.OFXAcctID), | ||||
| 			}, | ||||
| 			Include: true, | ||||
| 		} | ||||
| 		query.CreditCard = append(query.CreditCard, &statementRequest) | ||||
| 	} else { | ||||
| 		// Import generic bank transactions | ||||
| 		acctTypeEnum, err := ofxgo.NewAcctType(account.OFXAcctType) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		statementRequest := ofxgo.StatementRequest{ | ||||
| 			TrnUID: *transactionuid, | ||||
| 			BankAcctFrom: ofxgo.BankAcct{ | ||||
| 				BankID:   ofxgo.String(account.OFXBankID), | ||||
| 				AcctID:   ofxgo.String(account.OFXAcctID), | ||||
| 				AcctType: acctTypeEnum, | ||||
| 			}, | ||||
| 			Include: true, | ||||
| 		} | ||||
| 		query.Bank = append(query.Bank, &statementRequest) | ||||
| 	} | ||||
|  | ||||
| 	response, err := client.RequestNoParse(&query) | ||||
| 	if err != nil { | ||||
| 		// TODO this could be an error talking with the OFX server... | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
| 	defer response.Body.Close() | ||||
|  | ||||
| 	ofxImportHelper(db, response.Body, w, user, accountid) | ||||
| } | ||||
|  | ||||
| func OFXFileImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) { | ||||
| 	multipartReader, err := r.MultipartReader() | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// assume there is only one 'part' | ||||
| 	part, err := multipartReader.NextPart() | ||||
| 	if err != nil { | ||||
| 		if err == io.EOF { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			log.Print("Encountered unexpected EOF") | ||||
| 		} else { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	ofxImportHelper(db, part, w, user, accountid) | ||||
| } | ||||
|  | ||||
| /* | ||||
|  * Assumes the User is a valid, signed-in user, but accountid has not yet been validated | ||||
|  */ | ||||
| func AccountImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) { | ||||
|  | ||||
| 	switch importtype { | ||||
| 	case "ofx": | ||||
| 		OFXImportHandler(db, w, r, user, accountid) | ||||
| 	case "ofxfile": | ||||
| 		OFXFileImportHandler(db, w, r, user, accountid) | ||||
| 	default: | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										1003
									
								
								internal/handlers/ofx.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1003
									
								
								internal/handlers/ofx.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										113
									
								
								internal/handlers/prices.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								internal/handlers/prices.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"gopkg.in/gorp.v1" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type Price struct { | ||||
| 	PriceId    int64 | ||||
| 	SecurityId int64 | ||||
| 	CurrencyId int64 | ||||
| 	Date       time.Time | ||||
| 	Value      string // String representation of decimal price of Security in Currency units, suitable for passing to big.Rat.SetString() | ||||
| 	RemoteId   string // unique ID from source, for detecting duplicates | ||||
| } | ||||
|  | ||||
| func InsertPriceTx(transaction *gorp.Transaction, p *Price) error { | ||||
| 	err := transaction.Insert(p) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error { | ||||
| 	if len(price.RemoteId) == 0 { | ||||
| 		// Always create a new price if we can't match on the RemoteId | ||||
| 		err := InsertPriceTx(transaction, price) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	var prices []*Price | ||||
|  | ||||
| 	_, err := transaction.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if len(prices) > 0 { | ||||
| 		return nil // price already exists | ||||
| 	} | ||||
|  | ||||
| 	err = InsertPriceTx(transaction, price) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Return the latest price for security in currency units before date | ||||
| func GetLatestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { | ||||
| 	var p Price | ||||
| 	err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &p, nil | ||||
| } | ||||
|  | ||||
| // Return the earliest price for security in currency units after date | ||||
| func GetEarliestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { | ||||
| 	var p Price | ||||
| 	err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &p, nil | ||||
| } | ||||
|  | ||||
| // Return the price for security in currency closest to date | ||||
| func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { | ||||
| 	earliest, _ := GetEarliestPrice(transaction, security, currency, date) | ||||
| 	latest, err := GetLatestPrice(transaction, security, currency, date) | ||||
|  | ||||
| 	// Return early if either earliest or latest are invalid | ||||
| 	if earliest == nil { | ||||
| 		return latest, err | ||||
| 	} else if err != nil { | ||||
| 		return earliest, nil | ||||
| 	} | ||||
|  | ||||
| 	howlate := earliest.Date.Sub(*date) | ||||
| 	howearly := date.Sub(latest.Date) | ||||
| 	if howearly < howlate { | ||||
| 		return latest, nil | ||||
| 	} else { | ||||
| 		return earliest, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetClosestPrice(db *DB, security, currency *Security, date *time.Time) (*Price, error) { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	price, err := GetClosestPriceTx(transaction, security, currency, date) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return price, nil | ||||
| } | ||||
							
								
								
									
										91
									
								
								internal/handlers/prices_lua.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								internal/handlers/prices_lua.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,91 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| ) | ||||
|  | ||||
| const luaPriceTypeName = "price" | ||||
|  | ||||
| func luaRegisterPrices(L *lua.LState) { | ||||
| 	mt := L.NewTypeMetatable(luaPriceTypeName) | ||||
| 	L.SetGlobal("price", mt) | ||||
| 	L.SetField(mt, "__index", L.NewFunction(luaPrice__index)) | ||||
| 	L.SetField(mt, "__tostring", L.NewFunction(luaPrice__tostring)) | ||||
| 	L.SetField(mt, "__metatable", lua.LString("protected")) | ||||
| } | ||||
|  | ||||
| func PriceToLua(L *lua.LState, price *Price) *lua.LUserData { | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = price | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaPriceTypeName)) | ||||
| 	return ud | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Price and returns this *Price. | ||||
| func luaCheckPrice(L *lua.LState, n int) *Price { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if price, ok := ud.Value.(*Price); ok { | ||||
| 		return price | ||||
| 	} | ||||
| 	L.ArgError(n, "price expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaPrice__index(L *lua.LState) int { | ||||
| 	p := luaCheckPrice(L, 1) | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "PriceId", "priceid": | ||||
| 		L.Push(lua.LNumber(float64(p.PriceId))) | ||||
| 	case "Security", "security": | ||||
| 		security_map, err := luaContextGetSecurities(L) | ||||
| 		if err != nil { | ||||
| 			panic("luaContextGetSecurities couldn't fetch securities") | ||||
| 		} | ||||
| 		s, ok := security_map[p.SecurityId] | ||||
| 		if !ok { | ||||
| 			panic("Price's security not found for user") | ||||
| 		} | ||||
| 		L.Push(SecurityToLua(L, s)) | ||||
| 	case "Currency", "currency": | ||||
| 		security_map, err := luaContextGetSecurities(L) | ||||
| 		if err != nil { | ||||
| 			panic("luaContextGetSecurities couldn't fetch securities") | ||||
| 		} | ||||
| 		c, ok := security_map[p.CurrencyId] | ||||
| 		if !ok { | ||||
| 			panic("Price's currency not found for user") | ||||
| 		} | ||||
| 		L.Push(SecurityToLua(L, c)) | ||||
| 	case "Value", "value": | ||||
| 		amt, err := GetBigAmount(p.Value) | ||||
| 		if err != nil { | ||||
| 			panic(err) | ||||
| 		} | ||||
| 		float, _ := amt.Float64() | ||||
| 		L.Push(lua.LNumber(float)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected price attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaPrice__tostring(L *lua.LState) int { | ||||
| 	p := luaCheckPrice(L, 1) | ||||
|  | ||||
| 	security_map, err := luaContextGetSecurities(L) | ||||
| 	if err != nil { | ||||
| 		panic("luaContextGetSecurities couldn't fetch securities") | ||||
| 	} | ||||
| 	s, ok1 := security_map[p.SecurityId] | ||||
| 	c, ok2 := security_map[p.CurrencyId] | ||||
| 	if !ok1 || !ok2 { | ||||
| 		panic("Price's currency or security not found for user") | ||||
| 	} | ||||
|  | ||||
| 	L.Push(lua.LString(p.Value + " " + c.Symbol + " (" + s.Symbol + ")")) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
							
								
								
									
										354
									
								
								internal/handlers/reports.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										354
									
								
								internal/handlers/reports.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,354 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var reportTabulationRE *regexp.Regexp | ||||
|  | ||||
| func init() { | ||||
| 	reportTabulationRE = regexp.MustCompile(`^/report/[0-9]+/tabulation/?$`) | ||||
| } | ||||
|  | ||||
| //type and value to store user in lua's Context | ||||
| type key int | ||||
|  | ||||
| const ( | ||||
| 	userContextKey key = iota | ||||
| 	accountsContextKey | ||||
| 	securitiesContextKey | ||||
| 	balanceContextKey | ||||
| 	dbContextKey | ||||
| ) | ||||
|  | ||||
| const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for | ||||
|  | ||||
| type Report struct { | ||||
| 	ReportId int64 | ||||
| 	UserId   int64 | ||||
| 	Name     string | ||||
| 	Lua      string | ||||
| } | ||||
|  | ||||
| func (r *Report) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(r) | ||||
| } | ||||
|  | ||||
| func (r *Report) Read(json_str string) error { | ||||
| 	dec := json.NewDecoder(strings.NewReader(json_str)) | ||||
| 	return dec.Decode(r) | ||||
| } | ||||
|  | ||||
| type ReportList struct { | ||||
| 	Reports *[]Report `json:"reports"` | ||||
| } | ||||
|  | ||||
| func (rl *ReportList) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(rl) | ||||
| } | ||||
|  | ||||
| type Series struct { | ||||
| 	Values []float64 | ||||
| 	Series map[string]*Series | ||||
| } | ||||
|  | ||||
| type Tabulation struct { | ||||
| 	ReportId int64 | ||||
| 	Title    string | ||||
| 	Subtitle string | ||||
| 	Units    string | ||||
| 	Labels   []string | ||||
| 	Series   map[string]*Series | ||||
| } | ||||
|  | ||||
| func (r *Tabulation) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(r) | ||||
| } | ||||
|  | ||||
| func GetReport(db *DB, reportid int64, userid int64) (*Report, error) { | ||||
| 	var r Report | ||||
|  | ||||
| 	err := db.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &r, nil | ||||
| } | ||||
|  | ||||
| func GetReports(db *DB, userid int64) (*[]Report, error) { | ||||
| 	var reports []Report | ||||
|  | ||||
| 	_, err := db.Select(&reports, "SELECT * from reports where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &reports, nil | ||||
| } | ||||
|  | ||||
| func InsertReport(db *DB, r *Report) error { | ||||
| 	err := db.Insert(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func UpdateReport(db *DB, r *Report) error { | ||||
| 	count, err := db.Update(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return errors.New("Updated more than one report") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func DeleteReport(db *DB, r *Report) error { | ||||
| 	count, err := db.Delete(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return errors.New("Deleted more than one report") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func runReport(db *DB, user *User, report *Report) (*Tabulation, error) { | ||||
| 	// Create a new LState without opening the default libs for security | ||||
| 	L := lua.NewState(lua.Options{SkipOpenLibs: true}) | ||||
| 	defer L.Close() | ||||
|  | ||||
| 	// Create a new context holding the current user with a timeout | ||||
| 	ctx := context.WithValue(context.Background(), userContextKey, user) | ||||
| 	ctx = context.WithValue(ctx, dbContextKey, db) | ||||
| 	ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second) | ||||
| 	defer cancel() | ||||
| 	L.SetContext(ctx) | ||||
|  | ||||
| 	for _, pair := range []struct { | ||||
| 		n string | ||||
| 		f lua.LGFunction | ||||
| 	}{ | ||||
| 		{lua.LoadLibName, lua.OpenPackage}, // Must be first | ||||
| 		{lua.BaseLibName, lua.OpenBase}, | ||||
| 		{lua.TabLibName, lua.OpenTable}, | ||||
| 		{lua.StringLibName, lua.OpenString}, | ||||
| 		{lua.MathLibName, lua.OpenMath}, | ||||
| 	} { | ||||
| 		if err := L.CallByParam(lua.P{ | ||||
| 			Fn:      L.NewFunction(pair.f), | ||||
| 			NRet:    0, | ||||
| 			Protect: true, | ||||
| 		}, lua.LString(pair.n)); err != nil { | ||||
| 			return nil, errors.New("Error initializing Lua packages") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	luaRegisterAccounts(L) | ||||
| 	luaRegisterSecurities(L) | ||||
| 	luaRegisterBalances(L) | ||||
| 	luaRegisterDates(L) | ||||
| 	luaRegisterTabulations(L) | ||||
| 	luaRegisterPrices(L) | ||||
|  | ||||
| 	err := L.DoString(report.Lua) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if err := L.CallByParam(lua.P{ | ||||
| 		Fn:      L.GetGlobal("generate"), | ||||
| 		NRet:    1, | ||||
| 		Protect: true, | ||||
| 	}); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	value := L.Get(-1) | ||||
| 	if ud, ok := value.(*lua.LUserData); ok { | ||||
| 		if tabulation, ok := ud.Value.(*Tabulation); ok { | ||||
| 			return tabulation, nil | ||||
| 		} else { | ||||
| 			return nil, fmt.Errorf("generate() for %s (Id: %d) didn't return a tabulation", report.Name, report.ReportId) | ||||
| 		} | ||||
| 	} else { | ||||
| 		return nil, fmt.Errorf("generate() for %s (Id: %d) didn't even return LUserData", report.Name, report.ReportId) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ReportTabulationHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, reportid int64) { | ||||
| 	report, err := GetReport(db, reportid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	tabulation, err := runReport(db, user, report) | ||||
| 	if err != nil { | ||||
| 		// TODO handle different failure cases differently | ||||
| 		log.Print("runReport returned:", err) | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	tabulation.ReportId = reportid | ||||
|  | ||||
| 	err = tabulation.Write(w) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	user, err := GetUserFromSession(db, r) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 1 /*Not Signed In*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if r.Method == "POST" { | ||||
| 		report_json := r.PostFormValue("report") | ||||
| 		if report_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		var report Report | ||||
| 		err := report.Read(report_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		report.ReportId = -1 | ||||
| 		report.UserId = user.UserId | ||||
|  | ||||
| 		err = InsertReport(db, &report) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		w.WriteHeader(201 /*Created*/) | ||||
| 		err = report.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 	} else if r.Method == "GET" { | ||||
| 		if reportTabulationRE.MatchString(r.URL.Path) { | ||||
| 			var reportid int64 | ||||
| 			n, err := GetURLPieces(r.URL.Path, "/report/%d/tabulation", &reportid) | ||||
| 			if err != nil || n != 1 { | ||||
| 				WriteError(w, 999 /*InternalError*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			ReportTabulationHandler(db, w, r, user, reportid) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		var reportid int64 | ||||
| 		n, err := GetURLPieces(r.URL.Path, "/report/%d", &reportid) | ||||
| 		if err != nil || n != 1 { | ||||
| 			//Return all Reports | ||||
| 			var rl ReportList | ||||
| 			reports, err := GetReports(db, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			rl.Reports = reports | ||||
| 			err = (&rl).Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			// Return Report with this Id | ||||
| 			report, err := GetReport(db, reportid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = report.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		reportid, err := GetURLID(r.URL.Path) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if r.Method == "PUT" { | ||||
| 			report_json := r.PostFormValue("report") | ||||
| 			if report_json == "" { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			var report Report | ||||
| 			err := report.Read(report_json) | ||||
| 			if err != nil || report.ReportId != reportid { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
| 			report.UserId = user.UserId | ||||
|  | ||||
| 			err = UpdateReport(db, &report) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = report.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			report, err := GetReport(db, reportid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteReport(db, report) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			WriteSuccess(w) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										42
									
								
								internal/handlers/reports/asset_allocation.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								internal/handlers/reports/asset_allocation.lua
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| function generate() | ||||
|     accounts = get_accounts() | ||||
|     securities = get_securities() | ||||
|     default_currency = get_default_currency() | ||||
|     series_map = {} | ||||
|     totals_map = {} | ||||
|  | ||||
|     t = tabulation.new(1) | ||||
|     t:title("Current Asset Allocation") | ||||
|  | ||||
|     t:label(1, "Assets") | ||||
|  | ||||
|     for id, security in pairs(securities) do | ||||
|         totals_map[id] = 0 | ||||
|         series_map[id] = t:series(tostring(security)) | ||||
|     end | ||||
|  | ||||
|     for id, acct in pairs(accounts) do | ||||
|         if acct.type == account.Asset or acct.type == account.Investment or acct.type == account.Bank or acct.type == account.Cash then | ||||
|             balance = acct:balance() | ||||
|             multiplier = 1 | ||||
|             if acct.security ~= default_currency and balance.amount ~= 0 then | ||||
|                 price = acct.security:closestprice(default_currency, date.now()) | ||||
|                 if price == nil then | ||||
|                     --[[ | ||||
|                     -- This should contain code to warn the user that their report is missing some information | ||||
|                     --]] | ||||
|                     multiplier = 0 | ||||
|                 else | ||||
|                     multiplier = price.value | ||||
|                 end | ||||
|             end | ||||
|             totals_map[acct.security.SecurityId] = balance.amount * multiplier + totals_map[acct.security.SecurityId] | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     for id, series in pairs(series_map) do | ||||
|         series:value(1, totals_map[id]) | ||||
|     end | ||||
|  | ||||
|     return t | ||||
| end | ||||
							
								
								
									
										26
									
								
								internal/handlers/reports/monthly_cash_flow.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								internal/handlers/reports/monthly_cash_flow.lua
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| function generate() | ||||
|     year = date.now().year | ||||
|  | ||||
|     accounts = get_accounts() | ||||
|     t = tabulation.new(12) | ||||
|     t:title(year .. " Monthly Cash Flow") | ||||
|     series = t:series("Income minus expenses") | ||||
|  | ||||
|     for month=1,12 do | ||||
|         begin_date = date.new(year, month, 1) | ||||
|         end_date = date.new(year, month+1, 1) | ||||
|  | ||||
|         t:label(month, tostring(begin_date)) | ||||
|         cash_flow = 0 | ||||
|  | ||||
|         for id, acct in pairs(accounts) do | ||||
|             if acct.type == account.Expense or acct.type == account.Income then | ||||
|                 balance = acct:balance(begin_date, end_date) | ||||
|                 cash_flow = cash_flow - balance.amount | ||||
|             end | ||||
|         end | ||||
|         series:value(month, cash_flow) | ||||
|     end | ||||
|  | ||||
|     return t | ||||
| end | ||||
							
								
								
									
										49
									
								
								internal/handlers/reports/monthly_expenses.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								internal/handlers/reports/monthly_expenses.lua
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| function account_series_map(accounts, tabulation) | ||||
|     map = {} | ||||
|  | ||||
|     for i=1,100 do -- we're not messing with accounts more than 100 levels deep | ||||
|         all_handled = true | ||||
|         for id, acct in pairs(accounts) do | ||||
|             if not map[id] then | ||||
|                 all_handled = false | ||||
|                 if not acct.parent then | ||||
|                     map[id] = tabulation:series(acct.name) | ||||
|                 elseif map[acct.parent.accountid] then | ||||
|                     map[id] = map[acct.parent.accountid]:series(acct.name) | ||||
|                 end | ||||
|             end | ||||
|         end | ||||
|         if all_handled then | ||||
|             return map | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     error("Accounts nested (at least) 100 levels deep") | ||||
| end | ||||
|  | ||||
| function generate() | ||||
|     year = date.now().year | ||||
|     account_type = account.Expense | ||||
|  | ||||
|     accounts = get_accounts() | ||||
|     t = tabulation.new(12) | ||||
|     t:title(year .. " Monthly Expenses") | ||||
|     series_map = account_series_map(accounts, t) | ||||
|  | ||||
|     for month=1,12 do | ||||
|         begin_date = date.new(year, month, 1) | ||||
|         end_date = date.new(year, month+1, 1) | ||||
|  | ||||
|         t:label(month, tostring(begin_date)) | ||||
|  | ||||
|         for id, acct in pairs(accounts) do | ||||
|             series = series_map[id] | ||||
|             if acct.type == account_type then | ||||
|                 balance = acct:balance(begin_date, end_date) | ||||
|                 series:value(month, balance.amount) | ||||
|             end | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     return t | ||||
| end | ||||
							
								
								
									
										60
									
								
								internal/handlers/reports/monthly_net_worth.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								internal/handlers/reports/monthly_net_worth.lua
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | ||||
| function account_series_map(accounts, tabulation) | ||||
|     map = {} | ||||
|  | ||||
|     for i=1,100 do -- we're not messing with accounts more than 100 levels deep | ||||
|         all_handled = true | ||||
|         for id, acct in pairs(accounts) do | ||||
|             if not map[id] then | ||||
|                 all_handled = false | ||||
|                 if not acct.parent then | ||||
|                     map[id] = tabulation:series(acct.name) | ||||
|                 elseif map[acct.parent.accountid] then | ||||
|                     map[id] = map[acct.parent.accountid]:series(acct.name) | ||||
|                 end | ||||
|             end | ||||
|         end | ||||
|         if all_handled then | ||||
|             return map | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     error("Accounts nested (at least) 100 levels deep") | ||||
| end | ||||
|  | ||||
| function generate() | ||||
|     year = date.now().year | ||||
|  | ||||
|     accounts = get_accounts() | ||||
|     t = tabulation.new(12) | ||||
|     t:title(year .. " Monthly Net Worth") | ||||
|     series_map = account_series_map(accounts, t) | ||||
|     default_currency = get_default_currency() | ||||
|  | ||||
|     for month=1,12 do | ||||
|         end_date = date.new(year, month+1, 1) | ||||
|  | ||||
|         t:label(month, tostring(end_date)) | ||||
|  | ||||
|         for id, acct in pairs(accounts) do | ||||
|             series = series_map[id] | ||||
|             if acct.type ~= account.Expense and acct.type ~= account.Income and acct.type ~= account.Trading then | ||||
|                 balance = acct:balance(end_date) | ||||
|                 multiplier = 1 | ||||
|                 if acct.security ~= default_currency and balance.amount ~= 0 then | ||||
|                     price = acct.security:closestprice(default_currency, end_date) | ||||
|                     if price == nil then | ||||
|                         --[[ | ||||
|                         -- This should contain code to warn the user that their report is missing some information | ||||
|                         --]] | ||||
|                         multiplier = 0 | ||||
|                     else | ||||
|                         multiplier = price.value | ||||
|                     end | ||||
|                 end | ||||
|                 series:value(month, balance.amount * multiplier) | ||||
|             end | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     return t | ||||
| end | ||||
							
								
								
									
										61
									
								
								internal/handlers/reports/monthly_net_worth_change.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								internal/handlers/reports/monthly_net_worth_change.lua
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| function account_series_map(accounts, tabulation) | ||||
|     map = {} | ||||
|  | ||||
|     for i=1,100 do -- we're not messing with accounts more than 100 levels deep | ||||
|         all_handled = true | ||||
|         for id, acct in pairs(accounts) do | ||||
|             if not map[id] then | ||||
|                 all_handled = false | ||||
|                 if not acct.parent then | ||||
|                     map[id] = tabulation:series(acct.name) | ||||
|                 elseif map[acct.parent.accountid] then | ||||
|                     map[id] = map[acct.parent.accountid]:series(acct.name) | ||||
|                 end | ||||
|             end | ||||
|         end | ||||
|         if all_handled then | ||||
|             return map | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     error("Accounts nested (at least) 100 levels deep") | ||||
| end | ||||
|  | ||||
| function generate() | ||||
|     year = date.now().year | ||||
|  | ||||
|     accounts = get_accounts() | ||||
|     t = tabulation.new(12) | ||||
|     t:title(year .. " Monthly Net Worth") | ||||
|     series_map = account_series_map(accounts, t) | ||||
|     default_currency = get_default_currency() | ||||
|  | ||||
|     for month=1,12 do | ||||
|         begin_date = date.new(year, month, 1) | ||||
|         end_date = date.new(year, month+1, 1) | ||||
|  | ||||
|         t:label(month, tostring(begin_date)) | ||||
|  | ||||
|         for id, acct in pairs(accounts) do | ||||
|             series = series_map[id] | ||||
|             if acct.type ~= account.Expense and acct.type ~= account.Income and acct.type ~= account.Trading then | ||||
|                 balance = acct:balance(begin_date, end_date) | ||||
|                 multiplier = 1 | ||||
|                 if acct.security ~= default_currency then | ||||
|                     price = acct.security:closestprice(default_currency, end_date) | ||||
|                     if price == nil then | ||||
|                         --[[ | ||||
|                         -- This should contain code to warn the user that their report is missing some information | ||||
|                         --]] | ||||
|                         multiplier = 0 | ||||
|                     else | ||||
|                         multiplier = price.value | ||||
|                     end | ||||
|                 end | ||||
|                 series:value(month, balance.amount * multiplier) | ||||
|             end | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     return t | ||||
| end | ||||
							
								
								
									
										60
									
								
								internal/handlers/reports/quarterly_net_worth.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								internal/handlers/reports/quarterly_net_worth.lua
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | ||||
| function account_series_map(accounts, tabulation) | ||||
|     map = {} | ||||
|  | ||||
|     for i=1,100 do -- we're not messing with accounts more than 100 levels deep | ||||
|         all_handled = true | ||||
|         for id, acct in pairs(accounts) do | ||||
|             if not map[id] then | ||||
|                 all_handled = false | ||||
|                 if not acct.parent then | ||||
|                     map[id] = tabulation:series(acct.name) | ||||
|                 elseif map[acct.parent.accountid] then | ||||
|                     map[id] = map[acct.parent.accountid]:series(acct.name) | ||||
|                 end | ||||
|             end | ||||
|         end | ||||
|         if all_handled then | ||||
|             return map | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     error("Accounts nested (at least) 100 levels deep") | ||||
| end | ||||
|  | ||||
| function generate() | ||||
|     year = date.now().year-4 | ||||
|  | ||||
|     accounts = get_accounts() | ||||
|     t = tabulation.new(20) | ||||
|     t:title(year .. "-" .. date.now().year .. " Quarterly Net Worth") | ||||
|     series_map = account_series_map(accounts, t:series("Net Worth")) | ||||
|     default_currency = get_default_currency() | ||||
|  | ||||
|     for month=1,20 do | ||||
|         end_date = date.new(year, month*3-2, 1) | ||||
|  | ||||
|         t:label(month, tostring(end_date)) | ||||
|  | ||||
|         for id, acct in pairs(accounts) do | ||||
|             series = series_map[id] | ||||
|             if acct.type ~= account.Expense and acct.type ~= account.Income and acct.type ~= account.Trading then | ||||
|                 balance = acct:balance(end_date) | ||||
|                 multiplier = 1 | ||||
|                 if acct.security ~= default_currency then | ||||
|                     price = acct.security:closestprice(default_currency, end_date) | ||||
|                     if price == nil then | ||||
|                         --[[ | ||||
|                         -- This should contain code to warn the user that their report is missing some information | ||||
|                         --]] | ||||
|                         multiplier = 0 | ||||
|                     else | ||||
|                         multiplier = price.value | ||||
|                     end | ||||
|                 end | ||||
|                 series:value(month, balance.amount * multiplier) | ||||
|             end | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     return t | ||||
| end | ||||
							
								
								
									
										47
									
								
								internal/handlers/reports/years_income.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								internal/handlers/reports/years_income.lua
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | ||||
| function account_series_map(accounts, tabulation) | ||||
|     map = {} | ||||
|  | ||||
|     for i=1,100 do -- we're not messing with accounts more than 100 levels deep | ||||
|         all_handled = true | ||||
|         for id, acct in pairs(accounts) do | ||||
|             if not map[id] then | ||||
|                 all_handled = false | ||||
|                 if not acct.parent then | ||||
|                     map[id] = tabulation:series(acct.name) | ||||
|                 elseif map[acct.parent.accountid] then | ||||
|                     map[id] = map[acct.parent.accountid]:series(acct.name) | ||||
|                 end | ||||
|             end | ||||
|         end | ||||
|         if all_handled then | ||||
|             return map | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     error("Accounts nested (at least) 100 levels deep") | ||||
| end | ||||
|  | ||||
| function generate() | ||||
|     year = date.now().year | ||||
|     account_type = account.Income | ||||
|  | ||||
|     accounts = get_accounts() | ||||
|     t = tabulation.new(1) | ||||
|     t:title(year .. " Income") | ||||
|     series_map = account_series_map(accounts, t) | ||||
|  | ||||
|     begin_date = date.new(year, 1, 1) | ||||
|     end_date = date.new(year+1, 1, 1) | ||||
|  | ||||
|     t:label(1, year .. " Income") | ||||
|  | ||||
|     for id, acct in pairs(accounts) do | ||||
|         series = series_map[id] | ||||
|         if acct.type == account_type then | ||||
|             balance = acct:balance(begin_date, end_date) | ||||
|             series:value(1, balance.amount) | ||||
|         end | ||||
|     end | ||||
|  | ||||
|     return t | ||||
| end | ||||
							
								
								
									
										187
									
								
								internal/handlers/reports_lua.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										187
									
								
								internal/handlers/reports_lua.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,187 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| ) | ||||
|  | ||||
| const luaTabulationTypeName = "tabulation" | ||||
| const luaSeriesTypeName = "series" | ||||
|  | ||||
| func luaRegisterTabulations(L *lua.LState) { | ||||
| 	mtr := L.NewTypeMetatable(luaTabulationTypeName) | ||||
| 	L.SetGlobal("tabulation", mtr) | ||||
| 	L.SetField(mtr, "new", L.NewFunction(luaTabulationNew)) | ||||
| 	L.SetField(mtr, "__index", L.NewFunction(luaTabulation__index)) | ||||
| 	L.SetField(mtr, "__metatable", lua.LString("protected")) | ||||
|  | ||||
| 	mts := L.NewTypeMetatable(luaSeriesTypeName) | ||||
| 	L.SetGlobal("series", mts) | ||||
| 	L.SetField(mts, "__index", L.NewFunction(luaSeries__index)) | ||||
| 	L.SetField(mts, "__metatable", lua.LString("protected")) | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Tabulation and returns *Tabulation | ||||
| func luaCheckTabulation(L *lua.LState, n int) *Tabulation { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if tabulation, ok := ud.Value.(*Tabulation); ok { | ||||
| 		return tabulation | ||||
| 	} | ||||
| 	L.ArgError(n, "tabulation expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Series and returns *Series | ||||
| func luaCheckSeries(L *lua.LState, n int) *Series { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if series, ok := ud.Value.(*Series); ok { | ||||
| 		return series | ||||
| 	} | ||||
| 	L.ArgError(n, "series expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaTabulationNew(L *lua.LState) int { | ||||
| 	numvalues := L.CheckInt(1) | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = &Tabulation{ | ||||
| 		Labels: make([]string, numvalues), | ||||
| 		Series: make(map[string]*Series), | ||||
| 	} | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaTabulationTypeName)) | ||||
| 	L.Push(ud) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaTabulation__index(L *lua.LState) int { | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "Label", "label": | ||||
| 		L.Push(L.NewFunction(luaTabulationLabel)) | ||||
| 	case "Series", "series": | ||||
| 		L.Push(L.NewFunction(luaTabulationSeries)) | ||||
| 	case "Title", "title": | ||||
| 		L.Push(L.NewFunction(luaTabulationTitle)) | ||||
| 	case "Subtitle", "subtitle": | ||||
| 		L.Push(L.NewFunction(luaTabulationSubtitle)) | ||||
| 	case "Units", "units": | ||||
| 		L.Push(L.NewFunction(luaTabulationUnits)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected tabulation attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaTabulationLabel(L *lua.LState) int { | ||||
| 	tabulation := luaCheckTabulation(L, 1) | ||||
| 	labelnumber := L.CheckInt(2) | ||||
| 	label := L.CheckString(3) | ||||
|  | ||||
| 	if labelnumber > cap(tabulation.Labels) || labelnumber < 1 { | ||||
| 		L.ArgError(2, "Label index must be between 1 and the number of data points, inclusive") | ||||
| 	} | ||||
| 	tabulation.Labels[labelnumber-1] = label | ||||
| 	return 0 | ||||
| } | ||||
|  | ||||
| func luaTabulationSeries(L *lua.LState) int { | ||||
| 	tabulation := luaCheckTabulation(L, 1) | ||||
| 	name := L.CheckString(2) | ||||
| 	ud := L.NewUserData() | ||||
|  | ||||
| 	s, ok := tabulation.Series[name] | ||||
| 	if ok { | ||||
| 		ud.Value = s | ||||
| 	} else { | ||||
| 		tabulation.Series[name] = &Series{ | ||||
| 			Series: make(map[string]*Series), | ||||
| 			Values: make([]float64, cap(tabulation.Labels)), | ||||
| 		} | ||||
| 		ud.Value = tabulation.Series[name] | ||||
| 	} | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaSeriesTypeName)) | ||||
| 	L.Push(ud) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaTabulationTitle(L *lua.LState) int { | ||||
| 	tabulation := luaCheckTabulation(L, 1) | ||||
|  | ||||
| 	if L.GetTop() == 2 { | ||||
| 		tabulation.Title = L.CheckString(2) | ||||
| 		return 0 | ||||
| 	} | ||||
| 	L.Push(lua.LString(tabulation.Title)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaTabulationSubtitle(L *lua.LState) int { | ||||
| 	tabulation := luaCheckTabulation(L, 1) | ||||
|  | ||||
| 	if L.GetTop() == 2 { | ||||
| 		tabulation.Subtitle = L.CheckString(2) | ||||
| 		return 0 | ||||
| 	} | ||||
| 	L.Push(lua.LString(tabulation.Subtitle)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaTabulationUnits(L *lua.LState) int { | ||||
| 	tabulation := luaCheckTabulation(L, 1) | ||||
|  | ||||
| 	if L.GetTop() == 2 { | ||||
| 		tabulation.Units = L.CheckString(2) | ||||
| 		return 0 | ||||
| 	} | ||||
| 	L.Push(lua.LString(tabulation.Units)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaSeries__index(L *lua.LState) int { | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "Value", "value": | ||||
| 		L.Push(L.NewFunction(luaSeriesValue)) | ||||
| 	case "Series", "series": | ||||
| 		L.Push(L.NewFunction(luaSeriesSeries)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected series attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaSeriesValue(L *lua.LState) int { | ||||
| 	series := luaCheckSeries(L, 1) | ||||
| 	valuenumber := L.CheckInt(2) | ||||
| 	value := float64(L.CheckNumber(3)) | ||||
|  | ||||
| 	if valuenumber > cap(series.Values) || valuenumber < 1 { | ||||
| 		L.ArgError(2, "value index must be between 1 and the number of data points, inclusive") | ||||
| 	} | ||||
| 	series.Values[valuenumber-1] = value | ||||
|  | ||||
| 	return 0 | ||||
| } | ||||
|  | ||||
| func luaSeriesSeries(L *lua.LState) int { | ||||
| 	parent := luaCheckSeries(L, 1) | ||||
| 	name := L.CheckString(2) | ||||
| 	ud := L.NewUserData() | ||||
|  | ||||
| 	s, ok := parent.Series[name] | ||||
| 	if ok { | ||||
| 		ud.Value = s | ||||
| 	} else { | ||||
| 		parent.Series[name] = &Series{ | ||||
| 			Series: make(map[string]*Series), | ||||
| 			Values: make([]float64, cap(parent.Values)), | ||||
| 		} | ||||
| 		ud.Value = parent.Series[name] | ||||
| 	} | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaSeriesTypeName)) | ||||
| 	L.Push(ud) | ||||
| 	return 1 | ||||
| } | ||||
							
								
								
									
										38
									
								
								internal/handlers/scripts/gen_cusip_csv.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										38
									
								
								internal/handlers/scripts/gen_cusip_csv.sh
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| #!/bin/bash | ||||
| QUARTER=2017q1 | ||||
|  | ||||
| function get_ticker() { | ||||
| 	local cusip=$1 | ||||
|  | ||||
| 	local tmpfile=$tmpdir/curl_tmpfile | ||||
| 	curl -s -d "sopt=cusip&tickersymbol=${cusip}" http://quantumonline.com/search.cfm > $tmpfile | ||||
| 	local quantum_name=$(sed -rn 's@<font size="\+1"><center><b>(.+)</b><br></center></font>\s*$@\1@p' $tmpfile | head -n1) | ||||
| 	local quantum_ticker=$(sed -rn 's@^.*Ticker Symbol: ([A-Z\.0-9\-]+)     CUSIP.*$@\1@p' $tmpfile | head -n1) | ||||
|  | ||||
| 	if [[ -z $quantum_ticker ]] || [[ -z $quantum_name ]]; then | ||||
| 		curl -s -d "reqforlookup=REQUESTFORLOOKUP&productid=mmnet&isLoggedIn=mmnet&rows=50&for=stock&by=cusip&criteria=${cusip}&submit=Search" http://quotes.fidelity.com/mmnet/SymLookup.phtml > $tmpfile | ||||
| 		fidelity_name=$(sed -rn 's@<tr><td height="20" nowrap><font class="smallfont">(.+)</font></td>\s*@\1@p' $tmpfile | sed -r 's/\&/\&/') | ||||
| 		fidelity_ticker=$(sed -rn 's@\s+<td align="center" width="20%"><font><a href="/webxpress/get_quote\?QUOTE_TYPE=\&SID_VALUE_ID=(.+)">(.+)</a></td>\s*@\1@p' $tmpfile | head -n1) | ||||
| 		if [[ -z $fidelity_ticker ]] || [[ -z $fidelity_name ]]; then | ||||
| 			echo $cusip >> $tmpdir/${QUARTER}_bad_cusips.csv | ||||
| 		else | ||||
| 			echo "$cusip,$fidelity_ticker,$fidelity_name" | ||||
| 		fi | ||||
| 	else | ||||
| 		echo "$cusip,$quantum_ticker,$quantum_name" | ||||
| 	fi | ||||
| } | ||||
|  | ||||
| tmpdir=$(mktemp -d -p $PWD) | ||||
|  | ||||
| # Get the list of CUSIPs from the SEC and generate a nicer format of it | ||||
| wget -q http://www.sec.gov/divisions/investment/13f/13flist${QUARTER}.pdf -O $tmpdir/13flist${QUARTER}.pdf | ||||
| pdftotext -layout $tmpdir/13flist${QUARTER}.pdf - > $tmpdir/13flist${QUARTER}.txt | ||||
| sed -rn 's/^([A-Z0-9]{6}) ([A-Z0-9]{2}) ([A-Z0-9]) .*$/\1\2\3/p' $tmpdir/13flist${QUARTER}.txt > $tmpdir/${QUARTER}_cusips | ||||
|  | ||||
| # Find tickers and names for all the CUSIPs we can and print them out | ||||
| for cusip in $(cat $tmpdir/${QUARTER}_cusips); do | ||||
| 	get_ticker $cusip | ||||
| done | ||||
|  | ||||
| rm -rf $tmpdir | ||||
							
								
								
									
										114
									
								
								internal/handlers/scripts/gen_security_list.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										114
									
								
								internal/handlers/scripts/gen_security_list.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,114 @@ | ||||
| #!/usr/bin/env python | ||||
|  | ||||
| import csv | ||||
| from xml.dom import minidom | ||||
| import sys | ||||
|  | ||||
| if sys.version_info[0] < 3: | ||||
|     from urllib2 import urlopen | ||||
|  | ||||
|     # Allow writing utf-8 to stdout | ||||
|     import codecs | ||||
|     UTF8Writer = codecs.getwriter('utf8') | ||||
|     sys.stdout = UTF8Writer(sys.stdout) | ||||
| else: | ||||
|     from urllib.request import urlopen | ||||
|  | ||||
|     # This is absent, but also unneeded in python3, so just return the string | ||||
|     def unicode(s, encoding): | ||||
|         return s | ||||
|  | ||||
| class Security(object): | ||||
|     def __init__(self, name, description, number, _type, precision): | ||||
|         self.name = name | ||||
|         self.description = description | ||||
|         self.number = number | ||||
|         self.type = _type | ||||
|         self.precision = precision | ||||
|     def unicode(self): | ||||
|         s = """\tSecurity{ | ||||
| \t\tName: \"%s\", | ||||
| \t\tDescription: \"%s\", | ||||
| \t\tSymbol: \"%s\", | ||||
| \t\tPrecision: %d, | ||||
| \t\tType: %s, | ||||
| \t\tAlternateId: \"%s\"},\n""" % (self.name, self.description, self.name, self.precision, self.type, str(self.number)) | ||||
|         try: | ||||
|             return unicode(s, 'utf_8') | ||||
|         except TypeError: | ||||
|             return s | ||||
|  | ||||
| class SecurityList(object): | ||||
|     def __init__(self, comment): | ||||
|         self.comment = comment | ||||
|         self.currencies = {} | ||||
|     def add(self, currency): | ||||
|         self.currencies[currency.number] = currency | ||||
|     def unicode(self): | ||||
|         string = "\t// "+self.comment+"\n" | ||||
|         for key in sorted(self.currencies.keys()): | ||||
|             string += self.currencies[key].unicode() | ||||
|         return string | ||||
|  | ||||
| def process_ccyntry(currency_list, node): | ||||
|     name = "" | ||||
|     nameSet = False | ||||
|     number = 0 | ||||
|     numberSet = False | ||||
|     description = "" | ||||
|     precision = 0 | ||||
|     for n in node.childNodes: | ||||
|         if n.nodeName == "Ccy": | ||||
|             name = n.firstChild.nodeValue | ||||
|             nameSet = True | ||||
|         elif n.nodeName == "CcyNm": | ||||
|             description = n.firstChild.nodeValue | ||||
|         elif n.nodeName == "CcyNbr": | ||||
|             number = int(n.firstChild.nodeValue) | ||||
|             numberSet = True | ||||
|         elif n.nodeName == "CcyMnrUnts": | ||||
|             if n.firstChild.nodeValue == "N.A.": | ||||
|                 precision = 0 | ||||
|             else: | ||||
|                 precision = int(n.firstChild.nodeValue) | ||||
|     if nameSet and numberSet: | ||||
|         currency_list.add(Security(name, description, number, "Currency", precision)) | ||||
|  | ||||
| def get_currency_list(): | ||||
|     currency_list = SecurityList("ISO 4217, from http://www.currency-iso.org/en/home/tables/table-a1.html") | ||||
|  | ||||
|     f = urlopen('http://www.currency-iso.org/dam/downloads/lists/list_one.xml') | ||||
|     xmldoc = minidom.parse(f) | ||||
|     for isonode in xmldoc.childNodes: | ||||
|         if isonode.nodeName == "ISO_4217": | ||||
|             for ccytblnode in isonode.childNodes: | ||||
|                 if ccytblnode.nodeName == "CcyTbl": | ||||
|                     for ccyntrynode in ccytblnode.childNodes: | ||||
|                         if ccyntrynode.nodeName == "CcyNtry": | ||||
|                             process_ccyntry(currency_list, ccyntrynode) | ||||
|     f.close() | ||||
|     return currency_list | ||||
|  | ||||
| def get_cusip_list(filename): | ||||
|     cusip_list = SecurityList("") | ||||
|     with open(filename) as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter=',') | ||||
|         for row in csvreader: | ||||
|             cusip = row[0] | ||||
|             name = row[1] | ||||
|             description = ",".join(row[2:]) | ||||
|             cusip_list.add(Security(name, description, cusip, "Stock", 5)) | ||||
|     return cusip_list | ||||
|  | ||||
| def main(): | ||||
|     currency_list = get_currency_list() | ||||
|     cusip_list = get_cusip_list('cusip_list.csv') | ||||
|  | ||||
|     print("package handlers\n") | ||||
|     print("var SecurityTemplates = []Security{") | ||||
|     print(currency_list.unicode()) | ||||
|     print(cusip_list.unicode()) | ||||
|     print("}") | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
							
								
								
									
										442
									
								
								internal/handlers/securities.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										442
									
								
								internal/handlers/securities.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,442 @@ | ||||
| package handlers | ||||
|  | ||||
| //go:generate make | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"gopkg.in/gorp.v1" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	Currency int64 = 1 | ||||
| 	Stock          = 2 | ||||
| ) | ||||
|  | ||||
| func GetSecurityType(typestring string) int64 { | ||||
| 	if strings.EqualFold(typestring, "currency") { | ||||
| 		return Currency | ||||
| 	} else if strings.EqualFold(typestring, "stock") { | ||||
| 		return Stock | ||||
| 	} else { | ||||
| 		return 0 | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type Security struct { | ||||
| 	SecurityId  int64 | ||||
| 	UserId      int64 | ||||
| 	Name        string | ||||
| 	Description string | ||||
| 	Symbol      string | ||||
| 	// Number of decimal digits (to the right of the decimal point) this | ||||
| 	// security is precise to | ||||
| 	Precision int | ||||
| 	Type      int64 | ||||
| 	// AlternateId is CUSIP for Type=Stock, ISO4217 for Type=Currency | ||||
| 	AlternateId string | ||||
| } | ||||
|  | ||||
| type SecurityList struct { | ||||
| 	Securities *[]*Security `json:"securities"` | ||||
| } | ||||
|  | ||||
| func (s *Security) Read(json_str string) error { | ||||
| 	dec := json.NewDecoder(strings.NewReader(json_str)) | ||||
| 	return dec.Decode(s) | ||||
| } | ||||
|  | ||||
| func (s *Security) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(s) | ||||
| } | ||||
|  | ||||
| func (sl *SecurityList) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(sl) | ||||
| } | ||||
|  | ||||
| func SearchSecurityTemplates(search string, _type int64, limit int64) []*Security { | ||||
| 	upperSearch := strings.ToUpper(search) | ||||
| 	var results []*Security | ||||
| 	for i, security := range SecurityTemplates { | ||||
| 		if strings.Contains(strings.ToUpper(security.Name), upperSearch) || | ||||
| 			strings.Contains(strings.ToUpper(security.Description), upperSearch) || | ||||
| 			strings.Contains(strings.ToUpper(security.Symbol), upperSearch) { | ||||
| 			if _type == 0 || _type == security.Type { | ||||
| 				results = append(results, &SecurityTemplates[i]) | ||||
| 				if limit != -1 && int64(len(results)) >= limit { | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return results | ||||
| } | ||||
|  | ||||
| func FindSecurityTemplate(name string, _type int64) *Security { | ||||
| 	for _, security := range SecurityTemplates { | ||||
| 		if name == security.Name && _type == security.Type { | ||||
| 			return &security | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func FindCurrencyTemplate(iso4217 int64) *Security { | ||||
| 	iso4217string := strconv.FormatInt(iso4217, 10) | ||||
| 	for _, security := range SecurityTemplates { | ||||
| 		if security.Type == Currency && security.AlternateId == iso4217string { | ||||
| 			return &security | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) { | ||||
| 	var s Security | ||||
|  | ||||
| 	err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64) (*Security, error) { | ||||
| 	var s Security | ||||
|  | ||||
| 	err := transaction.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func GetSecurities(db *DB, userid int64) (*[]*Security, error) { | ||||
| 	var securities []*Security | ||||
|  | ||||
| 	_, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &securities, nil | ||||
| } | ||||
|  | ||||
| func InsertSecurity(db *DB, s *Security) error { | ||||
| 	err := db.Insert(s) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error { | ||||
| 	err := transaction.Insert(s) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func UpdateSecurity(db *DB, s *Security) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	user, err := GetUserTx(transaction, s.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} else if user.DefaultCurrency == s.SecurityId && s.Type != Currency { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Cannot change security which is user's default currency to be non-currency") | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Update(s) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Updated more than one security") | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func DeleteSecurity(db *DB, s *Security) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// First, ensure no accounts are using this security | ||||
| 	accounts, err := transaction.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) | ||||
|  | ||||
| 	if accounts != 0 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("One or more accounts still use this security") | ||||
| 	} | ||||
|  | ||||
| 	user, err := GetUserTx(transaction, s.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} else if user.DefaultCurrency == s.SecurityId { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Cannot delete security which is user's default currency") | ||||
| 	} | ||||
|  | ||||
| 	// Remove all prices involving this security (either of this security, or | ||||
| 	// using it as a currency) | ||||
| 	_, err = transaction.Exec("DELETE * FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Delete(s) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Deleted more than one security") | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, security *Security) (*Security, error) { | ||||
| 	security.UserId = userid | ||||
| 	if len(security.AlternateId) == 0 { | ||||
| 		// Always create a new local security if we can't match on the AlternateId | ||||
| 		err := InsertSecurityTx(transaction, security) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		return security, nil | ||||
| 	} | ||||
|  | ||||
| 	var securities []*Security | ||||
|  | ||||
| 	_, err := transaction.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Precision=?", userid, security.Type, security.AlternateId, security.Precision) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// First try to find a case insensitive match on the name or symbol | ||||
| 	upperName := strings.ToUpper(security.Name) | ||||
| 	upperSymbol := strings.ToUpper(security.Symbol) | ||||
| 	for _, s := range securities { | ||||
| 		if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) || | ||||
| 			(len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) { | ||||
| 			return s, nil | ||||
| 		} | ||||
| 	} | ||||
| 	//		if strings.Contains(strings.ToUpper(security.Name), upperSearch) || | ||||
|  | ||||
| 	// Try to find a partial string match on the name or symbol | ||||
| 	for _, s := range securities { | ||||
| 		sUpperName := strings.ToUpper(s.Name) | ||||
| 		sUpperSymbol := strings.ToUpper(s.Symbol) | ||||
| 		if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) || | ||||
| 			(len(upperSymbol) > 0 && len(s.Symbol) > 0 && (strings.Contains(upperSymbol, sUpperSymbol) || strings.Contains(sUpperSymbol, upperSymbol))) { | ||||
| 			return s, nil | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Give up and return the first security in the list | ||||
| 	if len(securities) > 0 { | ||||
| 		return securities[0], nil | ||||
| 	} | ||||
|  | ||||
| 	// If there wasn't even one security in the list, make a new one | ||||
| 	err = InsertSecurityTx(transaction, security) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return security, nil | ||||
| } | ||||
|  | ||||
| func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	user, err := GetUserFromSession(db, r) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 1 /*Not Signed In*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if r.Method == "POST" { | ||||
| 		security_json := r.PostFormValue("security") | ||||
| 		if security_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		var security Security | ||||
| 		err := security.Read(security_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		security.SecurityId = -1 | ||||
| 		security.UserId = user.UserId | ||||
|  | ||||
| 		err = InsertSecurity(db, &security) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		w.WriteHeader(201 /*Created*/) | ||||
| 		err = security.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 	} else if r.Method == "GET" { | ||||
| 		var securityid int64 | ||||
| 		n, err := GetURLPieces(r.URL.Path, "/security/%d", &securityid) | ||||
|  | ||||
| 		if err != nil || n != 1 { | ||||
| 			//Return all securities | ||||
| 			var sl SecurityList | ||||
|  | ||||
| 			securities, err := GetSecurities(db, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			sl.Securities = securities | ||||
| 			err = (&sl).Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			security, err := GetSecurity(db, securityid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = security.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		securityid, err := GetURLID(r.URL.Path) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		if r.Method == "PUT" { | ||||
| 			security_json := r.PostFormValue("security") | ||||
| 			if security_json == "" { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			var security Security | ||||
| 			err := security.Read(security_json) | ||||
| 			if err != nil || security.SecurityId != securityid { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
| 			security.UserId = user.UserId | ||||
|  | ||||
| 			err = UpdateSecurity(db, &security) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = security.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			security, err := GetSecurity(db, securityid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteSecurity(db, security) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			WriteSuccess(w) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) { | ||||
| 	if r.Method == "GET" { | ||||
| 		var sl SecurityList | ||||
|  | ||||
| 		query, _ := url.ParseQuery(r.URL.RawQuery) | ||||
|  | ||||
| 		var limit int64 = -1 | ||||
| 		search := query.Get("search") | ||||
| 		_type := GetSecurityType(query.Get("type")) | ||||
|  | ||||
| 		limitstring := query.Get("limit") | ||||
| 		if limitstring != "" { | ||||
| 			limitint, err := strconv.ParseInt(limitstring, 10, 0) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
| 			limit = limitint | ||||
| 		} | ||||
|  | ||||
| 		securities := SearchSecurityTemplates(search, _type, limit) | ||||
|  | ||||
| 		sl.Securities = &securities | ||||
| 		err := (&sl).Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		WriteError(w, 3 /*Invalid Request*/) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										188
									
								
								internal/handlers/securities_lua.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										188
									
								
								internal/handlers/securities_lua.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,188 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| ) | ||||
|  | ||||
| const luaSecurityTypeName = "security" | ||||
|  | ||||
| func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { | ||||
| 	var security_map map[int64]*Security | ||||
|  | ||||
| 	ctx := L.Context() | ||||
|  | ||||
| 	db, ok := ctx.Value(dbContextKey).(*DB) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Couldn't find DB in lua's Context") | ||||
| 	} | ||||
|  | ||||
| 	security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security) | ||||
| 	if !ok { | ||||
| 		user, ok := ctx.Value(userContextKey).(*User) | ||||
| 		if !ok { | ||||
| 			return nil, errors.New("Couldn't find User in lua's Context") | ||||
| 		} | ||||
|  | ||||
| 		securities, err := GetSecurities(db, user.UserId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		security_map = make(map[int64]*Security) | ||||
| 		for i := range *securities { | ||||
| 			security_map[(*securities)[i].SecurityId] = (*securities)[i] | ||||
| 		} | ||||
|  | ||||
| 		ctx = context.WithValue(ctx, securitiesContextKey, security_map) | ||||
| 		L.SetContext(ctx) | ||||
| 	} | ||||
|  | ||||
| 	return security_map, nil | ||||
| } | ||||
|  | ||||
| func luaContextGetDefaultCurrency(L *lua.LState) (*Security, error) { | ||||
| 	security_map, err := luaContextGetSecurities(L) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	ctx := L.Context() | ||||
|  | ||||
| 	user, ok := ctx.Value(userContextKey).(*User) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Couldn't find User in lua's Context") | ||||
| 	} | ||||
|  | ||||
| 	if security, ok := security_map[user.DefaultCurrency]; ok { | ||||
| 		return security, nil | ||||
| 	} else { | ||||
| 		return nil, errors.New("DefaultCurrency not in lua security_map") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func luaGetDefaultCurrency(L *lua.LState) int { | ||||
| 	defcurrency, err := luaContextGetDefaultCurrency(L) | ||||
| 	if err != nil { | ||||
| 		panic("luaGetDefaultCurrency couldn't fetch default currency") | ||||
| 	} | ||||
|  | ||||
| 	L.Push(SecurityToLua(L, defcurrency)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaGetSecurities(L *lua.LState) int { | ||||
| 	security_map, err := luaContextGetSecurities(L) | ||||
| 	if err != nil { | ||||
| 		panic("luaGetSecurities couldn't fetch securities") | ||||
| 	} | ||||
|  | ||||
| 	table := L.NewTable() | ||||
|  | ||||
| 	for securityid := range security_map { | ||||
| 		table.RawSetInt(int(securityid), SecurityToLua(L, security_map[securityid])) | ||||
| 	} | ||||
|  | ||||
| 	L.Push(table) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaRegisterSecurities(L *lua.LState) { | ||||
| 	mt := L.NewTypeMetatable(luaSecurityTypeName) | ||||
| 	L.SetGlobal("security", mt) | ||||
| 	L.SetField(mt, "__index", L.NewFunction(luaSecurity__index)) | ||||
| 	L.SetField(mt, "__tostring", L.NewFunction(luaSecurity__tostring)) | ||||
| 	L.SetField(mt, "__eq", L.NewFunction(luaSecurity__eq)) | ||||
| 	L.SetField(mt, "__metatable", lua.LString("protected")) | ||||
| 	getSecuritiesFn := L.NewFunction(luaGetSecurities) | ||||
| 	L.SetField(mt, "get_all", getSecuritiesFn) | ||||
| 	getDefaultCurrencyFn := L.NewFunction(luaGetDefaultCurrency) | ||||
| 	L.SetField(mt, "get_default", getDefaultCurrencyFn) | ||||
|  | ||||
| 	// also register the get_securities and get_default functions as globals in | ||||
| 	// their own right | ||||
| 	L.SetGlobal("get_securities", getSecuritiesFn) | ||||
| 	L.SetGlobal("get_default_currency", getDefaultCurrencyFn) | ||||
| } | ||||
|  | ||||
| func SecurityToLua(L *lua.LState, security *Security) *lua.LUserData { | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = security | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaSecurityTypeName)) | ||||
| 	return ud | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Security and returns this *Security. | ||||
| func luaCheckSecurity(L *lua.LState, n int) *Security { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if security, ok := ud.Value.(*Security); ok { | ||||
| 		return security | ||||
| 	} | ||||
| 	L.ArgError(n, "security expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaSecurity__index(L *lua.LState) int { | ||||
| 	a := luaCheckSecurity(L, 1) | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "SecurityId", "securityid": | ||||
| 		L.Push(lua.LNumber(float64(a.SecurityId))) | ||||
| 	case "Name", "name": | ||||
| 		L.Push(lua.LString(a.Name)) | ||||
| 	case "Description", "description": | ||||
| 		L.Push(lua.LString(a.Description)) | ||||
| 	case "Symbol", "symbol": | ||||
| 		L.Push(lua.LString(a.Symbol)) | ||||
| 	case "Precision", "precision": | ||||
| 		L.Push(lua.LNumber(float64(a.Precision))) | ||||
| 	case "Type", "type": | ||||
| 		L.Push(lua.LNumber(float64(a.Type))) | ||||
| 	case "ClosestPrice", "closestprice": | ||||
| 		L.Push(L.NewFunction(luaClosestPrice)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected security attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaClosestPrice(L *lua.LState) int { | ||||
| 	s := luaCheckSecurity(L, 1) | ||||
| 	c := luaCheckSecurity(L, 2) | ||||
| 	date := luaCheckTime(L, 3) | ||||
|  | ||||
| 	ctx := L.Context() | ||||
| 	db, ok := ctx.Value(dbContextKey).(*DB) | ||||
| 	if !ok { | ||||
| 		panic("Couldn't find DB in lua's Context") | ||||
| 	} | ||||
|  | ||||
| 	p, err := GetClosestPrice(db, s, c, date) | ||||
| 	if err != nil { | ||||
| 		L.Push(lua.LNil) | ||||
| 	} else { | ||||
| 		L.Push(PriceToLua(L, p)) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaSecurity__tostring(L *lua.LState) int { | ||||
| 	s := luaCheckSecurity(L, 1) | ||||
|  | ||||
| 	L.Push(lua.LString(s.Name + " - " + s.Description + " (" + s.Symbol + ")")) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaSecurity__eq(L *lua.LState) int { | ||||
| 	a := luaCheckSecurity(L, 1) | ||||
| 	b := luaCheckSecurity(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.SecurityId == b.SecurityId)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
							
								
								
									
										139
									
								
								internal/handlers/sessions.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								internal/handlers/sessions.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type Session struct { | ||||
| 	SessionId     int64 | ||||
| 	SessionSecret string `json:"-"` | ||||
| 	UserId        int64 | ||||
| } | ||||
|  | ||||
| func (s *Session) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(s) | ||||
| } | ||||
|  | ||||
| func GetSession(db *DB, r *http.Request) (*Session, error) { | ||||
| 	var s Session | ||||
|  | ||||
| 	cookie, err := r.Cookie("moneygo-session") | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("moneygo-session cookie not set") | ||||
| 	} | ||||
| 	s.SessionSecret = cookie.Value | ||||
|  | ||||
| 	err = db.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func DeleteSessionIfExists(db *DB, r *http.Request) { | ||||
| 	// TODO do this in one transaction | ||||
| 	session, err := GetSession(db, r) | ||||
| 	if err == nil { | ||||
| 		db.Delete(session) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func NewSessionCookie() (string, error) { | ||||
| 	bits := make([]byte, 128) | ||||
| 	if _, err := io.ReadFull(rand.Reader, bits); err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	return base64.StdEncoding.EncodeToString(bits), nil | ||||
| } | ||||
|  | ||||
| func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) { | ||||
| 	s := Session{} | ||||
|  | ||||
| 	session_secret, err := NewSessionCookie() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	cookie := http.Cookie{ | ||||
| 		Name:     "moneygo-session", | ||||
| 		Value:    session_secret, | ||||
| 		Path:     "/", | ||||
| 		Domain:   r.URL.Host, | ||||
| 		Expires:  time.Now().AddDate(0, 1, 0), // a month from now | ||||
| 		Secure:   true, | ||||
| 		HttpOnly: true, | ||||
| 	} | ||||
| 	http.SetCookie(w, &cookie) | ||||
|  | ||||
| 	s.SessionSecret = session_secret | ||||
| 	s.UserId = userid | ||||
|  | ||||
| 	err = db.Insert(&s) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func SessionHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	if r.Method == "POST" || r.Method == "PUT" { | ||||
| 		user_json := r.PostFormValue("user") | ||||
| 		if user_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		user := User{} | ||||
| 		err := user.Read(user_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		dbuser, err := GetUserByUsername(db, user.Username) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 2 /*Unauthorized Access*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		user.HashPassword() | ||||
| 		if user.PasswordHash != dbuser.PasswordHash { | ||||
| 			WriteError(w, 2 /*Unauthorized Access*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		DeleteSessionIfExists(db, r) | ||||
|  | ||||
| 		session, err := NewSession(db, w, r, dbuser.UserId) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		err = session.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 	} else if r.Method == "GET" { | ||||
| 		s, err := GetSession(db, r) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 1 /*Not Signed In*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		s.Write(w) | ||||
| 	} else if r.Method == "DELETE" { | ||||
| 		DeleteSessionIfExists(db, r) | ||||
| 		WriteSuccess(w) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										945
									
								
								internal/handlers/transactions.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										945
									
								
								internal/handlers/transactions.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,945 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"gopkg.in/gorp.v1" | ||||
| 	"log" | ||||
| 	"math/big" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // Split.Status | ||||
| const ( | ||||
| 	Imported   int64 = 1 | ||||
| 	Entered          = 2 | ||||
| 	Cleared          = 3 | ||||
| 	Reconciled       = 4 | ||||
| 	Voided           = 5 | ||||
| ) | ||||
|  | ||||
| // Split.ImportSplitType | ||||
| const ( | ||||
| 	Default         int64 = 0 | ||||
| 	ImportAccount         = 1 // This split belongs to the main account being imported | ||||
| 	SubAccount            = 2 // This split belongs to a sub-account of that being imported | ||||
| 	ExternalAccount       = 3 | ||||
| 	TradingAccount        = 4 | ||||
| 	Commission            = 5 | ||||
| 	Taxes                 = 6 | ||||
| 	Fees                  = 7 | ||||
| 	Load                  = 8 | ||||
| 	IncomeAccount         = 9 | ||||
| 	ExpenseAccount        = 10 | ||||
| ) | ||||
|  | ||||
| type Split struct { | ||||
| 	SplitId         int64 | ||||
| 	TransactionId   int64 | ||||
| 	Status          int64 | ||||
| 	ImportSplitType int64 | ||||
|  | ||||
| 	// One of AccountId and SecurityId must be -1 | ||||
| 	// In normal splits, AccountId will be valid and SecurityId will be -1. The | ||||
| 	// only case where this is reversed is for transactions that have been | ||||
| 	// imported and not yet associated with an account. | ||||
| 	AccountId  int64 | ||||
| 	SecurityId int64 | ||||
|  | ||||
| 	RemoteId string // unique ID from server, for detecting duplicates | ||||
| 	Number   string // Check or reference number | ||||
| 	Memo     string | ||||
| 	Amount   string // String representation of decimal, suitable for passing to big.Rat.SetString() | ||||
| } | ||||
|  | ||||
| func GetBigAmount(amt string) (*big.Rat, error) { | ||||
| 	var r big.Rat | ||||
| 	_, success := r.SetString(amt) | ||||
| 	if !success { | ||||
| 		return nil, errors.New("Couldn't convert string amount to big.Rat via SetString()") | ||||
| 	} | ||||
| 	return &r, nil | ||||
| } | ||||
|  | ||||
| func (s *Split) GetAmount() (*big.Rat, error) { | ||||
| 	return GetBigAmount(s.Amount) | ||||
| } | ||||
|  | ||||
| func (s *Split) Valid() bool { | ||||
| 	if (s.AccountId == -1) == (s.SecurityId == -1) { | ||||
| 		return false | ||||
| 	} | ||||
| 	_, err := s.GetAmount() | ||||
| 	return err == nil | ||||
| } | ||||
|  | ||||
| func (s *Split) AlreadyImportedTx(transaction *gorp.Transaction) (bool, error) { | ||||
| 	count, err := transaction.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) | ||||
| 	return count == 1, err | ||||
| } | ||||
|  | ||||
| type Transaction struct { | ||||
| 	TransactionId int64 | ||||
| 	UserId        int64 | ||||
| 	Description   string | ||||
| 	Date          time.Time | ||||
| 	Splits        []*Split `db:"-"` | ||||
| } | ||||
|  | ||||
| type TransactionList struct { | ||||
| 	Transactions *[]Transaction `json:"transactions"` | ||||
| } | ||||
|  | ||||
| type AccountTransactionsList struct { | ||||
| 	Account           *Account | ||||
| 	Transactions      *[]Transaction | ||||
| 	TotalTransactions int64 | ||||
| 	BeginningBalance  string | ||||
| 	EndingBalance     string | ||||
| } | ||||
|  | ||||
| func (t *Transaction) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(t) | ||||
| } | ||||
|  | ||||
| func (t *Transaction) Read(json_str string) error { | ||||
| 	dec := json.NewDecoder(strings.NewReader(json_str)) | ||||
| 	return dec.Decode(t) | ||||
| } | ||||
|  | ||||
| func (tl *TransactionList) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(tl) | ||||
| } | ||||
|  | ||||
| func (atl *AccountTransactionsList) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(atl) | ||||
| } | ||||
|  | ||||
| func (t *Transaction) Valid() bool { | ||||
| 	for i := range t.Splits { | ||||
| 		if !t.Splits[i].Valid() { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // Return a map of security ID's to big.Rat's containing the amount that | ||||
| // security is imbalanced by | ||||
| func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]big.Rat, error) { | ||||
| 	sums := make(map[int64]big.Rat) | ||||
|  | ||||
| 	if !t.Valid() { | ||||
| 		return nil, errors.New("Transaction invalid") | ||||
| 	} | ||||
|  | ||||
| 	for i := range t.Splits { | ||||
| 		securityid := t.Splits[i].SecurityId | ||||
| 		if t.Splits[i].AccountId != -1 { | ||||
| 			var err error | ||||
| 			var account *Account | ||||
| 			account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| 			securityid = account.SecurityId | ||||
| 		} | ||||
| 		amount, _ := t.Splits[i].GetAmount() | ||||
| 		sum := sums[securityid] | ||||
| 		(&sum).Add(&sum, amount) | ||||
| 		sums[securityid] = sum | ||||
| 	} | ||||
| 	return sums, nil | ||||
| } | ||||
|  | ||||
| // Returns true if all securities contained in this transaction are balanced, | ||||
| // false otherwise | ||||
| func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) { | ||||
| 	var zero big.Rat | ||||
|  | ||||
| 	sums, err := t.GetImbalancesTx(transaction) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|  | ||||
| 	for _, security_sum := range sums { | ||||
| 		if security_sum.Cmp(&zero) != 0 { | ||||
| 			return false, nil | ||||
| 		} | ||||
| 	} | ||||
| 	return true, nil | ||||
| } | ||||
|  | ||||
| func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) { | ||||
| 	var t Transaction | ||||
|  | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	_, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &t, nil | ||||
| } | ||||
|  | ||||
| func GetTransactions(db *DB, userid int64) (*[]Transaction, error) { | ||||
| 	var transactions []Transaction | ||||
|  | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	_, err = transaction.Select(&transactions, "SELECT * from transactions where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	for i := range transactions { | ||||
| 		_, err := transaction.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &transactions, nil | ||||
| } | ||||
|  | ||||
| func incrementAccountVersions(transaction *gorp.Transaction, user *User, accountids []int64) error { | ||||
| 	for i := range accountids { | ||||
| 		account, err := GetAccountTx(transaction, accountids[i], user.UserId) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		account.AccountVersion++ | ||||
| 		count, err := transaction.Update(account) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if count != 1 { | ||||
| 			return errors.New("Updated more than one account") | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type AccountMissingError struct{} | ||||
|  | ||||
| func (ame AccountMissingError) Error() string { | ||||
| 	return "Account missing" | ||||
| } | ||||
|  | ||||
| func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error { | ||||
| 	// Map of any accounts with transaction splits being added | ||||
| 	a_map := make(map[int64]bool) | ||||
| 	for i := range t.Splits { | ||||
| 		if t.Splits[i].AccountId != -1 { | ||||
| 			existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if existing != 1 { | ||||
| 				return AccountMissingError{} | ||||
| 			} | ||||
| 			a_map[t.Splits[i].AccountId] = true | ||||
| 		} else if t.Splits[i].SecurityId == -1 { | ||||
| 			return AccountMissingError{} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	//increment versions for all accounts | ||||
| 	var a_ids []int64 | ||||
| 	for id := range a_map { | ||||
| 		a_ids = append(a_ids, id) | ||||
| 	} | ||||
| 	// ensure at least one of the splits is associated with an actual account | ||||
| 	if len(a_ids) < 1 { | ||||
| 		return AccountMissingError{} | ||||
| 	} | ||||
| 	err := incrementAccountVersions(transaction, user, a_ids) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	t.UserId = user.UserId | ||||
| 	err = transaction.Insert(t) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for i := range t.Splits { | ||||
| 		t.Splits[i].TransactionId = t.TransactionId | ||||
| 		t.Splits[i].SplitId = -1 | ||||
| 		err = transaction.Insert(t.Splits[i]) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func InsertTransaction(db *DB, t *Transaction, user *User) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = InsertTransactionTx(transaction, t, user) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error { | ||||
| 	var existing_splits []*Split | ||||
|  | ||||
| 	_, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// Map of any accounts with transaction splits being added | ||||
| 	a_map := make(map[int64]bool) | ||||
|  | ||||
| 	// Make a map with any existing splits for this transaction | ||||
| 	s_map := make(map[int64]bool) | ||||
| 	for i := range existing_splits { | ||||
| 		s_map[existing_splits[i].SplitId] = true | ||||
| 	} | ||||
|  | ||||
| 	// Insert splits, updating any pre-existing ones | ||||
| 	for i := range t.Splits { | ||||
| 		t.Splits[i].TransactionId = t.TransactionId | ||||
| 		_, ok := s_map[t.Splits[i].SplitId] | ||||
| 		if ok { | ||||
| 			count, err := transaction.Update(t.Splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if count != 1 { | ||||
| 				return errors.New("Updated more than one transaction split") | ||||
| 			} | ||||
| 			delete(s_map, t.Splits[i].SplitId) | ||||
| 		} else { | ||||
| 			t.Splits[i].SplitId = -1 | ||||
| 			err := transaction.Insert(t.Splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 		if t.Splits[i].AccountId != -1 { | ||||
| 			a_map[t.Splits[i].AccountId] = true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Delete any remaining pre-existing splits | ||||
| 	for i := range existing_splits { | ||||
| 		_, ok := s_map[existing_splits[i].SplitId] | ||||
| 		if existing_splits[i].AccountId != -1 { | ||||
| 			a_map[existing_splits[i].AccountId] = true | ||||
| 		} | ||||
| 		if ok { | ||||
| 			_, err := transaction.Delete(existing_splits[i]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Increment versions for all accounts with modified splits | ||||
| 	var a_ids []int64 | ||||
| 	for id := range a_map { | ||||
| 		a_ids = append(a_ids, id) | ||||
| 	} | ||||
| 	err = incrementAccountVersions(transaction, user, a_ids) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Update(t) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		return errors.New("Updated more than one transaction") | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func DeleteTransaction(db *DB, t *Transaction, user *User) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var accountids []int64 | ||||
| 	_, err = transaction.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	_, err = transaction.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Delete(t) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Deleted more than one transaction") | ||||
| 	} | ||||
|  | ||||
| 	err = incrementAccountVersions(transaction, user, accountids) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	user, err := GetUserFromSession(db, r) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 1 /*Not Signed In*/) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if r.Method == "POST" { | ||||
| 		transaction_json := r.PostFormValue("transaction") | ||||
| 		if transaction_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		var transaction Transaction | ||||
| 		err := transaction.Read(transaction_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		transaction.TransactionId = -1 | ||||
| 		transaction.UserId = user.UserId | ||||
|  | ||||
| 		sqltx, err := db.Begin() | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		balanced, err := transaction.Balanced(sqltx) | ||||
| 		if err != nil { | ||||
| 			sqltx.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 		if !transaction.Valid() || !balanced { | ||||
| 			sqltx.Rollback() | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		for i := range transaction.Splits { | ||||
| 			transaction.Splits[i].SplitId = -1 | ||||
| 			_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId) | ||||
| 			if err != nil { | ||||
| 				sqltx.Rollback() | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		err = InsertTransactionTx(sqltx, &transaction, user) | ||||
| 		if err != nil { | ||||
| 			if _, ok := err.(AccountMissingError); ok { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 			} else { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 			} | ||||
| 			sqltx.Rollback() | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		err = sqltx.Commit() | ||||
| 		if err != nil { | ||||
| 			sqltx.Rollback() | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		err = transaction.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 	} else if r.Method == "GET" { | ||||
| 		transactionid, err := GetURLID(r.URL.Path) | ||||
|  | ||||
| 		if err != nil { | ||||
| 			//Return all Transactions | ||||
| 			var al TransactionList | ||||
| 			transactions, err := GetTransactions(db, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			al.Transactions = transactions | ||||
| 			err = (&al).Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			//Return Transaction with this Id | ||||
| 			transaction, err := GetTransaction(db, transactionid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
| 			err = transaction.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		transactionid, err := GetURLID(r.URL.Path) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		if r.Method == "PUT" { | ||||
| 			transaction_json := r.PostFormValue("transaction") | ||||
| 			if transaction_json == "" { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			var transaction Transaction | ||||
| 			err := transaction.Read(transaction_json) | ||||
| 			if err != nil || transaction.TransactionId != transactionid { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
| 			transaction.UserId = user.UserId | ||||
|  | ||||
| 			sqltx, err := db.Begin() | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			balanced, err := transaction.Balanced(sqltx) | ||||
| 			if err != nil { | ||||
| 				sqltx.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 			if !transaction.Valid() || !balanced { | ||||
| 				sqltx.Rollback() | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			for i := range transaction.Splits { | ||||
| 				_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId) | ||||
| 				if err != nil { | ||||
| 					sqltx.Rollback() | ||||
| 					WriteError(w, 3 /*Invalid Request*/) | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			err = UpdateTransactionTx(sqltx, &transaction, user) | ||||
| 			if err != nil { | ||||
| 				sqltx.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = sqltx.Commit() | ||||
| 			if err != nil { | ||||
| 				sqltx.Rollback() | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = transaction.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			transactionid, err := GetURLID(r.URL.Path) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			transaction, err := GetTransaction(db, transactionid, user.UserId) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = DeleteTransaction(db, transaction, user) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			WriteSuccess(w) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int64, transactions []Transaction) (*big.Rat, error) { | ||||
| 	var pageDifference, tmp big.Rat | ||||
| 	for i := range transactions { | ||||
| 		_, err := transaction.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		// Sum up the amounts from the splits we're returning so we can return | ||||
| 		// an ending balance | ||||
| 		for j := range transactions[i].Splits { | ||||
| 			if transactions[i].Splits[j].AccountId == accountid { | ||||
| 				rat_amount, err := GetBigAmount(transactions[i].Splits[j].Amount) | ||||
| 				if err != nil { | ||||
| 					return nil, err | ||||
| 				} | ||||
| 				tmp.Add(&pageDifference, rat_amount) | ||||
| 				pageDifference.Set(&tmp) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return &pageDifference, nil | ||||
| } | ||||
|  | ||||
| func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) { | ||||
| 	var splits []Split | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" | ||||
| 	_, err = transaction.Select(&splits, sql, accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var balance, tmp big.Rat | ||||
| 	for _, s := range splits { | ||||
| 		rat_amount, err := GetBigAmount(s.Amount) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &balance, nil | ||||
| } | ||||
|  | ||||
| // Assumes accountid is valid and is owned by the current user | ||||
| func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) { | ||||
| 	var splits []Split | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" | ||||
| 	_, err = transaction.Select(&splits, sql, accountid, user.UserId, date) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var balance, tmp big.Rat | ||||
| 	for _, s := range splits { | ||||
| 		rat_amount, err := GetBigAmount(s.Amount) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &balance, nil | ||||
| } | ||||
|  | ||||
| func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { | ||||
| 	var splits []Split | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" | ||||
| 	_, err = transaction.Select(&splits, sql, accountid, user.UserId, begin, end) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var balance, tmp big.Rat | ||||
| 	for _, s := range splits { | ||||
| 		rat_amount, err := GetBigAmount(s.Amount) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &balance, nil | ||||
| } | ||||
|  | ||||
| func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { | ||||
| 	var transactions []Transaction | ||||
| 	var atl AccountTransactionsList | ||||
|  | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var sqlsort, balanceLimitOffset string | ||||
| 	var balanceLimitOffsetArg uint64 | ||||
| 	if sort == "date-asc" { | ||||
| 		sqlsort = " ORDER BY transactions.Date ASC" | ||||
| 		balanceLimitOffset = " LIMIT ?" | ||||
| 		balanceLimitOffsetArg = page * limit | ||||
| 	} else if sort == "date-desc" { | ||||
| 		numSplits, err := transaction.SelectInt("SELECT count(*) FROM splits") | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		sqlsort = " ORDER BY transactions.Date DESC" | ||||
| 		balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits) | ||||
| 		balanceLimitOffsetArg = (page + 1) * limit | ||||
| 	} | ||||
|  | ||||
| 	var sqloffset string | ||||
| 	if page > 0 { | ||||
| 		sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) | ||||
| 	} | ||||
|  | ||||
| 	account, err := GetAccountTx(transaction, accountid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.Account = account | ||||
|  | ||||
| 	sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset | ||||
| 	_, err = transaction.Select(&transactions, sql, user.UserId, accountid, limit) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.Transactions = &transactions | ||||
|  | ||||
| 	pageDifference, err := TransactionsBalanceDifference(transaction, accountid, transactions) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	atl.TotalTransactions = count | ||||
|  | ||||
| 	security, err := GetSecurityTx(transaction, atl.Account.SecurityId, user.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if security == nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, errors.New("Security not found") | ||||
| 	} | ||||
|  | ||||
| 	// Sum all the splits for all transaction splits for this account that | ||||
| 	// occurred before the page we're returning | ||||
| 	var amounts []string | ||||
| 	sql = "SELECT splits.Amount FROM splits WHERE splits.AccountId=? AND splits.TransactionId IN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ")" | ||||
| 	_, err = transaction.Select(&amounts, sql, accountid, user.UserId, accountid, balanceLimitOffsetArg) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var tmp, balance big.Rat | ||||
| 	for _, amount := range amounts { | ||||
| 		rat_amount, err := GetBigAmount(amount) | ||||
| 		if err != nil { | ||||
| 			transaction.Rollback() | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
| 	atl.BeginningBalance = balance.FloatString(security.Precision) | ||||
| 	atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &atl, nil | ||||
| } | ||||
|  | ||||
| // Return only those transactions which have at least one split pertaining to | ||||
| // an account | ||||
| func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, | ||||
| 	user *User, accountid int64) { | ||||
|  | ||||
| 	var page uint64 = 0 | ||||
| 	var limit uint64 = 50 | ||||
| 	var sort string = "date-desc" | ||||
|  | ||||
| 	query, _ := url.ParseQuery(r.URL.RawQuery) | ||||
|  | ||||
| 	pagestring := query.Get("page") | ||||
| 	if pagestring != "" { | ||||
| 		p, err := strconv.ParseUint(pagestring, 10, 0) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		page = p | ||||
| 	} | ||||
|  | ||||
| 	limitstring := query.Get("limit") | ||||
| 	if limitstring != "" { | ||||
| 		l, err := strconv.ParseUint(limitstring, 10, 0) | ||||
| 		if err != nil || l > 100 { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		limit = l | ||||
| 	} | ||||
|  | ||||
| 	sortstring := query.Get("sort") | ||||
| 	if sortstring != "" { | ||||
| 		if sortstring != "date-asc" && sortstring != "date-desc" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		sort = sortstring | ||||
| 	} | ||||
|  | ||||
| 	accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = accountTransactions.Write(w) | ||||
| 	if err != nil { | ||||
| 		WriteError(w, 999 /*Internal Error*/) | ||||
| 		log.Print(err) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										291
									
								
								internal/handlers/users.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										291
									
								
								internal/handlers/users.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,291 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"crypto/sha256" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"gopkg.in/gorp.v1" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type User struct { | ||||
| 	UserId          int64 | ||||
| 	DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user | ||||
| 	Name            string | ||||
| 	Username        string | ||||
| 	Password        string `db:"-"` | ||||
| 	PasswordHash    string `json:"-"` | ||||
| 	Email           string | ||||
| } | ||||
|  | ||||
| const BogusPassword = "password" | ||||
|  | ||||
| type UserExistsError struct{} | ||||
|  | ||||
| func (ueu UserExistsError) Error() string { | ||||
| 	return "User exists" | ||||
| } | ||||
|  | ||||
| func (u *User) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(u) | ||||
| } | ||||
|  | ||||
| func (u *User) Read(json_str string) error { | ||||
| 	dec := json.NewDecoder(strings.NewReader(json_str)) | ||||
| 	return dec.Decode(u) | ||||
| } | ||||
|  | ||||
| func (u *User) HashPassword() { | ||||
| 	password_hasher := sha256.New() | ||||
| 	io.WriteString(password_hasher, u.Password) | ||||
| 	u.PasswordHash = fmt.Sprintf("%x", password_hasher.Sum(nil)) | ||||
| 	u.Password = "" | ||||
| } | ||||
|  | ||||
| func GetUser(db *DB, userid int64) (*User, error) { | ||||
| 	var u User | ||||
|  | ||||
| 	err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) { | ||||
| 	var u User | ||||
|  | ||||
| 	err := transaction.SelectOne(&u, "SELECT * from users where UserId=?", userid) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func GetUserByUsername(db *DB, username string) (*User, error) { | ||||
| 	var u User | ||||
|  | ||||
| 	err := db.SelectOne(&u, "SELECT * from users where Username=?", username) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func InsertUser(db *DB, u *User) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	security_template := FindCurrencyTemplate(u.DefaultCurrency) | ||||
| 	if security_template == nil { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Invalid ISO4217 Default Currency") | ||||
| 	} | ||||
|  | ||||
| 	existing, err := transaction.SelectInt("SELECT count(*) from users where Username=?", u.Username) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	if existing > 0 { | ||||
| 		transaction.Rollback() | ||||
| 		return UserExistsError{} | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Insert(u) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// Copy the security template and give it our new UserId | ||||
| 	var security Security | ||||
| 	security = *security_template | ||||
| 	security.UserId = u.UserId | ||||
|  | ||||
| 	err = InsertSecurityTx(transaction, &security) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// Update the user's DefaultCurrency to our new SecurityId | ||||
| 	u.DefaultCurrency = security.SecurityId | ||||
| 	count, err := transaction.Update(u) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} else if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Would have updated more than one user") | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func GetUserFromSession(db *DB, r *http.Request) (*User, error) { | ||||
| 	s, err := GetSession(db, r) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return GetUser(db, s.UserId) | ||||
| } | ||||
|  | ||||
| func UpdateUser(db *DB, u *User) error { | ||||
| 	transaction, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	security, err := GetSecurityTx(transaction, u.DefaultCurrency, u.UserId) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("UserId and DefaultCurrency don't match the fetched security") | ||||
| 	} else if security.Type != Currency { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("New DefaultCurrency security is not a currency") | ||||
| 	} | ||||
|  | ||||
| 	count, err := transaction.Update(u) | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} else if count != 1 { | ||||
| 		transaction.Rollback() | ||||
| 		return errors.New("Would have updated more than one user") | ||||
| 	} | ||||
|  | ||||
| 	err = transaction.Commit() | ||||
| 	if err != nil { | ||||
| 		transaction.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { | ||||
| 	if r.Method == "POST" { | ||||
| 		user_json := r.PostFormValue("user") | ||||
| 		if user_json == "" { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		var user User | ||||
| 		err := user.Read(user_json) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
| 		user.UserId = -1 | ||||
| 		user.HashPassword() | ||||
|  | ||||
| 		err = InsertUser(db, &user) | ||||
| 		if err != nil { | ||||
| 			if _, ok := err.(UserExistsError); ok { | ||||
| 				WriteError(w, 4 /*User Exists*/) | ||||
| 			} else { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		w.WriteHeader(201 /*Created*/) | ||||
| 		err = user.Write(w) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 999 /*Internal Error*/) | ||||
| 			log.Print(err) | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		user, err := GetUserFromSession(db, r) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 1 /*Not Signed In*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		userid, err := GetURLID(r.URL.Path) | ||||
| 		if err != nil { | ||||
| 			WriteError(w, 3 /*Invalid Request*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if userid != user.UserId { | ||||
| 			WriteError(w, 2 /*Unauthorized Access*/) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if r.Method == "GET" { | ||||
| 			err = user.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} else if r.Method == "PUT" { | ||||
| 			user_json := r.PostFormValue("user") | ||||
| 			if user_json == "" { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// Save old PWHash in case the new password is bogus | ||||
| 			old_pwhash := user.PasswordHash | ||||
|  | ||||
| 			err = user.Read(user_json) | ||||
| 			if err != nil || user.UserId != userid { | ||||
| 				WriteError(w, 3 /*Invalid Request*/) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// If the user didn't create a new password, keep their old one | ||||
| 			if user.Password != BogusPassword { | ||||
| 				user.HashPassword() | ||||
| 			} else { | ||||
| 				user.Password = "" | ||||
| 				user.PasswordHash = old_pwhash | ||||
| 			} | ||||
|  | ||||
| 			err = UpdateUser(db, user) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			err = user.Write(w) | ||||
| 			if err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
| 		} else if r.Method == "DELETE" { | ||||
| 			count, err := db.Delete(&user) | ||||
| 			if count != 1 || err != nil { | ||||
| 				WriteError(w, 999 /*Internal Error*/) | ||||
| 				log.Print(err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			WriteSuccess(w) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										23
									
								
								internal/handlers/util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								internal/handlers/util.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func GetURLID(url string) (int64, error) { | ||||
| 	pieces := strings.Split(strings.Trim(url, "/"), "/") | ||||
| 	return strconv.ParseInt(pieces[len(pieces)-1], 10, 0) | ||||
| } | ||||
|  | ||||
| func GetURLPieces(url string, format string, a ...interface{}) (int, error) { | ||||
| 	url = strings.Replace(url, "/", " ", -1) | ||||
| 	format = strings.Replace(format, "/", " ", -1) | ||||
| 	return fmt.Sscanf(url, format, a...) | ||||
| } | ||||
|  | ||||
| func WriteSuccess(w http.ResponseWriter) { | ||||
| 	fmt.Fprint(w, "{}") | ||||
| } | ||||
		Reference in New Issue
	
	Block a user