Skip to content

Commit

Permalink
Sample datasets to create new datasets.
Browse files Browse the repository at this point in the history
  • Loading branch information
zond authored and Martin Bruse committed Jun 27, 2024
1 parent 2418507 commit 43c39f8
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 1 deletion.
47 changes: 46 additions & 1 deletion go/bin/score/score.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import (
"flag"
"fmt"
"log"
"math/rand"
"os"
"path/filepath"
"reflect"
"runtime"
"sort"
Expand Down Expand Up @@ -58,9 +60,13 @@ func main() {
workers := flag.Int("workers", runtime.NumCPU(), "Number of concurrent workers for tasks.")
failFast := flag.Bool("fail_fast", false, "Whether to panic immediately on any error.")
optimizeMapping := flag.String("optimize_mapping", "", "Glob to directories with databases to optimize the MOS mapping for.")
sample := flag.String("sample", "", "Glob to directories with databases to sample metadata and audio from.")
sampleDestination := flag.String("sample_destination", "", "Path to directory to put the sampled databases into.")
sampleFraction := flag.Float64("sample_fraction", 1.0, "Fraction of references to copy from the source databases.")
sampleSeed := flag.Int64("sample_seed", 0, "Seed when sampling a random fraction of references.")
flag.Parse()

if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" && *mse == "" && *optimizedMSE == "" {
if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" && *mse == "" && *optimizedMSE == "" && *sample == "" {
flag.Usage()
os.Exit(1)
}
Expand All @@ -72,6 +78,45 @@ func main() {
log.Fatalf("Zimtohrli sample rates != %v not supported by this tool, since it loads all data set audio at %vHz.", aio.DefaultSampleRate, aio.DefaultSampleRate)
}

if *sample != "" {
if *sampleDestination == "" {
log.Fatal("`-sample_destination` required for sample operation")
}
bundles, err := data.OpenBundles(*sample)
if err != nil {
log.Fatal(err)
}
rng := rand.New(rand.NewSource(*sampleSeed))
for _, bundle := range bundles {
dest, err := data.OpenStudy(filepath.Join(*sampleDestination, filepath.Base(bundle.Dir)))
if err != nil {
log.Fatal(err)
}
func() {
defer dest.Close()
if *sampleFraction == 1.0 {
bar := progress.New(fmt.Sprintf("Copying %q", filepath.Base(bundle.Dir)))
if err := dest.Copy(bundle.Dir, bundle.References, bar.Update); err != nil {
log.Fatal(err)
}
bar.Finish()
} else {
numRefs := len(bundle.References)
numWanted := int(*sampleFraction * float64(numRefs))
toCopy := []*data.Reference{}
bar := progress.New(fmt.Sprintf("Copying %v of %q", *sampleFraction, filepath.Base(bundle.Dir)))
for _, index := range rng.Perm(numRefs)[:numWanted] {
toCopy = append(toCopy, bundle.References[index])
}
if err := dest.Copy(bundle.Dir, toCopy, bar.Update); err != nil {
log.Fatal(err)
}
bar.Finish()
}
}()
}
}

if *optimize != "" {
bundles, err := data.OpenBundles(*optimize)
if err != nil {
Expand Down
40 changes: 40 additions & 0 deletions go/data/study.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ type ReferenceBundle struct {
mosScaler func(float64) float64
}

// EachDistortion executes f for each distortion in the bundle.
func (r *ReferenceBundle) EachDistortion(f func(*Reference, *Distortion) error) error {
for _, ref := range r.References {
for _, dist := range ref.Distortions {
if err := f(ref, dist); err != nil {
return err
}
}
}
return nil
}

// ReferenceBundles is a slice of ReferenceBundle.
type ReferenceBundles []*ReferenceBundle

Expand Down Expand Up @@ -1072,6 +1084,34 @@ func (s *Study) ViewEachReference(f func(*Reference) error) error {
return nil
}

// Copy inserts some reference into a study, and copies the audio files of the references and their distortions, assuming they are relative to the provided directory.
func (s *Study) Copy(dir string, refs []*Reference, progress func(int, int, int)) error {
for index, ref := range refs {
refCopy := &Reference{}
*refCopy = *ref
newRefPath := fmt.Sprintf("%v_%v", filepath.Base(dir), filepath.Base(ref.Path))
refCopy.Path = newRefPath
if err := os.Symlink(filepath.Join(dir, ref.Path), filepath.Join(s.dir, newRefPath)); err != nil {
return err
}
for index, dist := range ref.Distortions {
distCopy := &Distortion{}
*distCopy = *dist
newDistPath := fmt.Sprintf("%v_%v", filepath.Base(dir), filepath.Base(dist.Path))
distCopy.Path = newDistPath
if err := os.Symlink(filepath.Join(dir, dist.Path), filepath.Join(s.dir, newDistPath)); err != nil {
return err
}
refCopy.Distortions[index] = distCopy
}
if err := s.Put([]*Reference{refCopy}); err != nil {
return err
}
progress(len(refs), index, 0)
}
return nil
}

// Put inserts some references into a study.
func (s *Study) Put(refs []*Reference) error {
tx, err := s.db.Begin()
Expand Down
Binary file modified go/goohrli/goohrli.a
Binary file not shown.

0 comments on commit 43c39f8

Please sign in to comment.