mirror of
				https://github.com/aclindsa/moneygo.git
				synced 2025-10-30 17:33:26 -04:00 
			
		
		
		
	Lay groundwork and move sessions to 'store'
This commit is contained in:
		
							
								
								
									
										100
									
								
								internal/store/db/db.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								internal/store/db/db.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | ||||
| package db | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/gorp" | ||||
| 	"github.com/aclindsa/moneygo/internal/config" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"github.com/aclindsa/moneygo/internal/store" | ||||
| 	_ "github.com/go-sql-driver/mysql" | ||||
| 	_ "github.com/lib/pq" | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| 	"log" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // luaMaxLengthBuffer is intended to be enough bytes such that a given string | ||||
| // no longer than models.LuaMaxLength is sure to fit within a database | ||||
| // implementation's string type specified by the same. | ||||
| const luaMaxLengthBuffer int = 4096 | ||||
|  | ||||
| func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { | ||||
| 	var dialect gorp.Dialect | ||||
| 	if dbtype == config.SQLite { | ||||
| 		dialect = gorp.SqliteDialect{} | ||||
| 	} else if dbtype == config.MySQL { | ||||
| 		dialect = gorp.MySQLDialect{ | ||||
| 			Engine:   "InnoDB", | ||||
| 			Encoding: "UTF8", | ||||
| 		} | ||||
| 	} else if dbtype == config.Postgres { | ||||
| 		dialect = gorp.PostgresDialect{ | ||||
| 			LowercaseFields: true, | ||||
| 		} | ||||
| 	} else { | ||||
| 		return nil, fmt.Errorf("Don't know gorp dialect to go with '%s' DB type", dbtype.String()) | ||||
| 	} | ||||
|  | ||||
| 	dbmap := &gorp.DbMap{Db: db, Dialect: dialect} | ||||
| 	dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") | ||||
| 	dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId") | ||||
| 	dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId") | ||||
| 	dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId") | ||||
| 	dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId") | ||||
| 	dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId") | ||||
| 	dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId") | ||||
| 	rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId") | ||||
| 	rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer) | ||||
|  | ||||
| 	err := dbmap.CreateTablesIfNotExists() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return dbmap, nil | ||||
| } | ||||
|  | ||||
| func GetDSN(dbtype config.DbType, dsn string) string { | ||||
| 	if dbtype == config.MySQL && !strings.Contains(dsn, "parseTime=true") { | ||||
| 		log.Fatalf("The DSN for MySQL MUST contain 'parseTime=True' but does not!") | ||||
| 	} | ||||
| 	return dsn | ||||
| } | ||||
|  | ||||
| type DbStore struct { | ||||
| 	DbMap *gorp.DbMap | ||||
| } | ||||
|  | ||||
| func (db *DbStore) Begin() (store.Tx, error) { | ||||
| 	tx, err := db.DbMap.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &Tx{db.DbMap.Dialect, tx}, nil | ||||
| } | ||||
|  | ||||
| func (db *DbStore) Close() error { | ||||
| 	err := db.DbMap.Db.Close() | ||||
| 	db.DbMap = nil | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) { | ||||
| 	dsn = GetDSN(dbtype, dsn) | ||||
| 	database, err := sql.Open(dbtype.String(), dsn) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		if err != nil { | ||||
| 			database.Close() | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	dbmap, err := GetDbMap(database, dbtype) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &DbStore{dbmap}, nil | ||||
| } | ||||
							
								
								
									
										36
									
								
								internal/store/db/sessions.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								internal/store/db/sessions.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| package db | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/aclindsa/moneygo/internal/models" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func (tx *Tx) InsertSession(session *models.Session) error { | ||||
| 	return tx.Insert(session) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) GetSession(secret string) (*models.Session, error) { | ||||
| 	var s models.Session | ||||
|  | ||||
| 	err := tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", secret) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if s.Expires.Before(time.Now()) { | ||||
| 		tx.Delete(&s) | ||||
| 		return nil, fmt.Errorf("Session has expired") | ||||
| 	} | ||||
| 	return &s, nil | ||||
| } | ||||
|  | ||||
| func (tx *Tx) SessionExists(secret string) (bool, error) { | ||||
| 	existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", secret) | ||||
| 	return existing != 0, err | ||||
| } | ||||
|  | ||||
| func (tx *Tx) DeleteSession(session *models.Session) error { | ||||
| 	_, err := tx.Delete(session) | ||||
| 	return err | ||||
| } | ||||
							
								
								
									
										57
									
								
								internal/store/db/tx.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								internal/store/db/tx.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| package db | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"github.com/aclindsa/gorp" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type Tx struct { | ||||
| 	Dialect gorp.Dialect | ||||
| 	Tx      *gorp.Transaction | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Rebind(query string) string { | ||||
| 	chunks := strings.Split(query, "?") | ||||
| 	str := chunks[0] | ||||
| 	for i := 1; i < len(chunks); i++ { | ||||
| 		str += tx.Dialect.BindVar(i-1) + chunks[i] | ||||
| 	} | ||||
| 	return str | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { | ||||
| 	return tx.Tx.Select(i, tx.Rebind(query), args...) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { | ||||
| 	return tx.Tx.Exec(tx.Rebind(query), args...) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) { | ||||
| 	return tx.Tx.SelectInt(tx.Rebind(query), args...) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error { | ||||
| 	return tx.Tx.SelectOne(holder, tx.Rebind(query), args...) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Insert(list ...interface{}) error { | ||||
| 	return tx.Tx.Insert(list...) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Update(list ...interface{}) (int64, error) { | ||||
| 	return tx.Tx.Update(list...) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Delete(list ...interface{}) (int64, error) { | ||||
| 	return tx.Tx.Delete(list...) | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Commit() error { | ||||
| 	return tx.Tx.Commit() | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Rollback() error { | ||||
| 	return tx.Tx.Rollback() | ||||
| } | ||||
		Reference in New Issue
	
	Block a user