diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..2289169 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,111 @@ +linters-settings: + funlen: + lines: 110 + statements: 70 + gci: + sections: + - standard + - default + - localmodule + custom-order: true + goconst: + min-len: 2 + min-occurrences: 2 + gocritic: + enabled-tags: + - diagnostic + - experimental + - opinionated + - performance + - style + gocyclo: + min-complexity: 15 + cyclop: + skip-tests: true + max-complexity: 15 + godot: + capital: true + goimports: + local-prefixes: github.com/lzambarda/goflat + govet: + settings: + printf: + funcs: + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Infof + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf + disable: + - fieldalignment + lll: + line-length: 140 + misspell: + locale: UK + tagliatelle: + case: + rules: + json: snake + unparam: + check-exported: true + + wrapcheck: + ignoreSigs: + - .Errorf( + - errors.New( + - errors.Unwrap( + - errors.Join( + - .Wrap( + - .Wrapf( + - .WithMessage( + - .WithMessagef( + - .WithStack( + - status.Error( + + wsl: + allow-cuddle-declarations: true + +issues: + # Excluding configuration per-path, per-linter, per-text and per-source + exclude-rules: + - path: _test\.go + linters: + - bodyclose + - dupl # we usually duplicate code in tests + - dupword + - errcheck + - errchkjson # we mostly dump file diffs, no biggie + - funlen + - gochecknoglobals + - goconst # sometimes it is easier this way + - gocritic + - gosec # security check is not important in tests + - govet + - maintidx + - revive + - unparam + - varnamelen + - wrapcheck + - path: testing + linters: + - errcheck + fix: true + exclude-use-default: false + exclude-dirs: + - model + - tmp + - bin + - scripts + +run: + allow-parallel-runners: true + tests: true + build-tags: + - integration + +linters: + enable-all: true + disable: + - exhaustruct # I want to use zero values... and sometime leave a field uninitialised, because it'll be later. + - depguard # because I don't want to write a dedicated config file. + - nonamedreturns # I don't fully agree with this + - paralleltest # I don't agree with this level of nitpicking diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f3f54db --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Luca Zambarda + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..f7667ba --- /dev/null +++ b/README.md @@ -0,0 +1,54 @@ +# goflat + +Generic-friendly flat file marshaller and unmarshaller using the `flat` field tag in structs. + +## Overview + +```go +type Record struct { + FirstName string `flat:"first_name"` + LastName string `flat:"last_name"` + Age int `flat:"age"` + Height float32 `flat:"-"` // ignored +} + +ch := make(chan Record) + +... + +goflat.MarshalSliceToWriter[Record](ctx,ch,csvWriter,options) +``` + +Will result in: + +``` +first_name,last_name,age +John,Doe,30 +Jane,Doe,20 +``` + +## Options + +Both marshal and unmarshal operations support `goflat.Options`, which allow to introduce automatic safety checks, such as duplicated headers, `flat` tag coverage and more. + +## Custom marshal / unmarshal + +Both operations can be customised for each field in a struct by having that value implementing `goflat.Marshal` and/or `goflat.Unmarshal`. + +```go +type Record struct { + Field MyType `flat:"field"` +} + +type MyType struct { + Value int +} + +func (m *MyType) Marshal() (string,error) { + if m.Value %2 == 0 { + return "odd", nil + } + + return "even", nil +} +``` diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..ca2800a --- /dev/null +++ b/doc.go @@ -0,0 +1,2 @@ +// Package goflat contains all the code to marshal and unmarshal tabular files. +package goflat diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..115072c --- /dev/null +++ b/errors.go @@ -0,0 +1,25 @@ +package goflat + +import "errors" + +var ( + // ErrNotAStruct is returned when the value to be worked with is not a struct. + ErrNotAStruct = errors.New("not a struct") + // ErrTaglessField is returned when goflat works in strict mode and a field + // of the input struct has no "flat" tag. + ErrTaglessField = errors.New("tagless field") + // ErrDuplicatedHeader is returned when there is more than one header with + // the same value. Only returned if [Option.ErrorIfDuplicateHeaders] is set + // to true. + ErrDuplicatedHeader = errors.New("duplicated header") + // ErrMissingHeader is returned when a header referenced in a "flat" tag + // does not appear in the input file. Only returned if + // [Option.ErrorIfMissingHeaders] is set to true. + ErrMissingHeader = errors.New("missing header") + // ErrMismatchedFields is returned when the input structs have inconsistent + // fields. In theory this will never be returned. + ErrMismatchedFields = errors.New("mismatched fields") + // ErrUnsupportedType is returned when the unmarshaller encounters an + // unsupported type. + ErrUnsupportedType = errors.New("unsupported type") +) diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c9ad154 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/lzambarda/goflat + +go 1.23.2 + +require ( + github.com/google/go-cmp v0.6.0 + golang.org/x/sync v0.9.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..2243e7b --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= diff --git a/marshal.go b/marshal.go new file mode 100644 index 0000000..ee631de --- /dev/null +++ b/marshal.go @@ -0,0 +1,84 @@ +package goflat + +import ( + "context" + "encoding/csv" + "fmt" +) + +// Marshaller can be used to tell goflat to use custom logic to convert a field +// into a string. +type Marshaller interface { + Marshal() (string, error) +} + +// MarshalSliceToWriter marshals a slice of structs to a CSV file. +func MarshalSliceToWriter[T any](ctx context.Context, values []T, writer *csv.Writer, opts Options) error { + ch := make(chan T) //nolint:varnamelen // Fine here. + + go func() { + defer close(ch) + + for _, value := range values { + select { + case <-ctx.Done(): + return + case ch <- value: + } + } + }() + + return MarshalChannelToWriter(ctx, ch, writer, opts) +} + +// MarshalChannelToWriter marshals a channel of structs to a CSV file. +func MarshalChannelToWriter[T any](ctx context.Context, inputCh <-chan T, writer *csv.Writer, opts Options) error { + opts.headersFromStruct = true + + factory, err := newFactory[T](nil, opts) + if err != nil { + return fmt.Errorf("new factory: %w", err) + } + + err = writer.Write(factory.marshalHeaders()) + if err != nil { + return fmt.Errorf("write headers: %w", err) + } + + var currentLine int + var value T + + for { + var channelHasValue bool + + select { + case <-ctx.Done(): + return ctx.Err() //nolint:wrapcheck // No need here. + case value, channelHasValue = <-inputCh: + } + + if !channelHasValue { + break + } + + record, err := factory.marshal(value, string(writer.Comma)) + if err != nil { + return fmt.Errorf("marshal %d: %w", currentLine, err) + } + + err = writer.Write(record) + if err != nil { + return fmt.Errorf("write line %d: %w", currentLine, err) + } + + currentLine++ + } + + writer.Flush() + + if err = writer.Error(); err != nil { + return fmt.Errorf("flush: %w", err) + } + + return nil +} diff --git a/marshal_test.go b/marshal_test.go new file mode 100644 index 0000000..24a8bed --- /dev/null +++ b/marshal_test.go @@ -0,0 +1,56 @@ +package goflat_test + +import ( + "bytes" + "context" + "encoding/csv" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/lzambarda/goflat" +) + +func TestMarshal(t *testing.T) { + expected, err := testdata.ReadFile("testdata/marshal/success.csv") + if err != nil { + t.Fatalf("read test file: %v", err) + } + + type record struct { + FirstName string `flat:"first_name"` + LastName string `flat:"last_name"` + Ignore uint8 `flat:"-"` + Age int `flat:"age"` + Height float32 `flat:"height"` + } + + input := []record{ + { + FirstName: "John", + LastName: "Doe", + Ignore: 123, + Age: 30, + Height: 1.75, + }, + { + FirstName: "Jane", + LastName: "Doe", + Ignore: 123, + Age: 25, + Height: 1.65, + }, + } + var got bytes.Buffer + + writer := csv.NewWriter(&got) + + err = goflat.MarshalSliceToWriter(context.Background(), input, writer, goflat.Options{}) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + if diff := cmp.Diff(string(expected), got.String()); diff != "" { + t.Errorf("(-expected, +got):\n%s", diff) + } +} diff --git a/reader.go b/reader.go new file mode 100644 index 0000000..5f79853 --- /dev/null +++ b/reader.go @@ -0,0 +1,65 @@ +package goflat + +import ( + "bytes" + "encoding/csv" + "fmt" + "io" + "strings" +) + +//nolint:gochecknoglobals // We are fine for now. +var commonSeparators = []string{",", ";", "\t", "|"} + +func readFirstLine(reader io.Reader) (string, error) { + b := make([]byte, 1) //nolint:varnamelen // Fine here. + + var line string + + for { + _, err := reader.Read(b) + if err != nil { + if err == io.EOF { + return line, nil + } + + return "", fmt.Errorf("read row: %w", err) + } + + line += string(b) + + if b[0] == '\n' { + return line, nil + } + } +} + +// DetectReader returns a CSV reader with a separator based on a best guess +// about the first line. +func DetectReader(reader io.Reader) (*csv.Reader, error) { + headers, err := readFirstLine(reader) + if err != nil { + return nil, fmt.Errorf("read first line: %w", err) + } + + var bestSeparator string + var bestCount int + + for _, sep := range commonSeparators { + count := strings.Count(headers, sep) + if count > bestCount { + bestCount = count + bestSeparator = sep + } + } + + // Read headers again + rr := io.MultiReader(bytes.NewBufferString(headers), reader) + + csvReader := csv.NewReader(rr) + if bestSeparator != "," { + csvReader.Comma = rune(bestSeparator[0]) + } + + return csvReader, nil +} diff --git a/reflect.go b/reflect.go new file mode 100644 index 0000000..19adce8 --- /dev/null +++ b/reflect.go @@ -0,0 +1,227 @@ +package goflat + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +type structFactory[T any] struct { + structType reflect.Type + columnMap []int + columnValues []any + columnNames []string +} + +// FieldTag is the tag that must be used in the struct fields so that goflat can +// work with them. +const FieldTag = "flat" + +// columnMapIgnore is used to mark a column as ignored. This is needed if there +// are duplicate headers that must be skipped. +const columnMapIgnore = -1 + +//nolint:varnamelen // Fine-ish here. +func newFactory[T any](headers []string, options Options) (*structFactory[T], error) { + var v T + + t := reflect.TypeOf(v) + + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("type %T: %w", v, ErrNotAStruct) + } + + factory := &structFactory[T]{ + structType: t, + columnMap: make([]int, len(headers)), + columnValues: make([]any, t.NumField()), + columnNames: make([]string, t.NumField()), + } + + covered := make([]bool, len(headers)) + + rv := reflect.ValueOf(v) + + for i := range t.NumField() { + fieldT := t.Field(i) + fieldV := rv.Field(i) + + factory.columnValues[i] = fieldV.Interface() + + v, ok := fieldT.Tag.Lookup(FieldTag) + if !ok && options.Strict { + return nil, fmt.Errorf("field %q breaks strict mode: %w", fieldT.Name, ErrTaglessField) + } + + if v == "" || v == "-" { + continue + } + + factory.columnNames[i] = v + + handledAt := -1 + + for j, header := range headers { + if covered[j] { + continue + } + + if header != v { + continue + } + + if handledAt >= 0 { + if options.ErrorIfDuplicateHeaders { + return nil, fmt.Errorf("header %q, index %d and %d: %w", header, j, handledAt, ErrDuplicatedHeader) + } + + // If the duplicate headers error flag is diabled, then mark the + // column as ignored and continue. + factory.columnMap[j] = columnMapIgnore + + continue + } + + handledAt = j + covered[j] = true + factory.columnMap[j] = i + } + + if handledAt == -1 && options.ErrorIfMissingHeaders { + return nil, fmt.Errorf("header %q: %w", v, ErrMissingHeader) + } + } + + return factory, nil +} + +//nolint:forcetypeassert,gocyclo,cyclop,ireturn // Fine for now. +func (s *structFactory[T]) unmarshal(record []string) (T, error) { + var zero T + if len(record) != len(s.columnMap) { + return zero, fmt.Errorf("expected %d fields, got %d: %w", len(s.columnMap), len(record), ErrMismatchedFields) + } + + newStruct := reflect.New(s.structType).Elem() + + var value any + var err error + + //nolint:varnamelen // Fine here. + for i, column := range record { + if s.columnMap[i] == columnMapIgnore { + continue + } + + columnBaseValue := s.columnValues[s.columnMap[i]] + + // special case + if u, ok := columnBaseValue.(Unmarshaller); ok { + value, err = u.Unmarshal(column) + } else { + switch columnBaseValue.(type) { + case bool: + value, err = strconv.ParseBool(column) + case int: + value, err = strconv.Atoi(column) + case int8: + value, err = strconv.ParseInt(column, 10, 8) + value = int8(value.(int64)) //nolint:gosec // Safe. + case int16: + value, err = strconv.ParseInt(column, 10, 16) + value = uint16(value.(int64)) //nolint:gosec // Safe. + case int32: + value, err = strconv.ParseInt(column, 10, 32) + value = int32(value.(int64)) //nolint:gosec // Safe. + case int64: + value, err = strconv.ParseInt(column, 10, 64) + case uint: + value, err = strconv.Atoi(column) + value = uint(value.(int)) //nolint:gosec // Safe. + case uint8: // aka byte + value, err = strconv.ParseUint(column, 10, 8) + value = uint8(value.(uint64)) //nolint:gosec // Safe. + case uint16: + value, err = strconv.ParseUint(column, 10, 16) + value = uint16(value.(uint64)) //nolint:gosec // Safe. + case uint32: + value, err = strconv.ParseUint(column, 10, 32) + value = uint32(value.(uint64)) //nolint:gosec // Safe. + case uint64: + value, err = strconv.ParseUint(column, 10, 64) + case float32: + value, err = strconv.ParseFloat(column, 32) + value = float32(value.(float64)) + case float64: + value, err = strconv.ParseFloat(column, 64) + case string: + value = column + default: + return zero, fmt.Errorf("type %T: %w", columnBaseValue, ErrUnsupportedType) + } + } + + if err != nil { + return zero, fmt.Errorf("parse column %d: %w", i, err) + } + + newStruct.Field(s.columnMap[i]).Set(reflect.ValueOf(value)) + } + + return newStruct.Interface().(T), nil +} + +func (s *structFactory[T]) marshalHeaders() []string { + headers := []string{} + + for _, name := range s.columnNames { + if name == "" { + continue + } + + headers = append(headers, name) + } + + return headers +} + +func (s *structFactory[T]) marshal(t T, separator string) ([]string, error) { + reflectValue := reflect.ValueOf(t) + + record := make([]string, 0, len(s.columnNames)) + + var strValue string + var err error + + //nolint:varnamelen // Fine here. + for i, name := range s.columnNames { + if name == "" { + continue + } + + fieldV := reflectValue.Field(i) + + // special case + if m, ok := fieldV.Interface().(Marshaller); ok { + strValue, err = m.Marshal() + if err != nil { + return nil, fmt.Errorf("marshal column %d: %w", i, err) + } + } else { + strValue = fmt.Sprintf("%v", fieldV.Interface()) + } + + strValue = strings.ReplaceAll(strValue, separator, "\\"+separator) + + record = append(record, strValue) + } + + record = record[0:len(record):len(record)] + + return record, nil +} diff --git a/reflect_internal_test.go b/reflect_internal_test.go new file mode 100644 index 0000000..8d628f3 --- /dev/null +++ b/reflect_internal_test.go @@ -0,0 +1,169 @@ +package goflat + +import ( + "reflect" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestReflect(t *testing.T) { + t.Run("error", testReflectError) + t.Run("success", testReflectSuccess) +} + +type s1 struct { + Foo string +} + +func testReflectError(t *testing.T) { + t.Run("tagless", testReflectErrorTaglessStrict) + t.Run("missing", testReflectErrorMissing) + t.Run("duplicate", testReflectErrorDuplicate) +} + +func testReflectErrorTaglessStrict(t *testing.T) { + f, err := newFactory[s1]([]string{}, Options{Strict: true}) + if f != nil { + t.Errorf("expected nil, got %v", f) + } + + if err == nil { + t.Errorf("expected error, got nil") + } +} + +func testReflectErrorMissing(t *testing.T) { + type foo struct { + Name string `flat:"name"` + Age int `flat:"age"` + Skipme string `flat:"-"` + } + + headers := []string{"name"} + + got, err := newFactory[foo](headers, Options{ + Strict: true, + ErrorIfMissingHeaders: true, + }) + if got != nil { + t.Errorf("expected nil, got %v", got) + } + + if err == nil { + t.Errorf("expected error, got nil") + } +} + +func testReflectErrorDuplicate(t *testing.T) { + type foo struct { + Name string `flat:"name"` + Age int `flat:"age"` + Skipme string `flat:"-"` + } + + headers := []string{"name", "age", "name"} + + got, err := newFactory[foo](headers, Options{ + Strict: true, + ErrorIfDuplicateHeaders: true, + }) + if got != nil { + t.Errorf("expected nil, got %v", got) + } + + if err == nil { + t.Errorf("expected error, got nil") + } +} + +func testReflectSuccess(t *testing.T) { + t.Run("duplicate", testReflectSuccessDuplicate) + t.Run("simple", testReflectSuccessSimple) +} + +func testReflectSuccessDuplicate(t *testing.T) { + type foo struct { + Name string `flat:"name"` + Age int `flat:"age"` + } + + headers := []string{"name", "age", "name"} + + got, err := newFactory[foo](headers, Options{ + Strict: true, + ErrorIfDuplicateHeaders: false, + ErrorIfMissingHeaders: true, + }) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + expected := &structFactory[foo]{ + structType: reflect.TypeOf(foo{}), + columnMap: []int{0, 1, -1}, + columnValues: []any{"", int(0)}, + columnNames: []string{"name", "age"}, + } + comparers := []cmp.Option{ + cmp.AllowUnexported(structFactory[foo]{}), + cmp.Comparer(func(a, b structFactory[foo]) bool { + if a.structType.String() != b.structType.String() { + return false + } + + if !slices.Equal(a.columnMap, b.columnMap) { + return false + } + + return true + }), + } + + if diff := cmp.Diff(expected, got, comparers...); diff != "" { + t.Errorf("(-want +got):\\n%s", diff) + } +} + +func testReflectSuccessSimple(t *testing.T) { + type foo struct { + Name string `flat:"name"` + Age int `flat:"age"` + Skipme string `flat:"-"` + } + + headers := []string{"name", "age"} + + got, err := newFactory[foo](headers, Options{ + Strict: true, + ErrorIfDuplicateHeaders: true, + ErrorIfMissingHeaders: true, + }) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + expected := &structFactory[foo]{ + structType: reflect.TypeOf(foo{}), + columnMap: []int{0, 1}, + } + comparers := []cmp.Option{ + cmp.AllowUnexported(structFactory[foo]{}), + cmp.Comparer(func(a, b structFactory[foo]) bool { + if a.structType.String() != b.structType.String() { + return false + } + + if !slices.Equal(a.columnMap, b.columnMap) { + return false + } + + return true + }), + } + + if diff := cmp.Diff(expected, got, comparers...); diff != "" { + t.Errorf("(-want +got):\\n%s", diff) + } +} diff --git a/testdata/marshal/success.csv b/testdata/marshal/success.csv new file mode 100644 index 0000000..b29a262 --- /dev/null +++ b/testdata/marshal/success.csv @@ -0,0 +1,3 @@ +first_name,last_name,age,height +John,Doe,30,1.75 +Jane,Doe,25,1.65 diff --git a/testdata/unmarshal/success.csv b/testdata/unmarshal/success.csv new file mode 100644 index 0000000..f03cefa --- /dev/null +++ b/testdata/unmarshal/success.csv @@ -0,0 +1,4 @@ +first_name,last_name,age,height +Guybrush,Threepwood,28,1.78 +Elaine,Marley,20,1.60 +LeChuck,,100,2.01 diff --git a/unmarshal.go b/unmarshal.go new file mode 100644 index 0000000..1a03071 --- /dev/null +++ b/unmarshal.go @@ -0,0 +1,102 @@ +package goflat + +import ( + "context" + "encoding/csv" + "errors" + "fmt" + "io" + + "golang.org/x/sync/errgroup" +) + +// Options is used to configure the marshalling and unmarshalling processes. +type Options struct { + headersFromStruct bool + // Strict causes goflat to error out if any struct field is missing the + // `flat` tag. + Strict bool + // ErrorIfDuplicateHeaders causes goflat to error out if two struct fields + // share the same `flat` tag value. + ErrorIfDuplicateHeaders bool + // ErrorIfMissingHeaders causes goflat to error out at unmarshalling time if + // a header has no struct field with a corresponding `flat` tag. + ErrorIfMissingHeaders bool +} + +// Unmarshaller can be used to tell goflat to use custom logic to convert the +// input string into the type itself. +type Unmarshaller interface { + Unmarshal(value string) (Unmarshaller, error) +} + +// UnmarshalToChannel unmarshals a CSV file to a channel of structs. It +// automatically closes the channel at the end. +func UnmarshalToChannel[T any](ctx context.Context, reader *csv.Reader, opts Options, outputCh chan<- T) error { + defer close(outputCh) + + headers, err := reader.Read() + if err != nil { + return fmt.Errorf("read headers: %w", err) + } + + factory, err := newFactory[T](headers, opts) + if err != nil { + return fmt.Errorf("new factory: %w", err) + } + + var currentLine int + + for { + record, err := reader.Read() + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + + return fmt.Errorf("read row: %w", err) + } + + value, err := factory.unmarshal(record) + if err != nil { + return fmt.Errorf("get struct at line %d: %w", currentLine, err) + } + + currentLine++ + + select { + case <-ctx.Done(): + return ctx.Err() //nolint:wrapcheck // No need here. + case outputCh <- value: + } + } +} + +// UnmarshalToSlice unmarshals a CSV file to a slice of structs. +func UnmarshalToSlice[T any](ctx context.Context, reader *csv.Reader, opts Options) ([]T, error) { + g, ctx := errgroup.WithContext(ctx) //nolint:varnamelen // Fine here. + + ch := make(chan T) //nolint:varnamelen // Fine here. + + var slice []T + + g.Go(func() error { + for v := range ch { + slice = append(slice, v) + } + + return nil + }) + + g.Go(func() error { + defer close(ch) + + return UnmarshalToChannel(ctx, reader, opts, ch) + }) + + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("wait: %w", err) + } + + return slice, nil +} diff --git a/unmarshal_test.go b/unmarshal_test.go new file mode 100644 index 0000000..7c53715 --- /dev/null +++ b/unmarshal_test.go @@ -0,0 +1,83 @@ +package goflat_test + +import ( + "context" + "embed" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + + "github.com/lzambarda/goflat" +) + +func TestUnmarshal(t *testing.T) { + t.Run("success", testUnmarshalSuccess) +} + +//go:embed testdata +var testdata embed.FS + +func testUnmarshalSuccess(t *testing.T) { + file, err := testdata.Open("testdata/unmarshal/success.csv") + if err != nil { + t.Fatalf("open test file: %v", err) + } + + type record struct { + FirstName string `flat:"first_name"` + LastName string `flat:"last_name"` + Age int `flat:"age"` + Height float32 `flat:"height"` + } + + expected := []record{} + + channel := make(chan record) + go assertChannel(t, channel, expected) + + ctx := context.Background() + + csvReader, err := goflat.DetectReader(file) + if err != nil { + t.Fatalf("detect reader: %v", err) + } + + options := goflat.Options{ + Strict: true, + ErrorIfDuplicateHeaders: true, + ErrorIfMissingHeaders: true, + } + + err = goflat.UnmarshalToChannel(ctx, csvReader, options, channel) + if err != nil { + t.Fatalf("unmarshal: %v", err) + } +} + +func assertChannel[T any](t *testing.T, ch <-chan T, expected []T) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + var got []T + +loop: + for { + select { + case <-ctx.Done(): + break loop + case v, ok := <-ch: + if !ok { + break loop + } + + got = append(got, v) + } + } + + var zero T + if diff := cmp.Diff(expected, got, cmp.AllowUnexported(zero)); diff != "" { + t.Errorf("(-expected,+got):\n%s", diff) + } +}