diff --git a/bank_test.go b/bank_test.go index bb9da4f..f9d7b9c 100644 --- a/bank_test.go +++ b/bank_test.go @@ -283,4 +283,5 @@ func TestUnmarshalBankStatementResponse(t *testing.T) { } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) } diff --git a/creditcard_test.go b/creditcard_test.go index e384488..d487dd6 100644 --- a/creditcard_test.go +++ b/creditcard_test.go @@ -161,4 +161,5 @@ NEWFILEUID:NONE } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) } diff --git a/invstmt.go b/invstmt.go index 3adeee6..9a36e45 100644 --- a/invstmt.go +++ b/invstmt.go @@ -465,6 +465,7 @@ type InvBankTransaction struct { // security-related transactions themselves. It must be unmarshalled manually // due to the structure (don't know what kind of InvTransaction is coming next) type InvTranList struct { + XMLName xml.Name `xml:"INVTRANLIST"` DtStart Date DtEnd Date // This is the value that should be sent as in the next InvStatementRequest to ensure that no transactions are missed InvTransactions []InvTransaction @@ -630,6 +631,119 @@ func (l *InvTranList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error } } +// MarshalXML handles marshalling an InvTranList element to an SGML/XML string +func (l *InvTranList) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + invTranListElement := xml.StartElement{Name: xml.Name{Local: "INVTRANLIST"}} + if err := e.EncodeToken(invTranListElement); err != nil { + return err + } + err := e.EncodeElement(&l.DtStart, xml.StartElement{Name: xml.Name{Local: "DTSTART"}}) + if err != nil { + return err + } + err = e.EncodeElement(&l.DtEnd, xml.StartElement{Name: xml.Name{Local: "DTEND"}}) + if err != nil { + return err + } + for _, t := range l.InvTransactions { + start := xml.StartElement{Name: xml.Name{Local: t.TransactionType()}} + switch tran := t.(type) { + case BuyDebt: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case BuyMF: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case BuyOpt: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case BuyOther: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case BuyStock: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case ClosureOpt: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case Income: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case InvExpense: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case JrnlFund: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case JrnlSec: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case MarginInterest: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case Reinvest: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case RetOfCap: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case SellDebt: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case SellMF: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case SellOpt: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case SellOther: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case SellStock: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case Split: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case Transfer: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + default: + return errors.New("Invalid INVTRANLIST child type: " + tran.TransactionType()) + } + } + for _, tran := range l.BankTransactions { + err = e.EncodeElement(&tran, xml.StartElement{Name: xml.Name{Local: "INVBANKTRAN"}}) + if err != nil { + return err + } + } + if err := e.EncodeToken(invTranListElement.End()); err != nil { + return err + } + return nil +} + // InvPosition contains generic position information included in each of the // other *Position types type InvPosition struct { @@ -770,6 +884,45 @@ func (p *PositionList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) erro } } +// MarshalXML handles marshalling a PositionList to an XML string +func (p *PositionList) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + invPosListElement := xml.StartElement{Name: xml.Name{Local: "INVPOSLIST"}} + if err := e.EncodeToken(invPosListElement); err != nil { + return err + } + for _, position := range *p { + start := xml.StartElement{Name: xml.Name{Local: position.PositionType()}} + switch pos := position.(type) { + case DebtPosition: + if err := e.EncodeElement(&pos, start); err != nil { + return err + } + case MFPosition: + if err := e.EncodeElement(&pos, start); err != nil { + return err + } + case OptPosition: + if err := e.EncodeElement(&pos, start); err != nil { + return err + } + case OtherPosition: + if err := e.EncodeElement(&pos, start); err != nil { + return err + } + case StockPosition: + if err := e.EncodeElement(&pos, start); err != nil { + return err + } + default: + return errors.New("Invalid INVPOSLIST child type: " + pos.PositionType()) + } + } + if err := e.EncodeToken(invPosListElement.End()); err != nil { + return err + } + return nil +} + // InvBalance contains three (or optionally four) specified balances as well as // a free-form list of generic balance information which may be provided by an // FI. @@ -1036,6 +1189,69 @@ func (o *OOList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { } } +// MarshalXML handles marshalling an OOList to an XML string +func (o *OOList) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + ooListElement := xml.StartElement{Name: xml.Name{Local: "INVOOLIST"}} + if err := e.EncodeToken(ooListElement); err != nil { + return err + } + for _, openorder := range *o { + start := xml.StartElement{Name: xml.Name{Local: openorder.OrderType()}} + switch oo := openorder.(type) { + case OOBuyDebt: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOBuyMF: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOBuyOpt: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOBuyOther: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOBuyStock: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSellDebt: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSellMF: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSellOpt: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSellOther: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSellStock: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSwitchMF: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + default: + return errors.New("Invalid OOLIST child type: " + oo.OrderType()) + } + } + if err := e.EncodeToken(ooListElement.End()); err != nil { + return err + } + return nil +} + // ContribSecurity identifies current contribution allocation for a security in // a 401(k) account type ContribSecurity struct { diff --git a/invstmt_test.go b/invstmt_test.go index afe37b9..e1a1e14 100644 --- a/invstmt_test.go +++ b/invstmt_test.go @@ -602,6 +602,7 @@ func TestUnmarshalInvStatementResponse(t *testing.T) { } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) } func TestUnmarshalInvStatementResponse102(t *testing.T) { @@ -957,6 +958,7 @@ NEWFILEUID: NONE } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) } func TestUnmarshalInvTranList(t *testing.T) { diff --git a/profile.go b/profile.go index c28cc8a..8bacc60 100644 --- a/profile.go +++ b/profile.go @@ -3,6 +3,7 @@ package ofxgo import ( "errors" "github.com/aclindsa/xml" + "strings" ) // ProfileRequest represents a request for a server to provide a profile of its @@ -126,6 +127,35 @@ func (msl *MessageSetList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) } } +// MarshalXML handles marshalling a MessageSetList element to an XML string +func (msl *MessageSetList) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + messageSetListElement := xml.StartElement{Name: xml.Name{Local: "MSGSETLIST"}} + if err := e.EncodeToken(messageSetListElement); err != nil { + return err + } + for _, messageset := range *msl { + if !strings.HasSuffix(messageset.Name, "V1") { + return errors.New("Expected MessageSet.Name to end with \"V1\"") + } + messageSetName := strings.TrimSuffix(messageset.Name, "V1") + messageSetElement := xml.StartElement{Name: xml.Name{Local: messageSetName}} + if err := e.EncodeToken(messageSetElement); err != nil { + return err + } + start := xml.StartElement{Name: xml.Name{Local: messageset.Name}} + if err := e.EncodeElement(&messageset, start); err != nil { + return err + } + if err := e.EncodeToken(messageSetElement.End()); err != nil { + return err + } + } + if err := e.EncodeToken(messageSetListElement.End()); err != nil { + return err + } + return nil +} + // ProfileResponse contains a requested profile of the server's capabilities // (which message sets and versions it supports, how to access them, which // languages and which types of synchronization they support, etc.). Note that diff --git a/profile_test.go b/profile_test.go index 71db474..b15d0c3 100644 --- a/profile_test.go +++ b/profile_test.go @@ -325,4 +325,5 @@ NEWFILEUID:NONE } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) } diff --git a/response.go b/response.go index 1d177be..9ab4db7 100644 --- a/response.go +++ b/response.go @@ -369,3 +369,70 @@ func ParseResponse(reader io.Reader) (*Response, error) { } } } + +// Marshal this Response into its SGML/XML representation held in a bytes.Buffer +// +// If error is non-nil, this bytes.Buffer is ready to be sent to an OFX client +func (or *Response) Marshal() (*bytes.Buffer, error) { + var b bytes.Buffer + + // Write the header appropriate to our version + writeHeader(&b, or.Version) + + encoder := xml.NewEncoder(&b) + encoder.Indent("", " ") + + ofxElement := xml.StartElement{Name: xml.Name{Local: "OFX"}} + + if err := encoder.EncodeToken(ofxElement); err != nil { + return nil, err + } + + if ok, err := or.Signon.Valid(or.Version); !ok { + return nil, err + } + signonMsgSet := xml.StartElement{Name: xml.Name{Local: SignonRs.String()}} + if err := encoder.EncodeToken(signonMsgSet); err != nil { + return nil, err + } + if err := encoder.Encode(&or.Signon); err != nil { + return nil, err + } + if err := encoder.EncodeToken(signonMsgSet.End()); err != nil { + return nil, err + } + + messageSets := []struct { + Messages []Message + Type messageType + }{ + {or.Signup, SignupRs}, + {or.Bank, BankRs}, + {or.CreditCard, CreditCardRs}, + {or.Loan, LoanRs}, + {or.InvStmt, InvStmtRs}, + {or.InterXfer, InterXferRs}, + {or.WireXfer, WireXferRs}, + {or.Billpay, BillpayRs}, + {or.Email, EmailRs}, + {or.SecList, SecListRs}, + {or.PresDir, PresDirRs}, + {or.PresDlv, PresDlvRs}, + {or.Prof, ProfRs}, + {or.Image, ImageRs}, + } + for _, set := range messageSets { + if err := encodeMessageSet(encoder, set.Messages, set.Type, or.Version); err != nil { + return nil, err + } + } + + if err := encoder.EncodeToken(ofxElement.End()); err != nil { + return nil, err + } + + if err := encoder.Flush(); err != nil { + return nil, err + } + return &b, nil +} diff --git a/response_test.go b/response_test.go index 1018002..c1faa3b 100644 --- a/response_test.go +++ b/response_test.go @@ -136,6 +136,20 @@ func checkResponsesEqual(t *testing.T, expected, actual *ofxgo.Response) { checkEqual(t, "", reflect.ValueOf(expected), reflect.ValueOf(actual)) } +func checkResponseRoundTrip(t *testing.T, response *ofxgo.Response) { + b, err := response.Marshal() + if err != nil { + t.Fatalf("Unexpected error re-marshaling OFX response: %s\n", err) + } + roundtripped, err := ofxgo.ParseResponse(b) + if err != nil { + t.Fatalf("Unexpected error re-parsing OFX response: %s\n", err) + } + checkResponsesEqual(t, response, roundtripped) +} + +// Ensure that these samples both parse without errors, and can be converted +// back and forth without changing. func TestValidSamples(t *testing.T) { fn := func(path string, info os.FileInfo, err error) error { if info.IsDir() { @@ -147,10 +161,11 @@ func TestValidSamples(t *testing.T) { if err != nil { t.Fatalf("Unexpected error opening %s: %s\n", path, err) } - _, err = ofxgo.ParseResponse(file) + response, err := ofxgo.ParseResponse(file) if err != nil { t.Fatalf("Unexpected error parsing OFX response in %s: %s\n", path, err) } + checkResponseRoundTrip(t, response) return nil } filepath.Walk("samples/valid_responses", fn) diff --git a/seclist.go b/seclist.go index 8c60582..c8d3334 100644 --- a/seclist.go +++ b/seclist.go @@ -221,6 +221,7 @@ func (i StockInfo) SecurityType() string { // SecurityList is a container for Security objects containaing information // about securities type SecurityList struct { + XMLName xml.Name `xml:"SECLIST"` Securities []Security } @@ -290,3 +291,42 @@ func (r *SecurityList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) erro } } } + +// MarshalXML handles marshalling a SecurityList to an SGML/XML string +func (r *SecurityList) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + secListElement := xml.StartElement{Name: xml.Name{Local: "SECLIST"}} + if err := e.EncodeToken(secListElement); err != nil { + return err + } + for _, s := range r.Securities { + start := xml.StartElement{Name: xml.Name{Local: s.SecurityType()}} + switch sec := s.(type) { + case DebtInfo: + if err := e.EncodeElement(&sec, start); err != nil { + return err + } + case MFInfo: + if err := e.EncodeElement(&sec, start); err != nil { + return err + } + case OptInfo: + if err := e.EncodeElement(&sec, start); err != nil { + return err + } + case OtherInfo: + if err := e.EncodeElement(&sec, start); err != nil { + return err + } + case StockInfo: + if err := e.EncodeElement(&sec, start); err != nil { + return err + } + default: + return errors.New("Invalid SECLIST child type: " + sec.SecurityType()) + } + } + if err := e.EncodeToken(secListElement.End()); err != nil { + return err + } + return nil +} diff --git a/signup_test.go b/signup_test.go index da2ac68..115ec1f 100644 --- a/signup_test.go +++ b/signup_test.go @@ -155,4 +155,5 @@ func TestUnmarshalAcctInfoResponse(t *testing.T) { } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) }