Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 56 additions & 36 deletions spool/spool.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ var (
WithSkipError = withStrict(skiperror)
)

// Spool handler function
type Spoolf = func(context.Context, string, io.Reader) (io.ReadCloser, error)

type Spool struct {
reader FileSystem
writer FileSystem
Expand All @@ -81,13 +84,36 @@ func New(reader, writer FileSystem, opt ...opts.Option[Spool]) *Spool {
return s
}

// Write new file to spool
func (spool *Spool) Write(path string, r io.Reader) error {
return spool.write(spool.reader, path, r)
}
// apply spool function over the file
func (spool *Spool) apply(ctx context.Context, path string, f Spoolf) error {
fd, err := spool.reader.Open(path)
if err != nil {
return spool.iserr(err)
}
defer fd.Close()

func (spool *Spool) WriteFile(path string, b []byte) error {
return spool.Write(path, bytes.NewBuffer(b))
dd, err := f(ctx, path, fd)
if err != nil {
return spool.iserr(err)
}
if dd == nil {
return nil
}
defer dd.Close()

err = spool.write(spool.writer, path, dd)
if err != nil {
return spool.iserr(err)
}

if spool.mutable == mutable {
err = spool.reader.Remove(path)
if err != nil {
return spool.iserr(err)
}
}

return nil
}

func (spool *Spool) iserr(err error) error {
Expand All @@ -99,13 +125,18 @@ func (spool *Spool) iserr(err error) error {
return err
}

// Write new file to spool
func (spool *Spool) Write(path string, r io.Reader) error {
return spool.write(spool.reader, path, r)
}

func (spool *Spool) WriteFile(path string, b []byte) error {
return spool.Write(path, bytes.NewBuffer(b))
}

// Apply the spool function over each file in the reader filesystem, producing
// results to writer file system.
func (spool *Spool) ForEach(
ctx context.Context,
dir string,
f func(context.Context, string, io.Reader) (io.ReadCloser, error),
) error {
func (spool *Spool) ForEach(ctx context.Context, dir string, f Spoolf) error {
return fs.WalkDir(spool.reader, dir,
func(path string, d fs.DirEntry, err error) error {
if err != nil {
Expand All @@ -116,38 +147,27 @@ func (spool *Spool) ForEach(
return nil
}

fd, err := spool.reader.Open(path)
if err != nil {
return spool.iserr(err)
}
defer fd.Close()

dd, err := f(ctx, path, fd)
if err != nil {
return spool.iserr(err)
}
if dd == nil {
return nil
}
defer dd.Close()

err = spool.write(spool.writer, path, dd)
if err != nil {
return spool.iserr(err)
}

if spool.mutable == mutable {
err = spool.reader.Remove(path)
if err != nil {
return spool.iserr(err)
}
if err := spool.apply(ctx, path, f); err != nil {
return err
}

return nil
},
)
}

// Apply the spool function over all file in the reader filesystem, producing
// results to writer file system.
func (spool *Spool) ForEachPath(ctx context.Context, paths []string, f Spoolf) error {
for _, path := range paths {
if err := spool.apply(ctx, path, f); err != nil {
return err
}
}

return nil
}

// Apply the spool function over each file in the reader filesystem, producing
// results to writer file system. It is a variant of [ForEach] that used bytes slices.
func (spool *Spool) ForEachFile(
Expand Down
29 changes: 29 additions & 0 deletions spool/spool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package spool_test

import (
"context"
"io"
"os"
"testing"

Expand Down Expand Up @@ -46,6 +47,34 @@ func TestSpoolForEach(t *testing.T) {
)
}

func TestSpoolForEachPath(t *testing.T) {
in, err := lfs.NewTempFS(os.TempDir(), "in")
it.Then(t).Must(it.Nil(err))

to, err := lfs.NewTempFS(os.TempDir(), "to")
it.Then(t).Must(it.Nil(err))

qq := spool.New(in, to)

seq := []string{"/a", "/b", "/c", "/d", "/e", "/f"}
for _, txt := range seq {
err := qq.WriteFile(txt, []byte(txt))
it.Then(t).Must(it.Nil(err))
}

dat := []string{}
qq.ForEachPath(context.Background(), seq,
func(ctx context.Context, path string, r io.Reader) (io.ReadCloser, error) {
dat = append(dat, path)
return io.NopCloser(r), nil
},
)

it.Then(t).Should(
it.Seq(seq).Equal(dat...),
)
}

func TestSpoolPartition(t *testing.T) {
in, err := lfs.NewTempFS(os.TempDir(), "in")
it.Then(t).Must(it.Nil(err))
Expand Down
Loading