mirror of
				https://github.com/aclindsa/moneygo.git
				synced 2025-10-30 09:33:25 -04:00 
			
		
		
		
	Split Lua reports into own package
This commit is contained in:
		| @@ -1,223 +0,0 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| const luaAccountTypeName = "account" | ||||
|  | ||||
| func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) { | ||||
| 	var account_map map[int64]*models.Account | ||||
|  | ||||
| 	ctx := L.Context() | ||||
|  | ||||
| 	tx, ok := ctx.Value(dbContextKey).(store.Tx) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
|  | ||||
| 	account_map, ok = ctx.Value(accountsContextKey).(map[int64]*models.Account) | ||||
| 	if !ok { | ||||
| 		user, ok := ctx.Value(userContextKey).(*models.User) | ||||
| 		if !ok { | ||||
| 			return nil, errors.New("Couldn't find User in lua's Context") | ||||
| 		} | ||||
|  | ||||
| 		accounts, err := tx.GetAccounts(user.UserId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		account_map = make(map[int64]*models.Account) | ||||
| 		for i := range *accounts { | ||||
| 			account_map[(*accounts)[i].AccountId] = (*accounts)[i] | ||||
| 		} | ||||
|  | ||||
| 		ctx = context.WithValue(ctx, accountsContextKey, account_map) | ||||
| 		L.SetContext(ctx) | ||||
| 	} | ||||
|  | ||||
| 	return account_map, nil | ||||
| } | ||||
|  | ||||
| func luaGetAccounts(L *lua.LState) int { | ||||
| 	account_map, err := luaContextGetAccounts(L) | ||||
| 	if err != nil { | ||||
| 		panic("luaGetAccounts couldn't fetch accounts") | ||||
| 	} | ||||
|  | ||||
| 	table := L.NewTable() | ||||
|  | ||||
| 	for accountid := range account_map { | ||||
| 		table.RawSetInt(int(accountid), AccountToLua(L, account_map[accountid])) | ||||
| 	} | ||||
|  | ||||
| 	L.Push(table) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaRegisterAccounts(L *lua.LState) { | ||||
| 	mt := L.NewTypeMetatable(luaAccountTypeName) | ||||
| 	L.SetGlobal("account", mt) | ||||
| 	L.SetField(mt, "__index", L.NewFunction(luaAccount__index)) | ||||
| 	L.SetField(mt, "__tostring", L.NewFunction(luaAccount__tostring)) | ||||
| 	L.SetField(mt, "__eq", L.NewFunction(luaAccount__eq)) | ||||
| 	L.SetField(mt, "__metatable", lua.LString("protected")) | ||||
|  | ||||
| 	for _, accttype := range models.AccountTypes { | ||||
| 		L.SetField(mt, accttype.String(), lua.LNumber(float64(accttype))) | ||||
| 	} | ||||
|  | ||||
| 	getAccountsFn := L.NewFunction(luaGetAccounts) | ||||
| 	L.SetField(mt, "get_all", getAccountsFn) | ||||
| 	// also register the get_accounts function as a global in its own right | ||||
| 	L.SetGlobal("get_accounts", getAccountsFn) | ||||
| } | ||||
|  | ||||
| func AccountToLua(L *lua.LState, account *models.Account) *lua.LUserData { | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = account | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaAccountTypeName)) | ||||
| 	return ud | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Account and returns this *Account. | ||||
| func luaCheckAccount(L *lua.LState, n int) *models.Account { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if account, ok := ud.Value.(*models.Account); ok { | ||||
| 		return account | ||||
| 	} | ||||
| 	L.ArgError(n, "account expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaAccount__index(L *lua.LState) int { | ||||
| 	a := luaCheckAccount(L, 1) | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "AccountId", "accountid": | ||||
| 		L.Push(lua.LNumber(float64(a.AccountId))) | ||||
| 	case "Security", "security": | ||||
| 		security_map, err := luaContextGetSecurities(L) | ||||
| 		if err != nil { | ||||
| 			panic("account.security couldn't fetch securities") | ||||
| 		} | ||||
| 		if security, ok := security_map[a.SecurityId]; ok { | ||||
| 			L.Push(SecurityToLua(L, security)) | ||||
| 		} else { | ||||
| 			panic("SecurityId not in lua security_map") | ||||
| 		} | ||||
| 	case "SecurityId", "securityid": | ||||
| 		L.Push(lua.LNumber(float64(a.SecurityId))) | ||||
| 	case "Parent", "parent", "ParentAccount", "parentaccount": | ||||
| 		if a.ParentAccountId == -1 { | ||||
| 			L.Push(lua.LNil) | ||||
| 		} else { | ||||
| 			account_map, err := luaContextGetAccounts(L) | ||||
| 			if err != nil { | ||||
| 				panic("account.parent couldn't fetch accounts") | ||||
| 			} | ||||
| 			if parent, ok := account_map[a.ParentAccountId]; ok { | ||||
| 				L.Push(AccountToLua(L, parent)) | ||||
| 			} else { | ||||
| 				panic("ParentAccountId not in lua account_map") | ||||
| 			} | ||||
| 		} | ||||
| 	case "Name", "name": | ||||
| 		L.Push(lua.LString(a.Name)) | ||||
| 	case "Type", "type": | ||||
| 		L.Push(lua.LNumber(float64(a.Type))) | ||||
| 	case "TypeName", "Typename": | ||||
| 		L.Push(lua.LString(a.Type.String())) | ||||
| 	case "typename": | ||||
| 		L.Push(lua.LString(strings.ToLower(a.Type.String()))) | ||||
| 	case "Balance", "balance": | ||||
| 		L.Push(L.NewFunction(luaAccountBalance)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected account attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaAccountBalance(L *lua.LState) int { | ||||
| 	a := luaCheckAccount(L, 1) | ||||
|  | ||||
| 	ctx := L.Context() | ||||
| 	tx, ok := ctx.Value(dbContextKey).(store.Tx) | ||||
| 	if !ok { | ||||
| 		panic("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
| 	user, ok := ctx.Value(userContextKey).(*models.User) | ||||
| 	if !ok { | ||||
| 		panic("Couldn't find User in lua's Context") | ||||
| 	} | ||||
| 	security_map, err := luaContextGetSecurities(L) | ||||
| 	if err != nil { | ||||
| 		panic("account.security couldn't fetch securities") | ||||
| 	} | ||||
| 	security, ok := security_map[a.SecurityId] | ||||
| 	if !ok { | ||||
| 		panic("SecurityId not in lua security_map") | ||||
| 	} | ||||
| 	date := luaWeakCheckTime(L, 2) | ||||
| 	var splits *[]*models.Split | ||||
| 	if date != nil { | ||||
| 		end := luaWeakCheckTime(L, 3) | ||||
| 		if end != nil { | ||||
| 			splits, err = tx.GetAccountSplitsDateRange(user, a.AccountId, date, end) | ||||
| 		} else { | ||||
| 			splits, err = tx.GetAccountSplitsDate(user, a.AccountId, date) | ||||
| 		} | ||||
| 	} else { | ||||
| 		splits, err = tx.GetAccountSplits(user, a.AccountId) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		panic("Failed to fetch splits for account:" + err.Error()) | ||||
| 	} | ||||
| 	rat, err := BalanceFromSplits(splits) | ||||
| 	if err != nil { | ||||
| 		panic("Failed to calculate balance for account:" + err.Error()) | ||||
| 	} | ||||
| 	b := &Balance{ | ||||
| 		Amount:   rat, | ||||
| 		Security: security, | ||||
| 	} | ||||
| 	L.Push(BalanceToLua(L, b)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaAccount__tostring(L *lua.LState) int { | ||||
| 	a := luaCheckAccount(L, 1) | ||||
|  | ||||
| 	account_map, err := luaContextGetAccounts(L) | ||||
| 	if err != nil { | ||||
| 		panic("luaGetAccounts couldn't fetch accounts") | ||||
| 	} | ||||
|  | ||||
| 	full_name := a.Name | ||||
| 	for a.ParentAccountId != -1 { | ||||
| 		a = account_map[a.ParentAccountId] | ||||
| 		full_name = a.Name + "/" + full_name | ||||
| 	} | ||||
|  | ||||
| 	L.Push(lua.LString(full_name)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaAccount__eq(L *lua.LState) int { | ||||
| 	a := luaCheckAccount(L, 1) | ||||
| 	b := luaCheckAccount(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.AccountId == b.AccountId)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
| @@ -1,225 +0,0 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| 	"math/big" | ||||
| ) | ||||
|  | ||||
| type Balance struct { | ||||
| 	Security *models.Security | ||||
| 	Amount   *big.Rat | ||||
| } | ||||
|  | ||||
| const luaBalanceTypeName = "balance" | ||||
|  | ||||
| func luaRegisterBalances(L *lua.LState) { | ||||
| 	mt := L.NewTypeMetatable(luaBalanceTypeName) | ||||
| 	L.SetGlobal("balance", mt) | ||||
| 	L.SetField(mt, "__index", L.NewFunction(luaBalance__index)) | ||||
| 	L.SetField(mt, "__tostring", L.NewFunction(luaBalance__tostring)) | ||||
| 	L.SetField(mt, "__eq", L.NewFunction(luaBalance__eq)) | ||||
| 	L.SetField(mt, "__lt", L.NewFunction(luaBalance__lt)) | ||||
| 	L.SetField(mt, "__le", L.NewFunction(luaBalance__le)) | ||||
| 	L.SetField(mt, "__add", L.NewFunction(luaBalance__add)) | ||||
| 	L.SetField(mt, "__sub", L.NewFunction(luaBalance__sub)) | ||||
| 	L.SetField(mt, "__mul", L.NewFunction(luaBalance__mul)) | ||||
| 	L.SetField(mt, "__div", L.NewFunction(luaBalance__div)) | ||||
| 	L.SetField(mt, "__unm", L.NewFunction(luaBalance__unm)) | ||||
| 	L.SetField(mt, "__metatable", lua.LString("protected")) | ||||
| } | ||||
|  | ||||
| func BalanceToLua(L *lua.LState, balance *Balance) *lua.LUserData { | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = balance | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaBalanceTypeName)) | ||||
| 	return ud | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Balance and returns this *Balance. | ||||
| func luaCheckBalance(L *lua.LState, n int) *Balance { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if balance, ok := ud.Value.(*Balance); ok { | ||||
| 		return balance | ||||
| 	} | ||||
| 	L.ArgError(n, "balance expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaWeakCheckBalance(L *lua.LState, n int) *Balance { | ||||
| 	v := L.Get(n) | ||||
| 	if ud, ok := v.(*lua.LUserData); ok { | ||||
| 		if balance, ok := ud.Value.(*Balance); ok { | ||||
| 			return balance | ||||
| 		} | ||||
| 		L.ArgError(n, "balance expected") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaGetBalanceOperands(L *lua.LState, m int, n int) (*Balance, *Balance) { | ||||
| 	bm := luaWeakCheckBalance(L, m) | ||||
| 	bn := luaWeakCheckBalance(L, n) | ||||
|  | ||||
| 	if bm != nil && bn != nil { | ||||
| 		return bm, bn | ||||
| 	} else if bm != nil { | ||||
| 		nn := L.CheckNumber(n) | ||||
| 		var balance Balance | ||||
| 		var rat big.Rat | ||||
| 		balance.Security = bm.Security | ||||
| 		balance.Amount = rat.SetFloat64(float64(nn)) | ||||
| 		if balance.Amount == nil { | ||||
| 			L.ArgError(n, "non-finite float invalid for operand to balance arithemetic") | ||||
| 			return nil, nil | ||||
| 		} | ||||
| 		return bm, &balance | ||||
| 	} else if bn != nil { | ||||
| 		nm := L.CheckNumber(m) | ||||
| 		var balance Balance | ||||
| 		var rat big.Rat | ||||
| 		balance.Security = bn.Security | ||||
| 		balance.Amount = rat.SetFloat64(float64(nm)) | ||||
| 		if balance.Amount == nil { | ||||
| 			L.ArgError(m, "non-finite float invalid for operand to balance arithemetic") | ||||
| 			return nil, nil | ||||
| 		} | ||||
| 		return &balance, bn | ||||
| 	} | ||||
| 	L.ArgError(m, "balance expected") | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func luaBalance__index(L *lua.LState) int { | ||||
| 	a := luaCheckBalance(L, 1) | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "Security", "security": | ||||
| 		L.Push(SecurityToLua(L, a.Security)) | ||||
| 	case "Amount", "amount": | ||||
| 		float, _ := a.Amount.Float64() | ||||
| 		L.Push(lua.LNumber(float)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected balance attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__tostring(L *lua.LState) int { | ||||
| 	b := luaCheckBalance(L, 1) | ||||
|  | ||||
| 	L.Push(lua.LString(b.Security.Symbol + " " + b.Amount.FloatString(b.Security.Precision))) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__eq(L *lua.LState) int { | ||||
| 	a := luaCheckBalance(L, 1) | ||||
| 	b := luaCheckBalance(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Security.SecurityId == b.Security.SecurityId && a.Amount.Cmp(b.Amount) == 0)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__lt(L *lua.LState) int { | ||||
| 	a := luaCheckBalance(L, 1) | ||||
| 	b := luaCheckBalance(L, 2) | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't compare balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Amount.Cmp(b.Amount) < 0)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__le(L *lua.LState) int { | ||||
| 	a := luaCheckBalance(L, 1) | ||||
| 	b := luaCheckBalance(L, 2) | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't compare balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Amount.Cmp(b.Amount) <= 0)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__add(L *lua.LState) int { | ||||
| 	a, b := luaGetBalanceOperands(L, 1, 2) | ||||
|  | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't add balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	var balance Balance | ||||
| 	var rat big.Rat | ||||
| 	balance.Security = a.Security | ||||
| 	balance.Amount = rat.Add(a.Amount, b.Amount) | ||||
| 	L.Push(BalanceToLua(L, &balance)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__sub(L *lua.LState) int { | ||||
| 	a, b := luaGetBalanceOperands(L, 1, 2) | ||||
|  | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't subtract balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	var balance Balance | ||||
| 	var rat big.Rat | ||||
| 	balance.Security = a.Security | ||||
| 	balance.Amount = rat.Sub(a.Amount, b.Amount) | ||||
| 	L.Push(BalanceToLua(L, &balance)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__mul(L *lua.LState) int { | ||||
| 	a, b := luaGetBalanceOperands(L, 1, 2) | ||||
|  | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't multiply balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	var balance Balance | ||||
| 	var rat big.Rat | ||||
| 	balance.Security = a.Security | ||||
| 	balance.Amount = rat.Mul(a.Amount, b.Amount) | ||||
| 	L.Push(BalanceToLua(L, &balance)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__div(L *lua.LState) int { | ||||
| 	a, b := luaGetBalanceOperands(L, 1, 2) | ||||
|  | ||||
| 	if a.Security.SecurityId != b.Security.SecurityId { | ||||
| 		L.ArgError(2, "Can't divide balances with different securities") | ||||
| 	} | ||||
|  | ||||
| 	var balance Balance | ||||
| 	var rat big.Rat | ||||
| 	balance.Security = a.Security | ||||
| 	balance.Amount = rat.Quo(a.Amount, b.Amount) | ||||
| 	L.Push(BalanceToLua(L, &balance)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaBalance__unm(L *lua.LState) int { | ||||
| 	b := luaCheckBalance(L, 1) | ||||
|  | ||||
| 	var balance Balance | ||||
| 	var rat big.Rat | ||||
| 	balance.Security = b.Security | ||||
| 	balance.Amount = rat.Neg(b.Amount) | ||||
| 	L.Push(BalanceToLua(L, &balance)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
| @@ -1,170 +0,0 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| const luaDateTypeName = "date" | ||||
| const timeFormat = "2006-01-02" | ||||
|  | ||||
| func luaRegisterDates(L *lua.LState) { | ||||
| 	mt := L.NewTypeMetatable(luaDateTypeName) | ||||
| 	L.SetGlobal("date", mt) | ||||
| 	L.SetField(mt, "new", L.NewFunction(luaDateNew)) | ||||
| 	L.SetField(mt, "now", L.NewFunction(luaDateNow)) | ||||
| 	L.SetField(mt, "__index", L.NewFunction(luaDate__index)) | ||||
| 	L.SetField(mt, "__tostring", L.NewFunction(luaDate__tostring)) | ||||
| 	L.SetField(mt, "__eq", L.NewFunction(luaDate__eq)) | ||||
| 	L.SetField(mt, "__lt", L.NewFunction(luaDate__lt)) | ||||
| 	L.SetField(mt, "__le", L.NewFunction(luaDate__le)) | ||||
| 	L.SetField(mt, "__add", L.NewFunction(luaDate__add)) | ||||
| 	L.SetField(mt, "__sub", L.NewFunction(luaDate__sub)) | ||||
| 	L.SetField(mt, "__metatable", lua.LString("protected")) | ||||
| } | ||||
|  | ||||
| func TimeToLua(L *lua.LState, date *time.Time) *lua.LUserData { | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = date | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaDateTypeName)) | ||||
| 	return ud | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Time and returns this *Time. | ||||
| func luaCheckTime(L *lua.LState, n int) *time.Time { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if date, ok := ud.Value.(*time.Time); ok { | ||||
| 		return date | ||||
| 	} | ||||
| 	L.ArgError(n, "date expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaWeakCheckTime(L *lua.LState, n int) *time.Time { | ||||
| 	v := L.Get(n) | ||||
| 	if ud, ok := v.(*lua.LUserData); ok { | ||||
| 		if date, ok := ud.Value.(*time.Time); ok { | ||||
| 			return date | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaWeakCheckTableFieldInt(L *lua.LState, T *lua.LTable, n int, name string, def int) int { | ||||
| 	lv := T.RawGetString(name) | ||||
| 	if lv == lua.LNil { | ||||
| 		return def | ||||
| 	} | ||||
| 	if i, ok := lv.(lua.LNumber); ok { | ||||
| 		return int(i) | ||||
| 	} | ||||
| 	L.ArgError(n, "table field '"+name+"' expected to be int") | ||||
| 	return def | ||||
| } | ||||
|  | ||||
| func luaDateNew(L *lua.LState) int { | ||||
| 	// TODO make this track the user's timezone | ||||
| 	v := L.Get(1) | ||||
| 	if s, ok := v.(lua.LString); ok { | ||||
| 		date, err := time.ParseInLocation(timeFormat, s.String(), time.Local) | ||||
| 		if err != nil { | ||||
| 			L.ArgError(1, "error parsing date string: "+err.Error()) | ||||
| 			return 0 | ||||
| 		} | ||||
| 		L.Push(TimeToLua(L, &date)) | ||||
| 		return 1 | ||||
| 	} | ||||
| 	var year, month, day int | ||||
| 	if t, ok := v.(*lua.LTable); ok { | ||||
| 		year = luaWeakCheckTableFieldInt(L, t, 1, "year", 0) | ||||
| 		month = luaWeakCheckTableFieldInt(L, t, 1, "month", 1) | ||||
| 		day = luaWeakCheckTableFieldInt(L, t, 1, "day", 1) | ||||
| 	} else { | ||||
| 		year = L.CheckInt(1) | ||||
| 		month = L.CheckInt(2) | ||||
| 		day = L.CheckInt(3) | ||||
| 	} | ||||
| 	date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.Local) | ||||
| 	L.Push(TimeToLua(L, &date)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDateNow(L *lua.LState) int { | ||||
| 	now := time.Now() | ||||
| 	date := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local) | ||||
| 	L.Push(TimeToLua(L, &date)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__index(L *lua.LState) int { | ||||
| 	d := luaCheckTime(L, 1) | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "Year", "year": | ||||
| 		L.Push(lua.LNumber(d.Year())) | ||||
| 	case "Month", "month": | ||||
| 		L.Push(lua.LNumber(float64(d.Month()))) | ||||
| 	case "Day", "day": | ||||
| 		L.Push(lua.LNumber(float64(d.Day()))) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected date attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__tostring(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
|  | ||||
| 	L.Push(lua.LString(a.Format(timeFormat))) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__eq(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
| 	b := luaCheckTime(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Equal(*b))) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__lt(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
| 	b := luaCheckTime(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Before(*b))) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__le(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
| 	b := luaCheckTime(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.Equal(*b) || a.Before(*b))) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__add(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
| 	b := luaCheckTime(L, 2) | ||||
|  | ||||
| 	date := a.AddDate(b.Year(), int(b.Month()), b.Day()) | ||||
| 	L.Push(TimeToLua(L, &date)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaDate__sub(L *lua.LState) int { | ||||
| 	a := luaCheckTime(L, 1) | ||||
| 	b := luaCheckTime(L, 2) | ||||
|  | ||||
| 	date := a.AddDate(-b.Year(), -int(b.Month()), -b.Day()) | ||||
| 	L.Push(TimeToLua(L, &date)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
| @@ -5,7 +5,6 @@ import ( | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func CreatePriceIfNotExist(tx store.Tx, price *models.Price) error { | ||||
| @@ -33,27 +32,6 @@ func CreatePriceIfNotExist(tx store.Tx, price *models.Price) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Return the price for security in currency closest to date | ||||
| func GetClosestPrice(tx store.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { | ||||
| 	earliest, _ := tx.GetEarliestPrice(security, currency, date) | ||||
| 	latest, err := tx.GetLatestPrice(security, currency, date) | ||||
|  | ||||
| 	// Return early if either earliest or latest are invalid | ||||
| 	if earliest == nil { | ||||
| 		return latest, err | ||||
| 	} else if err != nil { | ||||
| 		return earliest, nil | ||||
| 	} | ||||
|  | ||||
| 	howlate := earliest.Date.Sub(*date) | ||||
| 	howearly := date.Sub(latest.Date) | ||||
| 	if howearly < howlate { | ||||
| 		return latest, nil | ||||
| 	} else { | ||||
| 		return earliest, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter { | ||||
| 	security, err := context.Tx.GetSecurity(securityid, user.UserId) | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -1,92 +0,0 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| ) | ||||
|  | ||||
| const luaPriceTypeName = "price" | ||||
|  | ||||
| func luaRegisterPrices(L *lua.LState) { | ||||
| 	mt := L.NewTypeMetatable(luaPriceTypeName) | ||||
| 	L.SetGlobal("price", mt) | ||||
| 	L.SetField(mt, "__index", L.NewFunction(luaPrice__index)) | ||||
| 	L.SetField(mt, "__tostring", L.NewFunction(luaPrice__tostring)) | ||||
| 	L.SetField(mt, "__metatable", lua.LString("protected")) | ||||
| } | ||||
|  | ||||
| func PriceToLua(L *lua.LState, price *models.Price) *lua.LUserData { | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = price | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaPriceTypeName)) | ||||
| 	return ud | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Price and returns this *Price. | ||||
| func luaCheckPrice(L *lua.LState, n int) *models.Price { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if price, ok := ud.Value.(*models.Price); ok { | ||||
| 		return price | ||||
| 	} | ||||
| 	L.ArgError(n, "price expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaPrice__index(L *lua.LState) int { | ||||
| 	p := luaCheckPrice(L, 1) | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "PriceId", "priceid": | ||||
| 		L.Push(lua.LNumber(float64(p.PriceId))) | ||||
| 	case "Security", "security": | ||||
| 		security_map, err := luaContextGetSecurities(L) | ||||
| 		if err != nil { | ||||
| 			panic("luaContextGetSecurities couldn't fetch securities") | ||||
| 		} | ||||
| 		s, ok := security_map[p.SecurityId] | ||||
| 		if !ok { | ||||
| 			panic("Price's security not found for user") | ||||
| 		} | ||||
| 		L.Push(SecurityToLua(L, s)) | ||||
| 	case "Currency", "currency": | ||||
| 		security_map, err := luaContextGetSecurities(L) | ||||
| 		if err != nil { | ||||
| 			panic("luaContextGetSecurities couldn't fetch securities") | ||||
| 		} | ||||
| 		c, ok := security_map[p.CurrencyId] | ||||
| 		if !ok { | ||||
| 			panic("Price's currency not found for user") | ||||
| 		} | ||||
| 		L.Push(SecurityToLua(L, c)) | ||||
| 	case "Value", "value": | ||||
| 		amt, err := models.GetBigAmount(p.Value) | ||||
| 		if err != nil { | ||||
| 			panic(err) | ||||
| 		} | ||||
| 		float, _ := amt.Float64() | ||||
| 		L.Push(lua.LNumber(float)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected price attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaPrice__tostring(L *lua.LState) int { | ||||
| 	p := luaCheckPrice(L, 1) | ||||
|  | ||||
| 	security_map, err := luaContextGetSecurities(L) | ||||
| 	if err != nil { | ||||
| 		panic("luaContextGetSecurities couldn't fetch securities") | ||||
| 	} | ||||
| 	s, ok1 := security_map[p.SecurityId] | ||||
| 	c, ok2 := security_map[p.CurrencyId] | ||||
| 	if !ok1 || !ok2 { | ||||
| 		panic("Price's currency or security not found for user") | ||||
| 	} | ||||
|  | ||||
| 	L.Push(lua.LString(p.Value + " " + c.Symbol + " (" + s.Symbol + ")")) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
| @@ -1,104 +1,23 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/reports" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| //type and value to store user in lua's Context | ||||
| type key int | ||||
|  | ||||
| const ( | ||||
| 	userContextKey key = iota | ||||
| 	accountsContextKey | ||||
| 	securitiesContextKey | ||||
| 	balanceContextKey | ||||
| 	dbContextKey | ||||
| ) | ||||
|  | ||||
| const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for | ||||
|  | ||||
| func runReport(tx store.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { | ||||
| 	// Create a new LState without opening the default libs for security | ||||
| 	L := lua.NewState(lua.Options{SkipOpenLibs: true}) | ||||
| 	defer L.Close() | ||||
|  | ||||
| 	// Create a new context holding the current user with a timeout | ||||
| 	ctx := context.WithValue(context.Background(), userContextKey, user) | ||||
| 	ctx = context.WithValue(ctx, dbContextKey, tx) | ||||
| 	ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second) | ||||
| 	defer cancel() | ||||
| 	L.SetContext(ctx) | ||||
|  | ||||
| 	for _, pair := range []struct { | ||||
| 		n string | ||||
| 		f lua.LGFunction | ||||
| 	}{ | ||||
| 		{lua.LoadLibName, lua.OpenPackage}, // Must be first | ||||
| 		{lua.BaseLibName, lua.OpenBase}, | ||||
| 		{lua.TabLibName, lua.OpenTable}, | ||||
| 		{lua.StringLibName, lua.OpenString}, | ||||
| 		{lua.MathLibName, lua.OpenMath}, | ||||
| 	} { | ||||
| 		if err := L.CallByParam(lua.P{ | ||||
| 			Fn:      L.NewFunction(pair.f), | ||||
| 			NRet:    0, | ||||
| 			Protect: true, | ||||
| 		}, lua.LString(pair.n)); err != nil { | ||||
| 			return nil, errors.New("Error initializing Lua packages") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	luaRegisterAccounts(L) | ||||
| 	luaRegisterSecurities(L) | ||||
| 	luaRegisterBalances(L) | ||||
| 	luaRegisterDates(L) | ||||
| 	luaRegisterTabulations(L) | ||||
| 	luaRegisterPrices(L) | ||||
|  | ||||
| 	err := L.DoString(report.Lua) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if err := L.CallByParam(lua.P{ | ||||
| 		Fn:      L.GetGlobal("generate"), | ||||
| 		NRet:    1, | ||||
| 		Protect: true, | ||||
| 	}); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	value := L.Get(-1) | ||||
| 	if ud, ok := value.(*lua.LUserData); ok { | ||||
| 		if tabulation, ok := ud.Value.(*models.Tabulation); ok { | ||||
| 			return tabulation, nil | ||||
| 		} else { | ||||
| 			return nil, fmt.Errorf("generate() for %s (Id: %d) didn't return a tabulation", report.Name, report.ReportId) | ||||
| 		} | ||||
| 	} else { | ||||
| 		return nil, fmt.Errorf("generate() for %s (Id: %d) didn't even return LUserData", report.Name, report.ReportId) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ReportTabulationHandler(tx store.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { | ||||
| 	report, err := tx.GetReport(reportid, user.UserId) | ||||
| 	if err != nil { | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
| 	tabulation, err := runReport(tx, user, report) | ||||
| 	tabulation, err := reports.RunReport(tx, user, report) | ||||
| 	if err != nil { | ||||
| 		// TODO handle different failure cases differently | ||||
| 		log.Print("runReport returned:", err) | ||||
| 		log.Print("reports.RunReport returned:", err) | ||||
| 		return NewError(3 /*Invalid Request*/) | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -1,188 +0,0 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| ) | ||||
|  | ||||
| const luaTabulationTypeName = "tabulation" | ||||
| const luaSeriesTypeName = "series" | ||||
|  | ||||
| func luaRegisterTabulations(L *lua.LState) { | ||||
| 	mtr := L.NewTypeMetatable(luaTabulationTypeName) | ||||
| 	L.SetGlobal("tabulation", mtr) | ||||
| 	L.SetField(mtr, "new", L.NewFunction(luaTabulationNew)) | ||||
| 	L.SetField(mtr, "__index", L.NewFunction(luaTabulation__index)) | ||||
| 	L.SetField(mtr, "__metatable", lua.LString("protected")) | ||||
|  | ||||
| 	mts := L.NewTypeMetatable(luaSeriesTypeName) | ||||
| 	L.SetGlobal("series", mts) | ||||
| 	L.SetField(mts, "__index", L.NewFunction(luaSeries__index)) | ||||
| 	L.SetField(mts, "__metatable", lua.LString("protected")) | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Tabulation and returns *Tabulation | ||||
| func luaCheckTabulation(L *lua.LState, n int) *models.Tabulation { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if tabulation, ok := ud.Value.(*models.Tabulation); ok { | ||||
| 		return tabulation | ||||
| 	} | ||||
| 	L.ArgError(n, "tabulation expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Series and returns *Series | ||||
| func luaCheckSeries(L *lua.LState, n int) *models.Series { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if series, ok := ud.Value.(*models.Series); ok { | ||||
| 		return series | ||||
| 	} | ||||
| 	L.ArgError(n, "series expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaTabulationNew(L *lua.LState) int { | ||||
| 	numvalues := L.CheckInt(1) | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = &models.Tabulation{ | ||||
| 		Labels: make([]string, numvalues), | ||||
| 		Series: make(map[string]*models.Series), | ||||
| 	} | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaTabulationTypeName)) | ||||
| 	L.Push(ud) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaTabulation__index(L *lua.LState) int { | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "Label", "label": | ||||
| 		L.Push(L.NewFunction(luaTabulationLabel)) | ||||
| 	case "Series", "series": | ||||
| 		L.Push(L.NewFunction(luaTabulationSeries)) | ||||
| 	case "Title", "title": | ||||
| 		L.Push(L.NewFunction(luaTabulationTitle)) | ||||
| 	case "Subtitle", "subtitle": | ||||
| 		L.Push(L.NewFunction(luaTabulationSubtitle)) | ||||
| 	case "Units", "units": | ||||
| 		L.Push(L.NewFunction(luaTabulationUnits)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected tabulation attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaTabulationLabel(L *lua.LState) int { | ||||
| 	tabulation := luaCheckTabulation(L, 1) | ||||
| 	labelnumber := L.CheckInt(2) | ||||
| 	label := L.CheckString(3) | ||||
|  | ||||
| 	if labelnumber > cap(tabulation.Labels) || labelnumber < 1 { | ||||
| 		L.ArgError(2, "Label index must be between 1 and the number of data points, inclusive") | ||||
| 	} | ||||
| 	tabulation.Labels[labelnumber-1] = label | ||||
| 	return 0 | ||||
| } | ||||
|  | ||||
| func luaTabulationSeries(L *lua.LState) int { | ||||
| 	tabulation := luaCheckTabulation(L, 1) | ||||
| 	name := L.CheckString(2) | ||||
| 	ud := L.NewUserData() | ||||
|  | ||||
| 	s, ok := tabulation.Series[name] | ||||
| 	if ok { | ||||
| 		ud.Value = s | ||||
| 	} else { | ||||
| 		tabulation.Series[name] = &models.Series{ | ||||
| 			Series: make(map[string]*models.Series), | ||||
| 			Values: make([]float64, cap(tabulation.Labels)), | ||||
| 		} | ||||
| 		ud.Value = tabulation.Series[name] | ||||
| 	} | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaSeriesTypeName)) | ||||
| 	L.Push(ud) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaTabulationTitle(L *lua.LState) int { | ||||
| 	tabulation := luaCheckTabulation(L, 1) | ||||
|  | ||||
| 	if L.GetTop() == 2 { | ||||
| 		tabulation.Title = L.CheckString(2) | ||||
| 		return 0 | ||||
| 	} | ||||
| 	L.Push(lua.LString(tabulation.Title)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaTabulationSubtitle(L *lua.LState) int { | ||||
| 	tabulation := luaCheckTabulation(L, 1) | ||||
|  | ||||
| 	if L.GetTop() == 2 { | ||||
| 		tabulation.Subtitle = L.CheckString(2) | ||||
| 		return 0 | ||||
| 	} | ||||
| 	L.Push(lua.LString(tabulation.Subtitle)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaTabulationUnits(L *lua.LState) int { | ||||
| 	tabulation := luaCheckTabulation(L, 1) | ||||
|  | ||||
| 	if L.GetTop() == 2 { | ||||
| 		tabulation.Units = L.CheckString(2) | ||||
| 		return 0 | ||||
| 	} | ||||
| 	L.Push(lua.LString(tabulation.Units)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaSeries__index(L *lua.LState) int { | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "Value", "value": | ||||
| 		L.Push(L.NewFunction(luaSeriesValue)) | ||||
| 	case "Series", "series": | ||||
| 		L.Push(L.NewFunction(luaSeriesSeries)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected series attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaSeriesValue(L *lua.LState) int { | ||||
| 	series := luaCheckSeries(L, 1) | ||||
| 	valuenumber := L.CheckInt(2) | ||||
| 	value := float64(L.CheckNumber(3)) | ||||
|  | ||||
| 	if valuenumber > cap(series.Values) || valuenumber < 1 { | ||||
| 		L.ArgError(2, "value index must be between 1 and the number of data points, inclusive") | ||||
| 	} | ||||
| 	series.Values[valuenumber-1] = value | ||||
|  | ||||
| 	return 0 | ||||
| } | ||||
|  | ||||
| func luaSeriesSeries(L *lua.LState) int { | ||||
| 	parent := luaCheckSeries(L, 1) | ||||
| 	name := L.CheckString(2) | ||||
| 	ud := L.NewUserData() | ||||
|  | ||||
| 	s, ok := parent.Series[name] | ||||
| 	if ok { | ||||
| 		ud.Value = s | ||||
| 	} else { | ||||
| 		parent.Series[name] = &models.Series{ | ||||
| 			Series: make(map[string]*models.Series), | ||||
| 			Values: make([]float64, cap(parent.Values)), | ||||
| 		} | ||||
| 		ud.Value = parent.Series[name] | ||||
| 	} | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaSeriesTypeName)) | ||||
| 	L.Push(ud) | ||||
| 	return 1 | ||||
| } | ||||
| @@ -1,192 +0,0 @@ | ||||
| package handlers | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	"github.com/yuin/gopher-lua" | ||||
| ) | ||||
|  | ||||
| const luaSecurityTypeName = "security" | ||||
|  | ||||
| func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error) { | ||||
| 	var security_map map[int64]*models.Security | ||||
|  | ||||
| 	ctx := L.Context() | ||||
|  | ||||
| 	tx, ok := ctx.Value(dbContextKey).(store.Tx) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
|  | ||||
| 	security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*models.Security) | ||||
| 	if !ok { | ||||
| 		user, ok := ctx.Value(userContextKey).(*models.User) | ||||
| 		if !ok { | ||||
| 			return nil, errors.New("Couldn't find User in lua's Context") | ||||
| 		} | ||||
|  | ||||
| 		securities, err := tx.GetSecurities(user.UserId) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		security_map = make(map[int64]*models.Security) | ||||
| 		for i := range *securities { | ||||
| 			security_map[(*securities)[i].SecurityId] = (*securities)[i] | ||||
| 		} | ||||
|  | ||||
| 		ctx = context.WithValue(ctx, securitiesContextKey, security_map) | ||||
| 		L.SetContext(ctx) | ||||
| 	} | ||||
|  | ||||
| 	return security_map, nil | ||||
| } | ||||
|  | ||||
| func luaContextGetDefaultCurrency(L *lua.LState) (*models.Security, error) { | ||||
| 	security_map, err := luaContextGetSecurities(L) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	ctx := L.Context() | ||||
|  | ||||
| 	user, ok := ctx.Value(userContextKey).(*models.User) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Couldn't find User in lua's Context") | ||||
| 	} | ||||
|  | ||||
| 	if security, ok := security_map[user.DefaultCurrency]; ok { | ||||
| 		return security, nil | ||||
| 	} else { | ||||
| 		return nil, errors.New("DefaultCurrency not in lua security_map") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func luaGetDefaultCurrency(L *lua.LState) int { | ||||
| 	defcurrency, err := luaContextGetDefaultCurrency(L) | ||||
| 	if err != nil { | ||||
| 		panic("luaGetDefaultCurrency couldn't fetch default currency") | ||||
| 	} | ||||
|  | ||||
| 	L.Push(SecurityToLua(L, defcurrency)) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaGetSecurities(L *lua.LState) int { | ||||
| 	security_map, err := luaContextGetSecurities(L) | ||||
| 	if err != nil { | ||||
| 		panic("luaGetSecurities couldn't fetch securities") | ||||
| 	} | ||||
|  | ||||
| 	table := L.NewTable() | ||||
|  | ||||
| 	for securityid := range security_map { | ||||
| 		table.RawSetInt(int(securityid), SecurityToLua(L, security_map[securityid])) | ||||
| 	} | ||||
|  | ||||
| 	L.Push(table) | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaRegisterSecurities(L *lua.LState) { | ||||
| 	mt := L.NewTypeMetatable(luaSecurityTypeName) | ||||
| 	L.SetGlobal("security", mt) | ||||
| 	L.SetField(mt, "__index", L.NewFunction(luaSecurity__index)) | ||||
| 	L.SetField(mt, "__tostring", L.NewFunction(luaSecurity__tostring)) | ||||
| 	L.SetField(mt, "__eq", L.NewFunction(luaSecurity__eq)) | ||||
| 	L.SetField(mt, "__metatable", lua.LString("protected")) | ||||
| 	getSecuritiesFn := L.NewFunction(luaGetSecurities) | ||||
| 	L.SetField(mt, "get_all", getSecuritiesFn) | ||||
| 	getDefaultCurrencyFn := L.NewFunction(luaGetDefaultCurrency) | ||||
| 	L.SetField(mt, "get_default", getDefaultCurrencyFn) | ||||
|  | ||||
| 	// also register the get_securities and get_default functions as globals in | ||||
| 	// their own right | ||||
| 	L.SetGlobal("get_securities", getSecuritiesFn) | ||||
| 	L.SetGlobal("get_default_currency", getDefaultCurrencyFn) | ||||
| } | ||||
|  | ||||
| func SecurityToLua(L *lua.LState, security *models.Security) *lua.LUserData { | ||||
| 	ud := L.NewUserData() | ||||
| 	ud.Value = security | ||||
| 	L.SetMetatable(ud, L.GetTypeMetatable(luaSecurityTypeName)) | ||||
| 	return ud | ||||
| } | ||||
|  | ||||
| // Checks whether the first lua argument is a *LUserData with *Security and returns this *Security. | ||||
| func luaCheckSecurity(L *lua.LState, n int) *models.Security { | ||||
| 	ud := L.CheckUserData(n) | ||||
| 	if security, ok := ud.Value.(*models.Security); ok { | ||||
| 		return security | ||||
| 	} | ||||
| 	L.ArgError(n, "security expected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func luaSecurity__index(L *lua.LState) int { | ||||
| 	a := luaCheckSecurity(L, 1) | ||||
| 	field := L.CheckString(2) | ||||
|  | ||||
| 	switch field { | ||||
| 	case "SecurityId", "securityid": | ||||
| 		L.Push(lua.LNumber(float64(a.SecurityId))) | ||||
| 	case "Name", "name": | ||||
| 		L.Push(lua.LString(a.Name)) | ||||
| 	case "Description", "description": | ||||
| 		L.Push(lua.LString(a.Description)) | ||||
| 	case "Symbol", "symbol": | ||||
| 		L.Push(lua.LString(a.Symbol)) | ||||
| 	case "Precision", "precision": | ||||
| 		L.Push(lua.LNumber(float64(a.Precision))) | ||||
| 	case "Type", "type": | ||||
| 		L.Push(lua.LNumber(float64(a.Type))) | ||||
| 	case "ClosestPrice", "closestprice": | ||||
| 		L.Push(L.NewFunction(luaClosestPrice)) | ||||
| 	case "AlternateId", "alternateid": | ||||
| 		L.Push(lua.LString(a.AlternateId)) | ||||
| 	default: | ||||
| 		L.ArgError(2, "unexpected security attribute: "+field) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaClosestPrice(L *lua.LState) int { | ||||
| 	s := luaCheckSecurity(L, 1) | ||||
| 	c := luaCheckSecurity(L, 2) | ||||
| 	date := luaCheckTime(L, 3) | ||||
|  | ||||
| 	ctx := L.Context() | ||||
| 	tx, ok := ctx.Value(dbContextKey).(store.Tx) | ||||
| 	if !ok { | ||||
| 		panic("Couldn't find tx in lua's Context") | ||||
| 	} | ||||
|  | ||||
| 	p, err := GetClosestPrice(tx, s, c, date) | ||||
| 	if err != nil { | ||||
| 		L.Push(lua.LNil) | ||||
| 	} else { | ||||
| 		L.Push(PriceToLua(L, p)) | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaSecurity__tostring(L *lua.LState) int { | ||||
| 	s := luaCheckSecurity(L, 1) | ||||
|  | ||||
| 	L.Push(lua.LString(s.Name + " - " + s.Description + " (" + s.Symbol + ")")) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func luaSecurity__eq(L *lua.LState) int { | ||||
| 	a := luaCheckSecurity(L, 1) | ||||
| 	b := luaCheckSecurity(L, 2) | ||||
|  | ||||
| 	L.Push(lua.LBool(a.SecurityId == b.SecurityId)) | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
| @@ -182,20 +182,6 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter | ||||
| 	return NewError(3 /*Invalid Request*/) | ||||
| } | ||||
|  | ||||
| func BalanceFromSplits(splits *[]*models.Split) (*big.Rat, error) { | ||||
| 	var balance, tmp big.Rat | ||||
| 	for _, s := range *splits { | ||||
| 		rat_amount, err := models.GetBigAmount(s.Amount) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tmp.Add(&balance, rat_amount) | ||||
| 		balance.Set(&tmp) | ||||
| 	} | ||||
|  | ||||
| 	return &balance, nil | ||||
| } | ||||
|  | ||||
| // Return only those transactions which have at least one split pertaining to | ||||
| // an account | ||||
| func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user