diff --git a/spool/spool.go b/spool/spool.go index b8c302c..da81b00 100644 --- a/spool/spool.go +++ b/spool/spool.go @@ -13,6 +13,7 @@ import ( "context" "io" "io/fs" + "path/filepath" "github.com/fogfish/opts" "github.com/fogfish/stream" @@ -172,6 +173,82 @@ func (spool *Spool) ForEachFile( ) } +// Apply the parition function over each file in the reader filesystem, producing +// results to writer file system. +func (spool *Spool) Partition( + ctx context.Context, + dir string, + f func(context.Context, string, io.Reader) (string, error), +) error { + return fs.WalkDir(spool.reader, dir, + func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if d.IsDir() { + return nil + } + + fd, err := spool.reader.Open(path) + if err != nil { + return spool.iserr(err) + } + defer fd.Close() + + shard, err := f(ctx, path, fd) + if err != nil { + return spool.iserr(err) + } + if len(shard) == 0 { + return nil + } + + cp, err := spool.reader.Open(path) + if err != nil { + return spool.iserr(err) + } + defer cp.Close() + + err = spool.write(spool.writer, filepath.Join("/", shard, path), cp) + if err != nil { + return spool.iserr(err) + } + + if spool.mutable == mutable { + err = spool.reader.Remove(path) + if err != nil { + return spool.iserr(err) + } + } + + return nil + }, + ) +} + +func (spool *Spool) PartitionFile( + ctx context.Context, + dir string, + f func(context.Context, string, []byte) (string, error), +) error { + return spool.Partition(ctx, dir, + func(ctx context.Context, path string, r io.Reader) (string, error) { + in, err := io.ReadAll(r) + if err != nil { + return "", err + } + + shard, err := f(ctx, path, in) + if err != nil { + return "", err + } + + return shard, nil + }, + ) +} + func (spool *Spool) write(fs stream.CreateFS[struct{}], path string, r io.Reader) error { fd, err := fs.Create(path, nil) if err != nil { diff --git a/spool/spool_test.go b/spool/spool_test.go index ad503af..574b93a 100644 --- a/spool/spool_test.go +++ b/spool/spool_test.go @@ -18,7 +18,7 @@ import ( "github.com/fogfish/stream/spool" ) -func TestSpool(t *testing.T) { +func TestSpoolForEach(t *testing.T) { in, err := lfs.NewTempFS(os.TempDir(), "in") it.Then(t).Must(it.Nil(err)) @@ -45,3 +45,31 @@ func TestSpool(t *testing.T) { it.Seq(seq).Equal(dat...), ) } + +func TestSpoolPartition(t *testing.T) { + in, err := lfs.NewTempFS(os.TempDir(), "in") + it.Then(t).Must(it.Nil(err)) + + to, err := lfs.NewTempFS(os.TempDir(), "to") + it.Then(t).Must(it.Nil(err)) + + qq := spool.New(in, to) + + seq := []string{"/a", "/b", "/c", "/d", "/e", "/f"} + for _, txt := range seq { + err := qq.WriteFile(txt, []byte(txt)) + it.Then(t).Must(it.Nil(err)) + } + + dat := []string{} + qq.PartitionFile(context.Background(), "/", + func(ctx context.Context, path string, b []byte) (string, error) { + dat = append(dat, path) + return path, nil + }, + ) + + it.Then(t).Should( + it.Seq(seq).Equal(dat...), + ) +}