mirror of
				https://github.com/aclindsa/moneygo.git
				synced 2025-10-30 09:33:25 -04:00 
			
		
		
		
	Add tests for users and sessions
Split out common test infrastructure from security_templates_test, make tests HTTPS, use the http.Client provided by httptest
This commit is contained in:
		
							
								
								
									
										66
									
								
								internal/handlers/common_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								internal/handlers/common_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| package handlers_test | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"github.com/aclindsa/moneygo/internal/config" | ||||
| 	"github.com/aclindsa/moneygo/internal/db" | ||||
| 	"github.com/aclindsa/moneygo/internal/handlers" | ||||
| 	"io/ioutil" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"path" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| var server *httptest.Server | ||||
|  | ||||
| func Delete(client *http.Client, url string) (*http.Response, error) { | ||||
| 	request, err := http.NewRequest(http.MethodDelete, url, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return client.Do(request) | ||||
| } | ||||
|  | ||||
| func PutForm(client *http.Client, url string, data url.Values) (*http.Response, error) { | ||||
| 	request, err := http.NewRequest(http.MethodPut, url, strings.NewReader(data.Encode())) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	request.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
| 	return client.Do(request) | ||||
| } | ||||
|  | ||||
| func RunTests(m *testing.M) int { | ||||
| 	tmpdir, err := ioutil.TempDir("./", "handlertest") | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| 	defer os.RemoveAll(tmpdir) | ||||
|  | ||||
| 	dbpath := path.Join(tmpdir, "moneygo.sqlite") | ||||
| 	database, err := sql.Open("sqlite3", "file:"+dbpath+"?cache=shared&mode=rwc") | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| 	defer database.Close() | ||||
|  | ||||
| 	dbmap, err := db.GetDbMap(database, config.SQLite) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	servemux := handlers.GetHandler(dbmap) | ||||
| 	server = httptest.NewTLSServer(servemux) | ||||
| 	defer server.Close() | ||||
|  | ||||
| 	return m.Run() | ||||
| } | ||||
|  | ||||
| func TestMain(m *testing.M) { | ||||
| 	os.Exit(RunTests(m)) | ||||
| } | ||||
| @@ -1,54 +1,14 @@ | ||||
| package handlers_test | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"github.com/aclindsa/moneygo/internal/config" | ||||
| 	"github.com/aclindsa/moneygo/internal/db" | ||||
| 	"github.com/aclindsa/moneygo/internal/handlers" | ||||
| 	"io/ioutil" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"path" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| var server *httptest.Server | ||||
| 
 | ||||
| func RunTests(m *testing.M) int { | ||||
| 	tmpdir, err := ioutil.TempDir("./", "handlertest") | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| 	defer os.RemoveAll(tmpdir) | ||||
| 
 | ||||
| 	dbpath := path.Join(tmpdir, "moneygo.sqlite") | ||||
| 	database, err := sql.Open("sqlite3", "file:"+dbpath+"?cache=shared&mode=rwc") | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| 	defer database.Close() | ||||
| 
 | ||||
| 	dbmap, err := db.GetDbMap(database, config.SQLite) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	servemux := handlers.GetHandler(dbmap) | ||||
| 	server = httptest.NewServer(servemux) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	return m.Run() | ||||
| } | ||||
| 
 | ||||
| func TestMain(m *testing.M) { | ||||
| 	os.Exit(RunTests(m)) | ||||
| } | ||||
| 
 | ||||
| func TestSecurityTemplates(t *testing.T) { | ||||
| 	var sl handlers.SecurityList | ||||
| 	response, err := http.Get(server.URL + "/securitytemplate/?search=USD&type=currency") | ||||
| 	response, err := server.Client().Get(server.URL + "/securitytemplate/?search=USD&type=currency") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -85,7 +45,7 @@ func TestSecurityTemplates(t *testing.T) { | ||||
| 
 | ||||
| func TestSecurityTemplateLimit(t *testing.T) { | ||||
| 	var sl handlers.SecurityList | ||||
| 	response, err := http.Get(server.URL + "/securitytemplate/?search=e&limit=5") | ||||
| 	response, err := server.Client().Get(server.URL + "/securitytemplate/?search=e&limit=5") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -111,7 +71,7 @@ func TestSecurityTemplateLimit(t *testing.T) { | ||||
| 
 | ||||
| func TestSecurityTemplateInvalidType(t *testing.T) { | ||||
| 	var e handlers.Error | ||||
| 	response, err := http.Get(server.URL + "/securitytemplate/?search=e&type=blah") | ||||
| 	response, err := server.Client().Get(server.URL + "/securitytemplate/?search=e&type=blah") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -134,7 +94,7 @@ func TestSecurityTemplateInvalidType(t *testing.T) { | ||||
| 
 | ||||
| func TestSecurityTemplateInvalidLimit(t *testing.T) { | ||||
| 	var e handlers.Error | ||||
| 	response, err := http.Get(server.URL + "/securitytemplate/?search=e&type=Currency&limit=foo") | ||||
| 	response, err := server.Client().Get(server.URL + "/securitytemplate/?search=e&type=Currency&limit=foo") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -8,6 +8,7 @@ import ( | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| @@ -22,6 +23,11 @@ func (s *Session) Write(w http.ResponseWriter) error { | ||||
| 	return enc.Encode(s) | ||||
| } | ||||
|  | ||||
| func (s *Session) Read(json_str string) error { | ||||
| 	dec := json.NewDecoder(strings.NewReader(json_str)) | ||||
| 	return dec.Decode(s) | ||||
| } | ||||
|  | ||||
| func GetSession(db *DB, r *http.Request) (*Session, error) { | ||||
| 	var s Session | ||||
|  | ||||
|   | ||||
							
								
								
									
										144
									
								
								internal/handlers/sessions_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								internal/handlers/sessions_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | ||||
| package handlers_test | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/handlers" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/http/cookiejar" | ||||
| 	"net/url" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func newSession(user *User) (*http.Client, error) { | ||||
| 	var u User | ||||
| 	var e handlers.Error | ||||
|  | ||||
| 	jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: nil}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	client := server.Client() | ||||
| 	client.Jar = jar | ||||
|  | ||||
| 	bytes, err := json.Marshal(user) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	response, err := client.PostForm(server.URL+"/session/", url.Values{"user": {string(bytes)}}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	body, err := ioutil.ReadAll(response.Body) | ||||
| 	response.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = (&u).Read(string(body)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = (&e).Read(string(body)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if e.ErrorId != 0 || len(e.ErrorString) != 0 { | ||||
| 		return nil, fmt.Errorf("Unexpected error when creating session %+v", e) | ||||
| 	} | ||||
|  | ||||
| 	return client, nil | ||||
| } | ||||
|  | ||||
| func getSession(client *http.Client) (*handlers.Session, error) { | ||||
| 	var s handlers.Session | ||||
| 	response, err := client.Get(server.URL + "/session/") | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	body, err := ioutil.ReadAll(response.Body) | ||||
| 	response.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = (&s).Read(string(body)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func sessionExistsOrError(c *http.Client) error { | ||||
|  | ||||
| 	url, err := url.Parse(server.URL) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	cookies := c.Jar.Cookies(url) | ||||
|  | ||||
| 	var found_session bool = false | ||||
| 	for _, cookie := range cookies { | ||||
| 		if cookie.Name == "moneygo-session" { | ||||
| 			found_session = true | ||||
| 		} | ||||
| 	} | ||||
| 	if found_session { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return fmt.Errorf("Didn't find 'moneygo-session' cookie in CookieJar") | ||||
| } | ||||
|  | ||||
| func TestCreateSession(t *testing.T) { | ||||
| 	u, err := createUser(&users[0]) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	u.Password = users[0].Password | ||||
|  | ||||
| 	client, err := newSession(u) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer deleteUser(client, u) | ||||
| 	if err := sessionExistsOrError(client); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestGetSession(t *testing.T) { | ||||
| 	u, err := createUser(&users[0]) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	u.Password = users[0].Password | ||||
|  | ||||
| 	client, err := newSession(u) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer deleteUser(client, u) | ||||
| 	session, err := getSession(client) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if len(session.SessionSecret) != 0 { | ||||
| 		t.Error("Session.SessionSecret should not be passed back in JSON") | ||||
| 	} | ||||
|  | ||||
| 	if session.UserId != u.UserId { | ||||
| 		t.Errorf("session's UserId (%d) should equal user's UserID (%d)", session.UserId, u.UserId) | ||||
| 	} | ||||
|  | ||||
| 	if session.SessionId == 0 { | ||||
| 		t.Error("session's SessionId should not be 0") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										39
									
								
								internal/handlers/testdata_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								internal/handlers/testdata_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| package handlers_test | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // Needed because handlers.User doesn't allow Password to be written to JSON | ||||
|  | ||||
| type User struct { | ||||
| 	UserId          int64 | ||||
| 	DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user | ||||
| 	Name            string | ||||
| 	Username        string | ||||
| 	Password        string | ||||
| 	PasswordHash    string | ||||
| 	Email           string | ||||
| } | ||||
|  | ||||
| func (u *User) Write(w http.ResponseWriter) error { | ||||
| 	enc := json.NewEncoder(w) | ||||
| 	return enc.Encode(u) | ||||
| } | ||||
|  | ||||
| func (u *User) Read(json_str string) error { | ||||
| 	dec := json.NewDecoder(strings.NewReader(json_str)) | ||||
| 	return dec.Decode(u) | ||||
| } | ||||
|  | ||||
| var users = []User{ | ||||
| 	User{ | ||||
| 		DefaultCurrency: 840, // USD | ||||
| 		Name:            "John Smith", | ||||
| 		Username:        "jsmith", | ||||
| 		Password:        "hunter2", | ||||
| 		Email:           "jsmith@example.com", | ||||
| 	}, | ||||
| } | ||||
							
								
								
									
										218
									
								
								internal/handlers/users_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										218
									
								
								internal/handlers/users_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,218 @@ | ||||
| package handlers_test | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/handlers" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func createUser(user *User) (*User, error) { | ||||
| 	bytes, err := json.Marshal(user) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	response, err := server.Client().PostForm(server.URL+"/user/", url.Values{"user": {string(bytes)}}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	body, err := ioutil.ReadAll(response.Body) | ||||
| 	response.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var e handlers.Error | ||||
| 	err = (&e).Read(string(body)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if e.ErrorId != 0 || len(e.ErrorString) != 0 { | ||||
| 		return nil, fmt.Errorf("Error when creating user %+v", e) | ||||
| 	} | ||||
|  | ||||
| 	var u User | ||||
| 	err = (&u).Read(string(body)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if u.UserId == 0 || len(u.Username) == 0 { | ||||
| 		return nil, fmt.Errorf("Unable to create user: %+v", user) | ||||
| 	} | ||||
|  | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func updateUser(client *http.Client, user *User) (*User, error) { | ||||
| 	bytes, err := json.Marshal(user) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	response, err := PutForm(client, server.URL+"/user/"+strconv.FormatInt(user.UserId, 10), url.Values{"user": {string(bytes)}}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	body, err := ioutil.ReadAll(response.Body) | ||||
| 	response.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var e handlers.Error | ||||
| 	err = (&e).Read(string(body)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if e.ErrorId != 0 || len(e.ErrorString) != 0 { | ||||
| 		return nil, fmt.Errorf("Error when updating user %+v", e) | ||||
| 	} | ||||
|  | ||||
| 	var u User | ||||
| 	err = (&u).Read(string(body)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if u.UserId == 0 || len(u.Username) == 0 { | ||||
| 		return nil, fmt.Errorf("Unable to update user: %+v", user) | ||||
| 	} | ||||
|  | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func deleteUser(client *http.Client, u *User) error { | ||||
| 	response, err := Delete(client, server.URL+"/user/"+strconv.FormatInt(u.UserId, 10)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	body, err := ioutil.ReadAll(response.Body) | ||||
| 	response.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var e handlers.Error | ||||
| 	err = (&e).Read(string(body)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if e.ErrorId != 0 || len(e.ErrorString) != 0 { | ||||
| 		return fmt.Errorf("Error when deleting user %+v", e) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func getUser(client *http.Client, userid int64) (*User, error) { | ||||
| 	response, err := client.Get(server.URL + "/user/" + strconv.FormatInt(userid, 10)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	body, err := ioutil.ReadAll(response.Body) | ||||
| 	response.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var e handlers.Error | ||||
| 	err = (&e).Read(string(body)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if e.ErrorId != 0 || len(e.ErrorString) != 0 { | ||||
| 		return nil, fmt.Errorf("Error when get user %+v", e) | ||||
| 	} | ||||
|  | ||||
| 	var u User | ||||
| 	err = (&u).Read(string(body)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if u.UserId == 0 || len(u.Username) == 0 { | ||||
| 		return nil, fmt.Errorf("Unable to get userid: %d", userid) | ||||
| 	} | ||||
|  | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func TestCreateUser(t *testing.T) { | ||||
| 	u, err := createUser(&users[0]) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if len(u.Password) != 0 || len(u.PasswordHash) != 0 { | ||||
| 		t.Error("Never send password, only send password hash when necessary") | ||||
| 	} | ||||
|  | ||||
| 	u.Password = users[0].Password | ||||
|  | ||||
| 	client, err := newSession(u) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error creating new session, user not deleted (may cause errors in other tests): %s", err) | ||||
| 	} | ||||
| 	defer deleteUser(client, u) | ||||
| } | ||||
|  | ||||
| func TestGetUser(t *testing.T) { | ||||
| 	origu, err := createUser(&users[0]) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	origu.Password = users[0].Password | ||||
|  | ||||
| 	client, err := newSession(origu) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error creating new session, user not deleted (may cause errors in other tests): %s", err) | ||||
| 	} | ||||
| 	defer deleteUser(client, origu) | ||||
|  | ||||
| 	u, err := getUser(client, origu.UserId) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error fetching user: %s\n", err) | ||||
| 	} | ||||
| 	if u.UserId != origu.UserId { | ||||
| 		t.Errorf("UserId doesn't match") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestUpdateUser(t *testing.T) { | ||||
| 	origu, err := createUser(&users[0]) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	origu.Password = users[0].Password | ||||
|  | ||||
| 	client, err := newSession(origu) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error creating new session, user not deleted (may cause errors in other tests): %s", err) | ||||
| 	} | ||||
| 	defer deleteUser(client, origu) | ||||
|  | ||||
| 	origu.Name = "Bob" | ||||
| 	origu.Email = "bob@example.com" | ||||
|  | ||||
| 	u, err := updateUser(client, origu) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error updating user: %s\n", err) | ||||
| 	} | ||||
| 	if u.UserId != origu.UserId { | ||||
| 		t.Errorf("UserId doesn't match") | ||||
| 	} | ||||
| 	if u.Name != origu.Name { | ||||
| 		t.Errorf("UserId doesn't match") | ||||
| 	} | ||||
| 	if u.Email != origu.Email { | ||||
| 		t.Errorf("UserId doesn't match") | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user