From 9624f0c5bcddd4622dfc0c31f329a5563e3acc00 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sun, 12 Nov 2017 21:12:49 -0500 Subject: [PATCH] Move to a consistent way of handling IDs in URLs --- internal/handlers/accounts.go | 47 +++++++++++-------------------- internal/handlers/handlers.go | 9 ++++++ internal/handlers/imports.go | 5 ++-- internal/handlers/prices.go | 26 +++++++++-------- internal/handlers/reports.go | 32 ++++++++------------- internal/handlers/securities.go | 11 ++++---- internal/handlers/transactions.go | 15 ++++------ internal/handlers/users.go | 2 +- internal/handlers/util.go | 13 --------- 9 files changed, 66 insertions(+), 94 deletions(-) diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index 379c1ce..562e053 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -5,7 +5,6 @@ import ( "errors" "log" "net/http" - "regexp" "strings" ) @@ -100,14 +99,6 @@ type AccountList struct { Accounts *[]Account `json:"accounts"` } -var accountTransactionsRE *regexp.Regexp -var accountImportRE *regexp.Regexp - -func init() { - accountTransactionsRE = regexp.MustCompile(`^/v1/accounts/[0-9]+/transactions/?$`) - accountImportRE = regexp.MustCompile(`^/v1/accounts/[0-9]+/imports/[a-z]+/?$`) -} - func (a *Account) Write(w http.ResponseWriter) error { enc := json.NewEncoder(w) return enc.Encode(a) @@ -384,18 +375,12 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { } if r.Method == "POST" { - // if URL looks like /v1/accounts/[0-9]+/imports, use the account - // import handler - if accountImportRE.MatchString(r.URL.Path) { - var accountid int64 - var importtype string - n, err := GetURLPieces(r.URL.Path, "/v1/accounts/%d/imports/%s", &accountid, &importtype) - - if err != nil || n != 2 { - log.Print(err) - return NewError(999 /*Internal Error*/) + if !context.LastLevel() { + accountid, err := context.NextID() + if err != nil || context.NextLevel() != "imports" { + return NewError(3 /*Invalid Request*/) } - return AccountImportHandler(context, r, user, accountid, importtype) + return AccountImportHandler(context, r, user, accountid) } account_json := r.PostFormValue("account") @@ -433,10 +418,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { return ResponseWrapper{201, &account} } else if r.Method == "GET" { - var accountid int64 - n, err := GetURLPieces(r.URL.Path, "/v1/accounts/%d", &accountid) - - if err != nil || n != 1 { + if context.LastLevel() { //Return all Accounts var al AccountList accounts, err := GetAccounts(context.Tx, user.UserId) @@ -446,13 +428,14 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { } al.Accounts = accounts return &al - } else { - // if URL looks like /account/[0-9]+/transactions, use the account - // transaction handler - if accountTransactionsRE.MatchString(r.URL.Path) { - return AccountTransactionsHandler(context, r, user, accountid) - } + } + accountid, err := context.NextID() + if err != nil { + return NewError(3 /*Invalid Request*/) + } + + if context.LastLevel() { // Return Account with this Id account, err := GetAccount(context.Tx, accountid, user.UserId) if err != nil { @@ -460,9 +443,11 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { } return account + } else if context.NextLevel() == "transactions" { + return AccountTransactionsHandler(context, r, user, accountid) } } else { - accountid, err := GetURLID(r.URL.Path) + accountid, err := context.NextID() if err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 12d1b55..e4f1294 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -5,6 +5,7 @@ import ( "log" "net/http" "path" + "strconv" "strings" ) @@ -35,6 +36,14 @@ func (c *Context) NextLevel() string { return split[0] } +func (c *Context) NextID() (int64, error) { + return strconv.ParseInt(c.NextLevel(), 0, 64) +} + +func (c *Context) LastLevel() bool { + return len(c.remainingURL) == 0 +} + type Handler func(*http.Request, *Context) ResponseWriterWriter type APIHandler struct { diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 3fe5cde..d40038c 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -335,9 +335,10 @@ func OFXFileImportHandler(context *Context, r *http.Request, user *User, account /* * Assumes the User is a valid, signed-in user, but accountid has not yet been validated */ -func AccountImportHandler(context *Context, r *http.Request, user *User, accountid int64, importtype string) ResponseWriterWriter { +func AccountImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { - switch importtype { + importType := context.NextLevel() + switch importType { case "ofx": return OFXImportHandler(context, r, user, accountid) case "ofxfile": diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index e2240fd..13b0f67 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -165,10 +165,7 @@ func PriceHandler(r *http.Request, context *Context) ResponseWriterWriter { return ResponseWrapper{201, &price} } else if r.Method == "GET" { - var priceid int64 - n, err := GetURLPieces(r.URL.Path, "/v1/prices/%d", &priceid) - - if err != nil || n != 1 { + if context.LastLevel() { //Return all prices var pl PriceList @@ -180,16 +177,21 @@ func PriceHandler(r *http.Request, context *Context) ResponseWriterWriter { pl.Prices = prices return &pl - } else { - price, err := GetPrice(context.Tx, priceid, user.UserId) - if err != nil { - return NewError(3 /*Invalid Request*/) - } - - return price } + + priceid, err := context.NextID() + if err != nil { + return NewError(3 /*Invalid Request*/) + } + + price, err := GetPrice(context.Tx, priceid, user.UserId) + if err != nil { + return NewError(3 /*Invalid Request*/) + } + + return price } else { - priceid, err := GetURLID(r.URL.Path) + priceid, err := context.NextID() if err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index 905ba01..c9611d1 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -8,17 +8,10 @@ import ( "github.com/yuin/gopher-lua" "log" "net/http" - "regexp" "strings" "time" ) -var reportTabulationRE *regexp.Regexp - -func init() { - reportTabulationRE = regexp.MustCompile(`^/v1/reports/[0-9]+/tabulations/?$`) -} - //type and value to store user in lua's Context type key int @@ -255,19 +248,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { return ResponseWrapper{201, &report} } else if r.Method == "GET" { - if reportTabulationRE.MatchString(r.URL.Path) { - var reportid int64 - n, err := GetURLPieces(r.URL.Path, "/v1/reports/%d/tabulations", &reportid) - if err != nil || n != 1 { - log.Print(err) - return NewError(999 /*InternalError*/) - } - return ReportTabulationHandler(context.Tx, r, user, reportid) - } - - var reportid int64 - n, err := GetURLPieces(r.URL.Path, "/v1/reports/%d", &reportid) - if err != nil || n != 1 { + if context.LastLevel() { //Return all Reports var rl ReportList reports, err := GetReports(context.Tx, user.UserId) @@ -277,6 +258,15 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { } rl.Reports = reports return &rl + } + + reportid, err := context.NextID() + if err != nil { + return NewError(3 /*Invalid Request*/) + } + + if context.NextLevel() == "tabulations" { + return ReportTabulationHandler(context.Tx, r, user, reportid) } else { // Return Report with this Id report, err := GetReport(context.Tx, reportid, user.UserId) @@ -287,7 +277,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { return report } } else { - reportid, err := GetURLID(r.URL.Path) + reportid, err := context.NextID() if err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index 03ce542..1608973 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -274,10 +274,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { return ResponseWrapper{201, &security} } else if r.Method == "GET" { - var securityid int64 - n, err := GetURLPieces(r.URL.Path, "/v1/securities/%d", &securityid) - - if err != nil || n != 1 { + if context.LastLevel() { //Return all securities var sl SecurityList @@ -290,6 +287,10 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { sl.Securities = securities return &sl } else { + securityid, err := context.NextID() + if err != nil { + return NewError(3 /*Invalid Request*/) + } security, err := GetSecurity(context.Tx, securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) @@ -298,7 +299,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { return security } } else { - securityid, err := GetURLID(r.URL.Path) + securityid, err := context.NextID() if err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 5752b5d..4cba509 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -452,9 +452,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return &transaction } else if r.Method == "GET" { - transactionid, err := GetURLID(r.URL.Path) - - if err != nil { + if context.LastLevel() { //Return all Transactions var al TransactionList transactions, err := GetTransactions(context.Tx, user.UserId) @@ -466,6 +464,10 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return &al } else { //Return Transaction with this Id + transactionid, err := context.NextID() + if err != nil { + return NewError(3 /*Invalid Request*/) + } transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) @@ -473,7 +475,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return transaction } } else { - transactionid, err := GetURLID(r.URL.Path) + transactionid, err := context.NextID() if err != nil { return NewError(3 /*Invalid Request*/) } @@ -518,11 +520,6 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return &transaction } else if r.Method == "DELETE" { - transactionid, err := GetURLID(r.URL.Path) - if err != nil { - return NewError(3 /*Invalid Request*/) - } - transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 5041046..b3dc2e9 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -207,7 +207,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(1 /*Not Signed In*/) } - userid, err := GetURLID(r.URL.Path) + userid, err := context.NextID() if err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/util.go b/internal/handlers/util.go index 2533fcc..f83622b 100644 --- a/internal/handlers/util.go +++ b/internal/handlers/util.go @@ -3,21 +3,8 @@ package handlers import ( "fmt" "net/http" - "strconv" - "strings" ) -func GetURLID(url string) (int64, error) { - pieces := strings.Split(strings.Trim(url, "/"), "/") - return strconv.ParseInt(pieces[len(pieces)-1], 10, 0) -} - -func GetURLPieces(url string, format string, a ...interface{}) (int, error) { - url = strings.Replace(url, "/", " ", -1) - format = strings.Replace(format, "/", " ", -1) - return fmt.Sscanf(url, format, a...) -} - type ResponseWrapper struct { Code int Writer ResponseWriterWriter