server: Surround sqlite3 transactions with a lock to handle concurrency
This commit is contained in:
		| @@ -4,9 +4,15 @@ import ( | |||||||
| 	"asink" | 	"asink" | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	_ "github.com/mattn/go-sqlite3" | 	_ "github.com/mattn/go-sqlite3" | ||||||
|  | 	"sync" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetAndInitDB() (*sql.DB, error) { | type AsinkDB struct { | ||||||
|  | 	db   *sql.DB | ||||||
|  | 	lock sync.Mutex | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetAndInitDB() (*AsinkDB, error) { | ||||||
| 	dbLocation := "asink-server.db" //TODO make me configurable | 	dbLocation := "asink-server.db" //TODO make me configurable | ||||||
|  |  | ||||||
| 	db, err := sql.Open("sqlite3", dbLocation) | 	db, err := sql.Open("sqlite3", dbLocation) | ||||||
| @@ -33,11 +39,14 @@ func GetAndInitDB() (*sql.DB, error) { | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return db, nil | 	ret := new(AsinkDB) | ||||||
|  | 	ret.db = db | ||||||
|  | 	return ret, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func DatabaseAddEvent(db *sql.DB, e *asink.Event) (err error) { | func (adb *AsinkDB) DatabaseAddEvent(e *asink.Event) (err error) { | ||||||
| 	tx, err := db.Begin() | 	adb.lock.Lock() | ||||||
|  | 	tx, err := adb.db.Begin() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -47,6 +56,7 @@ func DatabaseAddEvent(db *sql.DB, e *asink.Event) (err error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			tx.Rollback() | 			tx.Rollback() | ||||||
| 		} | 		} | ||||||
|  | 		adb.lock.Unlock() | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	result, err := tx.Exec("INSERT INTO events (localid, type, status, path, hash, timestamp, permissions) VALUES (?,?,?,?,?,?,?);", e.LocalId, e.Type, e.Status, e.Path, e.Hash, e.Timestamp, e.Permissions) | 	result, err := tx.Exec("INSERT INTO events (localid, type, status, path, hash, timestamp, permissions) VALUES (?,?,?,?,?,?,?);", e.LocalId, e.Type, e.Status, e.Path, e.Hash, e.Timestamp, e.Permissions) | ||||||
| @@ -67,8 +77,13 @@ func DatabaseAddEvent(db *sql.DB, e *asink.Event) (err error) { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func DatabaseRetrieveEvents(db *sql.DB, firstId uint64, maxEvents uint) (events []*asink.Event, err error) { | func (adb *AsinkDB) DatabaseRetrieveEvents(firstId uint64, maxEvents uint) (events []*asink.Event, err error) { | ||||||
| 	rows, err := db.Query("SELECT id, localid, type, status, path, hash, timestamp, permissions FROM events WHERE id >= ? ORDER BY id ASC LIMIT ?;", firstId, maxEvents) | 	adb.lock.Lock() | ||||||
|  | 	//make sure the database gets unlocked on return | ||||||
|  | 	defer func() { | ||||||
|  | 		adb.lock.Unlock() | ||||||
|  | 	}() | ||||||
|  | 	rows, err := adb.db.Query("SELECT id, localid, type, status, path, hash, timestamp, permissions FROM events WHERE id >= ? ORDER BY id ASC LIMIT ?;", firstId, maxEvents) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -2,7 +2,6 @@ package main | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"asink" | 	"asink" | ||||||
| 	"database/sql" |  | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"flag" | 	"flag" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @@ -15,7 +14,7 @@ import ( | |||||||
| //global variables | //global variables | ||||||
| var eventsRegexp *regexp.Regexp | var eventsRegexp *regexp.Regexp | ||||||
| var port int = 8080 | var port int = 8080 | ||||||
| var db *sql.DB | var adb *AsinkDB | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| 	var err error | 	var err error | ||||||
| @@ -26,7 +25,7 @@ func init() { | |||||||
|  |  | ||||||
| 	eventsRegexp = regexp.MustCompile("^/events/([0-9]+)$") | 	eventsRegexp = regexp.MustCompile("^/events/([0-9]+)$") | ||||||
|  |  | ||||||
| 	db, err = GetAndInitDB() | 	adb, err = GetAndInitDB() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 	} | 	} | ||||||
| @@ -74,7 +73,7 @@ func getEvents(w http.ResponseWriter, r *http.Request, nextEvent uint64) { | |||||||
| 		w.Write(b) | 		w.Write(b) | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	events, err := DatabaseRetrieveEvents(db, nextEvent, 50) | 	events, err := adb.DatabaseRetrieveEvents(nextEvent, 50) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 		error_message = err.Error() | 		error_message = err.Error() | ||||||
| @@ -126,7 +125,7 @@ func putEvents(w http.ResponseWriter, r *http.Request) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	for _, event := range events.Events { | 	for _, event := range events.Events { | ||||||
| 		err = DatabaseAddEvent(db, event) | 		err = adb.DatabaseAddEvent(event) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			//TODO should probably do this in a way that the caller knows how many of these have failed and doesn't re-try sending ones that succeeded | 			//TODO should probably do this in a way that the caller knows how many of these have failed and doesn't re-try sending ones that succeeded | ||||||
| 			//i.e. add this to the return codes or something | 			//i.e. add this to the return codes or something | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user