diff --git a/csv.go b/csv.go index 17cd5c5..4d9bbe3 100644 --- a/csv.go +++ b/csv.go @@ -114,67 +114,67 @@ func getCSVReader(in io.Reader) CSVReader { // Marshal functions // MarshalFile saves the interface as CSV in the file. -func MarshalFile(in interface{}, file *os.File, removeFieldsIndexes []int) (err error) { - return Marshal(in, file, removeFieldsIndexes) +func MarshalFile(in interface{}, file *os.File, removeFieldsIndexes []int, colIndex []int) (err error) { + return Marshal(in, file, removeFieldsIndexes, colIndex) } // MarshalString returns the CSV string from the interface. -func MarshalString(in interface{}, removeFieldsIndexes []int) (out string, err error) { +func MarshalString(in interface{}, removeFieldsIndexes []int, colIndex []int) (out string, err error) { bufferString := bytes.NewBufferString(out) - if err := Marshal(in, bufferString, removeFieldsIndexes); err != nil { + if err := Marshal(in, bufferString, removeFieldsIndexes, colIndex); err != nil { return "", err } return bufferString.String(), nil } // MarshalStringWithoutHeaders returns the CSV string from the interface. -func MarshalStringWithoutHeaders(in interface{}, removeFieldsIndexes []int) (out string, err error) { +func MarshalStringWithoutHeaders(in interface{}, removeFieldsIndexes []int, colIndex []int) (out string, err error) { bufferString := bytes.NewBufferString(out) - if err := MarshalWithoutHeaders(in, bufferString, removeFieldsIndexes); err != nil { + if err := MarshalWithoutHeaders(in, bufferString, removeFieldsIndexes, colIndex); err != nil { return "", err } return bufferString.String(), nil } // MarshalBytes returns the CSV bytes from the interface. -func MarshalBytes(in interface{}, removeFieldsIndexes []int) (out []byte, err error) { +func MarshalBytes(in interface{}, removeFieldsIndexes []int, colIndex []int) (out []byte, err error) { bufferString := bytes.NewBuffer(out) - if err := Marshal(in, bufferString, removeFieldsIndexes); err != nil { + if err := Marshal(in, bufferString, removeFieldsIndexes, colIndex); err != nil { return nil, err } return bufferString.Bytes(), nil } // Marshal returns the CSV in writer from the interface. -func Marshal(in interface{}, out io.Writer, removeFieldsIndexes []int) (err error) { +func Marshal(in interface{}, out io.Writer, removeFieldsIndexes []int, colIndex []int) (err error) { writer := getCSVWriter(out) - return writeTo(writer, in, false, removeFieldsIndexes) + return writeTo(writer, in, false, removeFieldsIndexes, colIndex) } // MarshalWithoutHeaders returns the CSV in writer from the interface. -func MarshalWithoutHeaders(in interface{}, out io.Writer, removeFieldsIndexes []int) (err error) { +func MarshalWithoutHeaders(in interface{}, out io.Writer, removeFieldsIndexes []int, colIndex []int) (err error) { writer := getCSVWriter(out) - return writeTo(writer, in, true, removeFieldsIndexes) + return writeTo(writer, in, true, removeFieldsIndexes, colIndex) } // MarshalChan returns the CSV read from the channel. -func MarshalChan(c <-chan interface{}, out CSVWriter, removeFieldsIndexes []int) error { - return writeFromChan(out, c, false, removeFieldsIndexes) +func MarshalChan(c <-chan interface{}, out CSVWriter, removeFieldsIndexes []int, colIndex []int) error { + return writeFromChan(out, c, false, removeFieldsIndexes, colIndex) } // MarshalChanWithoutHeaders returns the CSV read from the channel. -func MarshalChanWithoutHeaders(c <-chan interface{}, out CSVWriter, removeFieldsIndexes []int) error { - return writeFromChan(out, c, true, removeFieldsIndexes) +func MarshalChanWithoutHeaders(c <-chan interface{}, out CSVWriter, removeFieldsIndexes []int, colIndex []int) error { + return writeFromChan(out, c, true, removeFieldsIndexes, colIndex) } // MarshalCSV returns the CSV in writer from the interface. -func MarshalCSV(in interface{}, out CSVWriter, removeFieldsIndexes []int) (err error) { - return writeTo(out, in, false, removeFieldsIndexes) +func MarshalCSV(in interface{}, out CSVWriter, removeFieldsIndexes []int, colIndex []int) (err error) { + return writeTo(out, in, false, removeFieldsIndexes, colIndex) } // MarshalCSVWithoutHeaders returns the CSV in writer from the interface. -func MarshalCSVWithoutHeaders(in interface{}, out CSVWriter, removeFieldsIndexes []int) (err error) { - return writeTo(out, in, true, removeFieldsIndexes) +func MarshalCSVWithoutHeaders(in interface{}, out CSVWriter, removeFieldsIndexes []int, colIndex []int) (err error) { + return writeTo(out, in, true, removeFieldsIndexes, colIndex) } // -------------------------------------------------------------------------- diff --git a/encode.go b/encode.go index 97d04e5..8b3478f 100644 --- a/encode.go +++ b/encode.go @@ -3,8 +3,10 @@ package gocsv import ( "errors" "fmt" + "golang.org/x/exp/slices" "io" "reflect" + "sort" ) var ( @@ -19,7 +21,7 @@ func newEncoder(out io.Writer) *encoder { return &encoder{out} } -func writeFromChan(writer CSVWriter, c <-chan interface{}, omitHeaders bool, removeFieldsIndexes []int) error { +func writeFromChan(writer CSVWriter, c <-chan interface{}, omitHeaders bool, removeFieldsIndexes []int, colIndex []int) error { // Get the first value. It wil determine the header structure. firstValue, ok := <-c if !ok { @@ -49,6 +51,7 @@ func writeFromChan(writer CSVWriter, c <-chan interface{}, omitHeaders bool, rem return err } csvHeadersLabels[j] = inInnerFieldValue + csvHeadersLabels = reorderColumns(csvHeadersLabels, colIndex) } if err := writer.Write(csvHeadersLabels); err != nil { return err @@ -71,7 +74,8 @@ func writeFromChan(writer CSVWriter, c <-chan interface{}, omitHeaders bool, rem return writer.Error() } -func writeTo(writer CSVWriter, in interface{}, omitHeaders bool, removeFieldsIndexes []int) error { +func writeTo(writer CSVWriter, in interface{}, omitHeaders bool, removeFieldsIndexes []int, colIndex []int) error { + colIndex = changeToSequence(colIndex) inValue, inType := getConcreteReflectValueAndType(in) // Get the concrete type (not pointer) (Slice or Array) if err := ensureInType(inType); err != nil { return err @@ -88,6 +92,7 @@ func writeTo(writer CSVWriter, in interface{}, omitHeaders bool, removeFieldsInd for i, fieldInfo := range inInnerStructInfo.Fields { // Used to write the header (first line) in CSV csvHeadersLabels[i] = fieldInfo.getFirstKey() } + csvHeadersLabels = reorderColumns(csvHeadersLabels, colIndex) if !omitHeaders { if err := writer.Write(csvHeadersLabels); err != nil { return err @@ -103,6 +108,7 @@ func writeTo(writer CSVWriter, in interface{}, omitHeaders bool, removeFieldsInd } csvHeadersLabels[j] = inInnerFieldValue } + csvHeadersLabels = reorderColumns(csvHeadersLabels, colIndex) if err := writer.Write(csvHeadersLabels); err != nil { return err } @@ -194,3 +200,26 @@ func getFilteredFields(fields []fieldInfo, removeFieldsIndexes []int) []fieldInf } return newFields } + +/* +Make colIndex consist of sequential numbers starting from 0. +Ex. [1,2,5,8,0] -> [1,2,3,4,0] +*/ +func changeToSequence(colIndex []int) []int { + copiedColIndex := make([]int, len(colIndex)) + copy(copiedColIndex, colIndex) + sort.Ints(copiedColIndex) + + for i, v := range colIndex { + colIndex[i] = slices.Index(copiedColIndex, v) + } + return colIndex +} + +func reorderColumns(row []string, colIndex []int) []string { + newLine := make([]string, len(row)) + for from, to := range colIndex { + newLine[to] = row[from] + } + return newLine +} diff --git a/encode_test.go b/encode_test.go index dc05d10..58a04c5 100644 --- a/encode_test.go +++ b/encode_test.go @@ -7,6 +7,7 @@ import ( "io" "io/ioutil" "math" + "reflect" "strconv" "strings" "testing" @@ -24,6 +25,14 @@ func assertLine(t *testing.T, expected, actual []string) { } } +func generateFakeColIndex(len int) []int { + colIndex := make([]int, len) + for i := range colIndex { + colIndex[i] = i + } + return colIndex +} + func Test_writeTo(t *testing.T) { b := bytes.Buffer{} e := &encoder{out: &b} @@ -33,7 +42,9 @@ func Test_writeTo(t *testing.T) { {Foo: "f", Bar: 1, Baz: "baz", Frop: 0.1, Blah: &blah, SPtr: &sptr}, {Foo: "e", Bar: 3, Baz: "b", Frop: 6.0 / 13, Blah: nil, SPtr: nil}, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil { + + colIndex := generateFakeColIndex(reflect.TypeOf(Sample{}).NumField()) + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil { t.Fatal(err) } @@ -56,7 +67,10 @@ func Test_writeTo_Time(t *testing.T) { s := []DateTime{ {Foo: d}, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, true, []int{}); err != nil { + + colIndex := generateFakeColIndex(reflect.TypeOf(DateTime{}).NumField()) + + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, true, []int{}, colIndex); err != nil { t.Fatal(err) } @@ -87,7 +101,8 @@ func Test_writeTo_NoHeaders(t *testing.T) { {Foo: "f", Bar: 1, Baz: "baz", Frop: 0.1, Blah: &blah, SPtr: &sptr}, {Foo: "e", Bar: 3, Baz: "b", Frop: 6.0 / 13, Blah: nil, SPtr: nil}, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, true, []int{}); err != nil { + colIndex := generateFakeColIndex(reflect.TypeOf(Sample{}).NumField()) + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, true, []int{}, colIndex); err != nil { t.Fatal(err) } @@ -109,7 +124,8 @@ func Test_writeTo_multipleTags(t *testing.T) { {Foo: "abc", Bar: 123}, {Foo: "def", Bar: 234}, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil { + colIndex := generateFakeColIndex(reflect.TypeOf(MultiTagSample{}).NumField()) + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil { t.Fatal(err) } @@ -146,7 +162,8 @@ func Test_writeTo_slice(t *testing.T) { }, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil { + colIndex := generateFakeColIndex(reflect.TypeOf(TestType{}).NumField()) + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil { t.Fatal(err) } @@ -181,7 +198,9 @@ func Test_writeTo_slice_structs(t *testing.T) { }, }, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil { + + colIndex := generateFakeColIndex(11) + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil { t.Fatal(err) } @@ -210,7 +229,8 @@ func Test_writeTo_embed(t *testing.T) { Grault: math.Pi, }, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil { + colIndex := generateFakeColIndex(10) + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil { t.Fatal(err) } @@ -239,7 +259,9 @@ func Test_writeTo_embedptr(t *testing.T) { Grault: math.Pi, }, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil { + + colIndex := generateFakeColIndex(10) + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil { t.Fatal(err) } @@ -260,7 +282,9 @@ func Test_writeTo_embedptr_nil(t *testing.T) { s := []EmbedPtrSample{ {}, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil { + + colIndex := generateFakeColIndex(10) + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil { t.Fatal(err) } @@ -283,7 +307,9 @@ func Test_writeTo_embedmarshal(t *testing.T) { Foo: &MarshalSample{Dummy: "bar"}, }, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil { + + colIndex := generateFakeColIndex(1) + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil { t.Fatal(err) } @@ -314,8 +340,10 @@ func Test_writeTo_embedmarshalCSV(t *testing.T) { }, } + colIndex := generateFakeColIndex(2) + // Next, attempt to write our test data to a CSV format - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil { + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil { t.Fatal(err) } @@ -358,7 +386,9 @@ func Test_writeTo_complex_embed(t *testing.T) { Corge: "hhh", }, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), sfs, false, []int{}); err != nil { + + colIndex := generateFakeColIndex(11) + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), sfs, false, []int{}, colIndex); err != nil { t.Fatal(err) } lines, err := csv.NewReader(&b).ReadAll() @@ -400,7 +430,8 @@ func Test_writeTo_complex_inner_struct_embed(t *testing.T) { }, } - if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), sfs, true, []int{}); err != nil { + colIndex := generateFakeColIndex(2) + if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), sfs, true, []int{}, colIndex); err != nil { t.Fatal(err) } lines, err := csv.NewReader(&b).ReadAll() @@ -423,7 +454,9 @@ func Test_writeToChan(t *testing.T) { } close(c) }() - if err := MarshalChan(c, NewSafeCSVWriter(csv.NewWriter(e.out)), []int{}); err != nil { + + colIndex := generateFakeColIndex(7) + if err := MarshalChan(c, NewSafeCSVWriter(csv.NewWriter(e.out)), []int{}, colIndex); err != nil { t.Fatal(err) } lines, err := csv.NewReader(&b).ReadAll() @@ -448,7 +481,8 @@ func Test_MarshalChan_ClosedChannel(t *testing.T) { c := make(chan interface{}) close(c) - if err := MarshalChan(c, NewSafeCSVWriter(csv.NewWriter(e.out)), []int{}); !errors.Is(err, ErrChannelIsClosed) { + colIndex := generateFakeColIndex(7) + if err := MarshalChan(c, NewSafeCSVWriter(csv.NewWriter(e.out)), []int{}, colIndex); !errors.Is(err, ErrChannelIsClosed) { t.Fatal(err) } } @@ -468,7 +502,8 @@ func TestRenamedTypesMarshal(t *testing.T) { // Switch back to default for tests executed after this defer SetCSVWriter(DefaultCSVWriter) - csvContent, err := MarshalString(&samples, []int{}) + colIndex := generateFakeColIndex(reflect.TypeOf(RenamedSample{}).NumField()) + csvContent, err := MarshalString(&samples, []int{}, colIndex) if err != nil { t.Fatal(err) } @@ -480,7 +515,8 @@ func TestRenamedTypesMarshal(t *testing.T) { samples = []RenamedSample{ {RenamedFloatUnmarshaler: 4.2, RenamedFloatDefault: 1.5}, } - _, err = MarshalString(&samples, []int{}) + + _, err = MarshalString(&samples, []int{}, colIndex) if _, ok := err.(MarshalError); !ok { t.Fatalf("Expected UnmarshalError, got %v", err) } @@ -499,7 +535,8 @@ func TestCustomTagSeparatorMarshal(t *testing.T) { TagSeparator = "," }() - csvContent, err := MarshalString(&samples, []int{}) + colIndex := generateFakeColIndex(reflect.TypeOf(RenamedSample{}).NumField()) + csvContent, err := MarshalString(&samples, []int{}, colIndex) if err != nil { t.Fatal(err) } @@ -528,7 +565,9 @@ func (e MarshalError) Error() string { func Benchmark_MarshalCSVWithoutHeaders(b *testing.B) { dst := NewSafeCSVWriter(csv.NewWriter(ioutil.Discard)) for n := 0; n < b.N; n++ { - err := MarshalCSVWithoutHeaders([]Sample{{}}, dst, []int{}) + + colIndex := generateFakeColIndex(0) + err := MarshalCSVWithoutHeaders([]Sample{{}}, dst, []int{}, colIndex) if err != nil { b.Fatal(err) } diff --git a/go.mod b/go.mod index 912e92f..c7d0f01 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/shigetaichi/gocsv -go 1.13 +go 1.18 + +require golang.org/x/exp v0.0.0-20221028150844-83b7d23a625f diff --git a/unmarshaller_test.go b/unmarshaller_test.go index d821230..a5b6383 100644 --- a/unmarshaller_test.go +++ b/unmarshaller_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/csv" "io" + "reflect" "strings" "testing" ) @@ -69,7 +70,8 @@ func TestUnmarshalListOfStructsAfterMarshal(t *testing.T) { innerWriter := csv.NewWriter(buffer) innerWriter.Comma = '|' csvWriter := NewSafeCSVWriter(innerWriter) - if err := MarshalCSV(inData, csvWriter, []int{}); err != nil { + colIndex := generateFakeColIndex(reflect.TypeOf(Option{}).NumField()) + if err := MarshalCSV(inData, csvWriter, []int{}, colIndex); err != nil { t.Fatalf("Error marshalling data to CSV: %#v", err) }