diff --git a/spool/spool.go b/spool/spool.go index 1acadda..fc72375 100644 --- a/spool/spool.go +++ b/spool/spool.go @@ -63,6 +63,9 @@ var ( WithSkipError = withStrict(skiperror) ) +// Spool handler function +type Spoolf = func(context.Context, string, io.Reader) (io.ReadCloser, error) + type Spool struct { reader FileSystem writer FileSystem @@ -81,13 +84,36 @@ func New(reader, writer FileSystem, opt ...opts.Option[Spool]) *Spool { return s } -// Write new file to spool -func (spool *Spool) Write(path string, r io.Reader) error { - return spool.write(spool.reader, path, r) -} +// apply spool function over the file +func (spool *Spool) apply(ctx context.Context, path string, f Spoolf) error { + fd, err := spool.reader.Open(path) + if err != nil { + return spool.iserr(err) + } + defer fd.Close() -func (spool *Spool) WriteFile(path string, b []byte) error { - return spool.Write(path, bytes.NewBuffer(b)) + dd, err := f(ctx, path, fd) + if err != nil { + return spool.iserr(err) + } + if dd == nil { + return nil + } + defer dd.Close() + + err = spool.write(spool.writer, path, dd) + if err != nil { + return spool.iserr(err) + } + + if spool.mutable == mutable { + err = spool.reader.Remove(path) + if err != nil { + return spool.iserr(err) + } + } + + return nil } func (spool *Spool) iserr(err error) error { @@ -99,13 +125,18 @@ func (spool *Spool) iserr(err error) error { return err } +// Write new file to spool +func (spool *Spool) Write(path string, r io.Reader) error { + return spool.write(spool.reader, path, r) +} + +func (spool *Spool) WriteFile(path string, b []byte) error { + return spool.Write(path, bytes.NewBuffer(b)) +} + // Apply the spool function over each file in the reader filesystem, producing // results to writer file system. -func (spool *Spool) ForEach( - ctx context.Context, - dir string, - f func(context.Context, string, io.Reader) (io.ReadCloser, error), -) error { +func (spool *Spool) ForEach(ctx context.Context, dir string, f Spoolf) error { return fs.WalkDir(spool.reader, dir, func(path string, d fs.DirEntry, err error) error { if err != nil { @@ -116,31 +147,8 @@ func (spool *Spool) ForEach( return nil } - fd, err := spool.reader.Open(path) - if err != nil { - return spool.iserr(err) - } - defer fd.Close() - - dd, err := f(ctx, path, fd) - if err != nil { - return spool.iserr(err) - } - if dd == nil { - return nil - } - defer dd.Close() - - err = spool.write(spool.writer, path, dd) - if err != nil { - return spool.iserr(err) - } - - if spool.mutable == mutable { - err = spool.reader.Remove(path) - if err != nil { - return spool.iserr(err) - } + if err := spool.apply(ctx, path, f); err != nil { + return err } return nil @@ -148,6 +156,18 @@ func (spool *Spool) ForEach( ) } +// Apply the spool function over all file in the reader filesystem, producing +// results to writer file system. +func (spool *Spool) ForEachPath(ctx context.Context, paths []string, f Spoolf) error { + for _, path := range paths { + if err := spool.apply(ctx, path, f); err != nil { + return err + } + } + + return nil +} + // Apply the spool function over each file in the reader filesystem, producing // results to writer file system. It is a variant of [ForEach] that used bytes slices. func (spool *Spool) ForEachFile( diff --git a/spool/spool_test.go b/spool/spool_test.go index 574b93a..8c0c80a 100644 --- a/spool/spool_test.go +++ b/spool/spool_test.go @@ -10,6 +10,7 @@ package spool_test import ( "context" + "io" "os" "testing" @@ -46,6 +47,34 @@ func TestSpoolForEach(t *testing.T) { ) } +func TestSpoolForEachPath(t *testing.T) { + in, err := lfs.NewTempFS(os.TempDir(), "in") + it.Then(t).Must(it.Nil(err)) + + to, err := lfs.NewTempFS(os.TempDir(), "to") + it.Then(t).Must(it.Nil(err)) + + qq := spool.New(in, to) + + seq := []string{"/a", "/b", "/c", "/d", "/e", "/f"} + for _, txt := range seq { + err := qq.WriteFile(txt, []byte(txt)) + it.Then(t).Must(it.Nil(err)) + } + + dat := []string{} + qq.ForEachPath(context.Background(), seq, + func(ctx context.Context, path string, r io.Reader) (io.ReadCloser, error) { + dat = append(dat, path) + return io.NopCloser(r), nil + }, + ) + + it.Then(t).Should( + it.Seq(seq).Equal(dat...), + ) +} + func TestSpoolPartition(t *testing.T) { in, err := lfs.NewTempFS(os.TempDir(), "in") it.Then(t).Must(it.Nil(err))