diff --git a/reports.go b/reports.go index 034eebd..c0f90dc 100644 --- a/reports.go +++ b/reports.go @@ -2,9 +2,11 @@ package main import ( "context" + "errors" "github.com/yuin/gopher-lua" "log" "net/http" + "os" "path" "time" ) @@ -21,6 +23,68 @@ const ( const luaTimeoutSeconds time.Duration = 5 // maximum time a lua request can run for +func runReport(user *User, reportpath string) (*Report, error) { + // Create a new LState without opening the default libs for security + L := lua.NewState(lua.Options{SkipOpenLibs: true}) + defer L.Close() + + // Create a new context holding the current user with a timeout + ctx := context.WithValue(context.Background(), userContextKey, user) + ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second) + defer cancel() + L.SetContext(ctx) + + for _, pair := range []struct { + n string + f lua.LGFunction + }{ + {lua.LoadLibName, lua.OpenPackage}, // Must be first + {lua.BaseLibName, lua.OpenBase}, + {lua.TabLibName, lua.OpenTable}, + {lua.StringLibName, lua.OpenString}, + {lua.MathLibName, lua.OpenMath}, + } { + if err := L.CallByParam(lua.P{ + Fn: L.NewFunction(pair.f), + NRet: 0, + Protect: true, + }, lua.LString(pair.n)); err != nil { + return nil, errors.New("Error initializing Lua packages") + } + } + + luaRegisterAccounts(L) + luaRegisterSecurities(L) + luaRegisterBalances(L) + luaRegisterDates(L) + luaRegisterReports(L) + + err := L.DoFile(reportpath) + + if err != nil { + return nil, err + } + + if err := L.CallByParam(lua.P{ + Fn: L.GetGlobal("generate"), + NRet: 1, + Protect: true, + }); err != nil { + return nil, err + } + + value := L.Get(-1) + if ud, ok := value.(*lua.LUserData); ok { + if report, ok := ud.Value.(*Report); ok { + return report, nil + } else { + return nil, errors.New("generate() in " + reportpath + " didn't return a report") + } + } else { + return nil, errors.New("generate() in " + reportpath + " didn't return a report") + } +} + func ReportHandler(w http.ResponseWriter, r *http.Request) { user, err := GetUserFromSession(r) if err != nil { @@ -37,48 +101,23 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) { } reportpath := path.Join(baseDir, "reports", reportname+".lua") - - // Create a new LState without opening the default libs for security - L := lua.NewState(lua.Options{SkipOpenLibs: true}) - defer L.Close() - - // Create a new context holding the current user with a timeout - ctx := context.WithValue(context.Background(), userContextKey, user) - ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second) - defer cancel() - L.SetContext(ctx) - - for _, pair := range []struct { - n string - f lua.LGFunction - }{ - {lua.LoadLibName, lua.OpenPackage}, // Must be first - {lua.BaseLibName, lua.OpenBase}, - {lua.TabLibName, lua.OpenTable}, - {lua.StringLibName, lua.OpenString}, - {lua.MathLibName, lua.OpenMath}, - } { - if err := L.CallByParam(lua.P{ - Fn: L.NewFunction(pair.f), - NRet: 0, - Protect: true, - }, lua.LString(pair.n)); err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + report_stat, err := os.Stat(reportpath) + if err != nil || !report_stat.Mode().IsRegular() { + WriteError(w, 3 /*Invalid Request*/) + return } - luaRegisterAccounts(L) - luaRegisterSecurities(L) - luaRegisterBalances(L) - luaRegisterDates(L) - - err = L.DoFile(reportpath) - + report, err := runReport(user, reportpath) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - log.Print("lua:" + err.Error()) + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + + err = report.Write(w) + if err != nil { + WriteError(w, 999 /*Internal Error*/) + log.Print(err) return } } diff --git a/reports_lua.go b/reports_lua.go new file mode 100644 index 0000000..2e11685 --- /dev/null +++ b/reports_lua.go @@ -0,0 +1,169 @@ +package main + +import ( + "encoding/json" + "github.com/yuin/gopher-lua" + "net/http" +) + +const luaReportTypeName = "report" +const luaSeriesTypeName = "series" + +type Series struct { + Values []float64 + Children map[string]*Series +} + +type Report struct { + Title string + Subtitle string + XAxisLabel string + YAxisLabel string + Labels []string + Series map[string]*Series +} + +func (r *Report) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(r) +} + +func luaRegisterReports(L *lua.LState) { + mtr := L.NewTypeMetatable(luaReportTypeName) + L.SetGlobal("report", mtr) + L.SetField(mtr, "new", L.NewFunction(luaReportNew)) + L.SetField(mtr, "__index", L.NewFunction(luaReport__index)) + L.SetField(mtr, "__metatable", lua.LString("protected")) + + mts := L.NewTypeMetatable(luaSeriesTypeName) + L.SetGlobal("series", mts) + L.SetField(mts, "__index", L.NewFunction(luaSeries__index)) + L.SetField(mts, "__metatable", lua.LString("protected")) +} + +// Checks whether the first lua argument is a *LUserData with *Report and returns *Report +func luaCheckReport(L *lua.LState, n int) *Report { + ud := L.CheckUserData(n) + if report, ok := ud.Value.(*Report); ok { + return report + } + L.ArgError(n, "report expected") + return nil +} + +// Checks whether the first lua argument is a *LUserData with *Series and returns *Series +func luaCheckSeries(L *lua.LState, n int) *Series { + ud := L.CheckUserData(n) + if series, ok := ud.Value.(*Series); ok { + return series + } + L.ArgError(n, "series expected") + return nil +} + +func luaReportNew(L *lua.LState) int { + numvalues := L.CheckInt(1) + ud := L.NewUserData() + ud.Value = &Report{ + Labels: make([]string, numvalues), + Series: make(map[string]*Series), + } + L.SetMetatable(ud, L.GetTypeMetatable(luaReportTypeName)) + L.Push(ud) + return 1 +} + +func luaReport__index(L *lua.LState) int { + field := L.CheckString(2) + + switch field { + case "Label", "label": + L.Push(L.NewFunction(luaReportLabel)) + case "Series", "series": + L.Push(L.NewFunction(luaReportSeries)) + default: + L.ArgError(2, "unexpected report attribute: "+field) + } + + return 1 +} + +func luaReportLabel(L *lua.LState) int { + report := luaCheckReport(L, 1) + labelnumber := L.CheckInt(2) + label := L.CheckString(3) + + if labelnumber > cap(report.Labels) || labelnumber < 1 { + L.ArgError(2, "Label index must be between 1 and the number of data points, inclusive") + } + report.Labels[labelnumber-1] = label + return 0 +} + +func luaReportSeries(L *lua.LState) int { + report := luaCheckReport(L, 1) + name := L.CheckString(2) + ud := L.NewUserData() + + s, ok := report.Series[name] + if ok { + ud.Value = s + } else { + report.Series[name] = &Series{ + Children: make(map[string]*Series), + Values: make([]float64, cap(report.Labels)), + } + ud.Value = report.Series[name] + } + L.SetMetatable(ud, L.GetTypeMetatable(luaSeriesTypeName)) + L.Push(ud) + return 1 +} + +func luaSeries__index(L *lua.LState) int { + field := L.CheckString(2) + + switch field { + case "Value", "value": + L.Push(L.NewFunction(luaSeriesValue)) + case "Series", "series", "Child", "child": + L.Push(L.NewFunction(luaSeriesChildren)) + default: + L.ArgError(2, "unexpected series attribute: "+field) + } + + return 1 +} + +func luaSeriesValue(L *lua.LState) int { + series := luaCheckSeries(L, 1) + valuenumber := L.CheckInt(2) + value := float64(L.CheckNumber(3)) + + if valuenumber > cap(series.Values) || valuenumber < 1 { + L.ArgError(2, "value index must be between 1 and the number of data points, inclusive") + } + series.Values[valuenumber-1] = value + + return 0 +} + +func luaSeriesChildren(L *lua.LState) int { + parent := luaCheckSeries(L, 1) + name := L.CheckString(2) + ud := L.NewUserData() + + s, ok := parent.Children[name] + if ok { + ud.Value = s + } else { + parent.Children[name] = &Series{ + Children: make(map[string]*Series), + Values: make([]float64, cap(parent.Values)), + } + ud.Value = parent.Children[name] + } + L.SetMetatable(ud, L.GetTypeMetatable(luaSeriesTypeName)) + L.Push(ud) + return 1 +}