Skip to content

Commit

Permalink
Merge pull request #3 from shigetaichi/feature/column-reorder
Browse files Browse the repository at this point in the history
feat: Reorder columns
  • Loading branch information
shigetaichi authored Nov 16, 2022
2 parents c428737 + fdda667 commit d13fd86
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 43 deletions.
40 changes: 20 additions & 20 deletions csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,67 +114,67 @@ func getCSVReader(in io.Reader) CSVReader {
// Marshal functions

// MarshalFile saves the interface as CSV in the file.
func MarshalFile(in interface{}, file *os.File, removeFieldsIndexes []int) (err error) {
return Marshal(in, file, removeFieldsIndexes)
func MarshalFile(in interface{}, file *os.File, removeFieldsIndexes []int, colIndex []int) (err error) {
return Marshal(in, file, removeFieldsIndexes, colIndex)
}

// MarshalString returns the CSV string from the interface.
func MarshalString(in interface{}, removeFieldsIndexes []int) (out string, err error) {
func MarshalString(in interface{}, removeFieldsIndexes []int, colIndex []int) (out string, err error) {
bufferString := bytes.NewBufferString(out)
if err := Marshal(in, bufferString, removeFieldsIndexes); err != nil {
if err := Marshal(in, bufferString, removeFieldsIndexes, colIndex); err != nil {
return "", err
}
return bufferString.String(), nil
}

// MarshalStringWithoutHeaders returns the CSV string from the interface.
func MarshalStringWithoutHeaders(in interface{}, removeFieldsIndexes []int) (out string, err error) {
func MarshalStringWithoutHeaders(in interface{}, removeFieldsIndexes []int, colIndex []int) (out string, err error) {
bufferString := bytes.NewBufferString(out)
if err := MarshalWithoutHeaders(in, bufferString, removeFieldsIndexes); err != nil {
if err := MarshalWithoutHeaders(in, bufferString, removeFieldsIndexes, colIndex); err != nil {
return "", err
}
return bufferString.String(), nil
}

// MarshalBytes returns the CSV bytes from the interface.
func MarshalBytes(in interface{}, removeFieldsIndexes []int) (out []byte, err error) {
func MarshalBytes(in interface{}, removeFieldsIndexes []int, colIndex []int) (out []byte, err error) {
bufferString := bytes.NewBuffer(out)
if err := Marshal(in, bufferString, removeFieldsIndexes); err != nil {
if err := Marshal(in, bufferString, removeFieldsIndexes, colIndex); err != nil {
return nil, err
}
return bufferString.Bytes(), nil
}

// Marshal returns the CSV in writer from the interface.
func Marshal(in interface{}, out io.Writer, removeFieldsIndexes []int) (err error) {
func Marshal(in interface{}, out io.Writer, removeFieldsIndexes []int, colIndex []int) (err error) {
writer := getCSVWriter(out)
return writeTo(writer, in, false, removeFieldsIndexes)
return writeTo(writer, in, false, removeFieldsIndexes, colIndex)
}

// MarshalWithoutHeaders returns the CSV in writer from the interface.
func MarshalWithoutHeaders(in interface{}, out io.Writer, removeFieldsIndexes []int) (err error) {
func MarshalWithoutHeaders(in interface{}, out io.Writer, removeFieldsIndexes []int, colIndex []int) (err error) {
writer := getCSVWriter(out)
return writeTo(writer, in, true, removeFieldsIndexes)
return writeTo(writer, in, true, removeFieldsIndexes, colIndex)
}

// MarshalChan returns the CSV read from the channel.
func MarshalChan(c <-chan interface{}, out CSVWriter, removeFieldsIndexes []int) error {
return writeFromChan(out, c, false, removeFieldsIndexes)
func MarshalChan(c <-chan interface{}, out CSVWriter, removeFieldsIndexes []int, colIndex []int) error {
return writeFromChan(out, c, false, removeFieldsIndexes, colIndex)
}

// MarshalChanWithoutHeaders returns the CSV read from the channel.
func MarshalChanWithoutHeaders(c <-chan interface{}, out CSVWriter, removeFieldsIndexes []int) error {
return writeFromChan(out, c, true, removeFieldsIndexes)
func MarshalChanWithoutHeaders(c <-chan interface{}, out CSVWriter, removeFieldsIndexes []int, colIndex []int) error {
return writeFromChan(out, c, true, removeFieldsIndexes, colIndex)
}

// MarshalCSV returns the CSV in writer from the interface.
func MarshalCSV(in interface{}, out CSVWriter, removeFieldsIndexes []int) (err error) {
return writeTo(out, in, false, removeFieldsIndexes)
func MarshalCSV(in interface{}, out CSVWriter, removeFieldsIndexes []int, colIndex []int) (err error) {
return writeTo(out, in, false, removeFieldsIndexes, colIndex)
}

// MarshalCSVWithoutHeaders returns the CSV in writer from the interface.
func MarshalCSVWithoutHeaders(in interface{}, out CSVWriter, removeFieldsIndexes []int) (err error) {
return writeTo(out, in, true, removeFieldsIndexes)
func MarshalCSVWithoutHeaders(in interface{}, out CSVWriter, removeFieldsIndexes []int, colIndex []int) (err error) {
return writeTo(out, in, true, removeFieldsIndexes, colIndex)
}

// --------------------------------------------------------------------------
Expand Down
33 changes: 31 additions & 2 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package gocsv
import (
"errors"
"fmt"
"golang.org/x/exp/slices"
"io"
"reflect"
"sort"
)

var (
Expand All @@ -19,7 +21,7 @@ func newEncoder(out io.Writer) *encoder {
return &encoder{out}
}

func writeFromChan(writer CSVWriter, c <-chan interface{}, omitHeaders bool, removeFieldsIndexes []int) error {
func writeFromChan(writer CSVWriter, c <-chan interface{}, omitHeaders bool, removeFieldsIndexes []int, colIndex []int) error {
// Get the first value. It wil determine the header structure.
firstValue, ok := <-c
if !ok {
Expand Down Expand Up @@ -49,6 +51,7 @@ func writeFromChan(writer CSVWriter, c <-chan interface{}, omitHeaders bool, rem
return err
}
csvHeadersLabels[j] = inInnerFieldValue
csvHeadersLabels = reorderColumns(csvHeadersLabels, colIndex)
}
if err := writer.Write(csvHeadersLabels); err != nil {
return err
Expand All @@ -71,7 +74,8 @@ func writeFromChan(writer CSVWriter, c <-chan interface{}, omitHeaders bool, rem
return writer.Error()
}

func writeTo(writer CSVWriter, in interface{}, omitHeaders bool, removeFieldsIndexes []int) error {
func writeTo(writer CSVWriter, in interface{}, omitHeaders bool, removeFieldsIndexes []int, colIndex []int) error {
colIndex = changeToSequence(colIndex)
inValue, inType := getConcreteReflectValueAndType(in) // Get the concrete type (not pointer) (Slice<?> or Array<?>)
if err := ensureInType(inType); err != nil {
return err
Expand All @@ -88,6 +92,7 @@ func writeTo(writer CSVWriter, in interface{}, omitHeaders bool, removeFieldsInd
for i, fieldInfo := range inInnerStructInfo.Fields { // Used to write the header (first line) in CSV
csvHeadersLabels[i] = fieldInfo.getFirstKey()
}
csvHeadersLabels = reorderColumns(csvHeadersLabels, colIndex)
if !omitHeaders {
if err := writer.Write(csvHeadersLabels); err != nil {
return err
Expand All @@ -103,6 +108,7 @@ func writeTo(writer CSVWriter, in interface{}, omitHeaders bool, removeFieldsInd
}
csvHeadersLabels[j] = inInnerFieldValue
}
csvHeadersLabels = reorderColumns(csvHeadersLabels, colIndex)
if err := writer.Write(csvHeadersLabels); err != nil {
return err
}
Expand Down Expand Up @@ -194,3 +200,26 @@ func getFilteredFields(fields []fieldInfo, removeFieldsIndexes []int) []fieldInf
}
return newFields
}

/*
Make colIndex consist of sequential numbers starting from 0.
Ex. [1,2,5,8,0] -> [1,2,3,4,0]
*/
func changeToSequence(colIndex []int) []int {
copiedColIndex := make([]int, len(colIndex))
copy(copiedColIndex, colIndex)
sort.Ints(copiedColIndex)

for i, v := range colIndex {
colIndex[i] = slices.Index(copiedColIndex, v)
}
return colIndex
}

func reorderColumns(row []string, colIndex []int) []string {
newLine := make([]string, len(row))
for from, to := range colIndex {
newLine[to] = row[from]
}
return newLine
}
77 changes: 58 additions & 19 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"io/ioutil"
"math"
"reflect"
"strconv"
"strings"
"testing"
Expand All @@ -24,6 +25,14 @@ func assertLine(t *testing.T, expected, actual []string) {
}
}

func generateFakeColIndex(len int) []int {
colIndex := make([]int, len)
for i := range colIndex {
colIndex[i] = i
}
return colIndex
}

func Test_writeTo(t *testing.T) {
b := bytes.Buffer{}
e := &encoder{out: &b}
Expand All @@ -33,7 +42,9 @@ func Test_writeTo(t *testing.T) {
{Foo: "f", Bar: 1, Baz: "baz", Frop: 0.1, Blah: &blah, SPtr: &sptr},
{Foo: "e", Bar: 3, Baz: "b", Frop: 6.0 / 13, Blah: nil, SPtr: nil},
}
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil {

colIndex := generateFakeColIndex(reflect.TypeOf(Sample{}).NumField())
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil {
t.Fatal(err)
}

Expand All @@ -56,7 +67,10 @@ func Test_writeTo_Time(t *testing.T) {
s := []DateTime{
{Foo: d},
}
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, true, []int{}); err != nil {

colIndex := generateFakeColIndex(reflect.TypeOf(DateTime{}).NumField())

if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, true, []int{}, colIndex); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -87,7 +101,8 @@ func Test_writeTo_NoHeaders(t *testing.T) {
{Foo: "f", Bar: 1, Baz: "baz", Frop: 0.1, Blah: &blah, SPtr: &sptr},
{Foo: "e", Bar: 3, Baz: "b", Frop: 6.0 / 13, Blah: nil, SPtr: nil},
}
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, true, []int{}); err != nil {
colIndex := generateFakeColIndex(reflect.TypeOf(Sample{}).NumField())
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, true, []int{}, colIndex); err != nil {
t.Fatal(err)
}

Expand All @@ -109,7 +124,8 @@ func Test_writeTo_multipleTags(t *testing.T) {
{Foo: "abc", Bar: 123},
{Foo: "def", Bar: 234},
}
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil {
colIndex := generateFakeColIndex(reflect.TypeOf(MultiTagSample{}).NumField())
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -146,7 +162,8 @@ func Test_writeTo_slice(t *testing.T) {
},
}

if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil {
colIndex := generateFakeColIndex(reflect.TypeOf(TestType{}).NumField())
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -181,7 +198,9 @@ func Test_writeTo_slice_structs(t *testing.T) {
},
},
}
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil {

colIndex := generateFakeColIndex(11)
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -210,7 +229,8 @@ func Test_writeTo_embed(t *testing.T) {
Grault: math.Pi,
},
}
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil {
colIndex := generateFakeColIndex(10)
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -239,7 +259,9 @@ func Test_writeTo_embedptr(t *testing.T) {
Grault: math.Pi,
},
}
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil {

colIndex := generateFakeColIndex(10)
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil {
t.Fatal(err)
}

Expand All @@ -260,7 +282,9 @@ func Test_writeTo_embedptr_nil(t *testing.T) {
s := []EmbedPtrSample{
{},
}
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil {

colIndex := generateFakeColIndex(10)
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil {
t.Fatal(err)
}

Expand All @@ -283,7 +307,9 @@ func Test_writeTo_embedmarshal(t *testing.T) {
Foo: &MarshalSample{Dummy: "bar"},
},
}
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil {

colIndex := generateFakeColIndex(1)
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -314,8 +340,10 @@ func Test_writeTo_embedmarshalCSV(t *testing.T) {
},
}

colIndex := generateFakeColIndex(2)

// Next, attempt to write our test data to a CSV format
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}); err != nil {
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), s, false, []int{}, colIndex); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -358,7 +386,9 @@ func Test_writeTo_complex_embed(t *testing.T) {
Corge: "hhh",
},
}
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), sfs, false, []int{}); err != nil {

colIndex := generateFakeColIndex(11)
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), sfs, false, []int{}, colIndex); err != nil {
t.Fatal(err)
}
lines, err := csv.NewReader(&b).ReadAll()
Expand Down Expand Up @@ -400,7 +430,8 @@ func Test_writeTo_complex_inner_struct_embed(t *testing.T) {
},
}

if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), sfs, true, []int{}); err != nil {
colIndex := generateFakeColIndex(2)
if err := writeTo(NewSafeCSVWriter(csv.NewWriter(e.out)), sfs, true, []int{}, colIndex); err != nil {
t.Fatal(err)
}
lines, err := csv.NewReader(&b).ReadAll()
Expand All @@ -423,7 +454,9 @@ func Test_writeToChan(t *testing.T) {
}
close(c)
}()
if err := MarshalChan(c, NewSafeCSVWriter(csv.NewWriter(e.out)), []int{}); err != nil {

colIndex := generateFakeColIndex(7)
if err := MarshalChan(c, NewSafeCSVWriter(csv.NewWriter(e.out)), []int{}, colIndex); err != nil {
t.Fatal(err)
}
lines, err := csv.NewReader(&b).ReadAll()
Expand All @@ -448,7 +481,8 @@ func Test_MarshalChan_ClosedChannel(t *testing.T) {
c := make(chan interface{})
close(c)

if err := MarshalChan(c, NewSafeCSVWriter(csv.NewWriter(e.out)), []int{}); !errors.Is(err, ErrChannelIsClosed) {
colIndex := generateFakeColIndex(7)
if err := MarshalChan(c, NewSafeCSVWriter(csv.NewWriter(e.out)), []int{}, colIndex); !errors.Is(err, ErrChannelIsClosed) {
t.Fatal(err)
}
}
Expand All @@ -468,7 +502,8 @@ func TestRenamedTypesMarshal(t *testing.T) {
// Switch back to default for tests executed after this
defer SetCSVWriter(DefaultCSVWriter)

csvContent, err := MarshalString(&samples, []int{})
colIndex := generateFakeColIndex(reflect.TypeOf(RenamedSample{}).NumField())
csvContent, err := MarshalString(&samples, []int{}, colIndex)
if err != nil {
t.Fatal(err)
}
Expand All @@ -480,7 +515,8 @@ func TestRenamedTypesMarshal(t *testing.T) {
samples = []RenamedSample{
{RenamedFloatUnmarshaler: 4.2, RenamedFloatDefault: 1.5},
}
_, err = MarshalString(&samples, []int{})

_, err = MarshalString(&samples, []int{}, colIndex)
if _, ok := err.(MarshalError); !ok {
t.Fatalf("Expected UnmarshalError, got %v", err)
}
Expand All @@ -499,7 +535,8 @@ func TestCustomTagSeparatorMarshal(t *testing.T) {
TagSeparator = ","
}()

csvContent, err := MarshalString(&samples, []int{})
colIndex := generateFakeColIndex(reflect.TypeOf(RenamedSample{}).NumField())
csvContent, err := MarshalString(&samples, []int{}, colIndex)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -528,7 +565,9 @@ func (e MarshalError) Error() string {
func Benchmark_MarshalCSVWithoutHeaders(b *testing.B) {
dst := NewSafeCSVWriter(csv.NewWriter(ioutil.Discard))
for n := 0; n < b.N; n++ {
err := MarshalCSVWithoutHeaders([]Sample{{}}, dst, []int{})

colIndex := generateFakeColIndex(0)
err := MarshalCSVWithoutHeaders([]Sample{{}}, dst, []int{}, colIndex)
if err != nil {
b.Fatal(err)
}
Expand Down
Loading

0 comments on commit d13fd86

Please sign in to comment.