Skip to content

Commit

Permalink
Added MOS mapping optimization
Browse files Browse the repository at this point in the history
- Made it possible to optimize the MOS mapping for a given dataset with already-calculated Zimtohrli scores.
- Made it possible to print the MOS MSE for a given dataset.
  • Loading branch information
Martin Bruse authored and zond committed Jun 18, 2024
1 parent 5d9704a commit b4ad528
Show file tree
Hide file tree
Showing 5 changed files with 643 additions and 54 deletions.
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ require (
github.com/PuerkitoBio/goquery v1.9.1
github.com/dgryski/go-onlinestats v0.0.0-20170612111826-1c7d19468768
github.com/mattn/go-sqlite3 v1.14.22
gonum.org/v1/gonum v0.15.0
)

require (
github.com/aclements/go-moremath v0.0.0-20210112150236-f10218a38794 // indirect
github.com/andybalholm/cascadia v1.3.2 // indirect
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect
golang.org/x/net v0.23.0 // indirect
golang.org/x/tools v0.15.0 // indirect
)
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxU
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
Expand Down Expand Up @@ -43,4 +45,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.15.0 h1:zdAyfUGbYmuVokhzVmghFl2ZJh5QhcfebBgmVPFYA+8=
golang.org/x/tools v0.15.0/go.mod h1:hpksKq4dtpQWS1uQ61JkdqWM3LscIS6Slf+VVkm+wQk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.15.0 h1:2lYxjRbTYyxkJxlhC+LvJIx3SsANPdRybu1tGj9/OrQ=
gonum.org/v1/gonum v0.15.0/go.mod h1:xzZVBJBtS+Mz4q0Yl2LJTk+OxOg4jiXZ7qBoM0uISGo=
42 changes: 34 additions & 8 deletions go/bin/score/score.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func main() {
leaderboard := flag.String("leaderboard", "", "Glob to directories with databases to compute leaderboard for.")
report := flag.String("report", "", "Glob to directories with databases to generate a report for.")
accuracy := flag.String("accuracy", "", "Glob to directories with databases to provide JND accuracy for.")
mos_mse := flag.String("mos_mse", "", "Glob to directories with databases to provide Zimtohrli-MOS to regular-MOS MSE for.")
optimize := flag.String("optimize", "", "Glob to directories with databases to optimize for.")
optimizeLogfile := flag.String("optimize_logfile", "", "File to write optimization events to.")
optimizeStartStep := flag.Float64("optimize_start_step", 1, "Start step for the simulated annealing.")
Expand All @@ -63,7 +64,7 @@ func main() {
optimizeMapping := flag.String("optimize_mapping", "", "Glob to directories with databases to optimize the MOS mapping for.")
flag.Parse()

if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" {
if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" && *mos_mse == "" {
flag.Usage()
os.Exit(1)
}
Expand Down Expand Up @@ -99,11 +100,20 @@ func main() {
if err != nil {
log.Fatal(err)
}
params, err := bundles.OptimizeMapping()
result, err := bundles.OptimizeMapping()
if err != nil {
log.Fatal(err)
}
fmt.Println(params)
fmt.Printf("%+v\n", result)
}

makeZimtohrli := func() *goohrli.Goohrli {
if !reflect.DeepEqual(zimtohrliParameters, goohrli.DefaultParameters(zimtohrliParameters.SampleRate)) {
log.Printf("Using %+v", zimtohrliParameters)
}
zimtohrliParameters.SampleRate = sampleRate
z := goohrli.New(zimtohrliParameters)
return z
}

if *calculate != "" {
Expand All @@ -115,11 +125,7 @@ func main() {
for _, study := range studies {
measurements := map[data.ScoreType]data.Measurement{}
if *calculateZimtohrli {
if !reflect.DeepEqual(zimtohrliParameters, goohrli.DefaultParameters(zimtohrliParameters.SampleRate)) {
log.Printf("Using %+v", zimtohrliParameters)
}
zimtohrliParameters.SampleRate = sampleRate
z := goohrli.New(zimtohrliParameters)
z := makeZimtohrli()
measurements[data.ScoreType(*zimtohrliScoreType)] = z.NormalizedAudioDistance
}
if *calculateViSQOL {
Expand Down Expand Up @@ -203,6 +209,26 @@ func main() {
}
}

if *mos_mse != "" {
bundles, err := data.OpenBundles(*mos_mse)
if err != nil {
log.Fatal(err)
}
for _, bundle := range bundles {
if bundle.IsJND() {
fmt.Printf("Not computing MOS MSE for JND dataset %q\n\n", bundle.Dir)
} else {
z := makeZimtohrli()
mse, err := bundle.ZimtohrliMOSMSE(z)
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n", bundle.Dir)
fmt.Printf("MSE between human MOS and Zimtohrli MOS: %.15f\n", mse)
}
}
}

if *report != "" {
bundles, err := data.OpenBundles(*report)
if err != nil {
Expand Down
120 changes: 112 additions & 8 deletions go/data/study.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/google/zimtohrli/go/goohrli"
"github.com/google/zimtohrli/go/progress"
"github.com/google/zimtohrli/go/worker"
"gonum.org/v1/gonum/optimize"

_ "github.com/mattn/go-sqlite3" // To open sqlite3-databases.
)
Expand Down Expand Up @@ -95,9 +96,10 @@ type Study struct {

// ReferenceBundle is a plain data type containing a bunch of references, typicall the content of a study.
type ReferenceBundle struct {
Dir string
References []*Reference
ScoreTypes map[ScoreType]int
Dir string
References []*Reference
ScoreTypes map[ScoreType]int
ScoreTypeLimits map[ScoreType][2]*float64
}

// ReferenceBundles is a slice of ReferenceBundle.
Expand All @@ -122,8 +124,20 @@ func (r *ReferenceBundle) SortedTypes() ScoreTypes {
// Add adds a reference to a bundle.
func (r *ReferenceBundle) Add(ref *Reference) {
for _, dist := range ref.Distortions {
for scoreType := range dist.Scores {
for scoreType, value := range dist.Scores {
r.ScoreTypes[scoreType]++
if r.ScoreTypeLimits[scoreType][0] == nil || *r.ScoreTypeLimits[scoreType][0] > value {
valueCopy := value
limits := r.ScoreTypeLimits[scoreType]
limits[0] = &valueCopy
r.ScoreTypeLimits[scoreType] = limits
}
if r.ScoreTypeLimits[scoreType][1] == nil || *r.ScoreTypeLimits[scoreType][1] < value {
valueCopy := value
limits := r.ScoreTypeLimits[scoreType]
limits[1] = &valueCopy
r.ScoreTypeLimits[scoreType] = limits
}
}
}
r.References = append(r.References, ref)
Expand All @@ -132,8 +146,9 @@ func (r *ReferenceBundle) Add(ref *Reference) {
// ToBundle returns a reference bundle for this study.
func (s *Study) ToBundle() (*ReferenceBundle, error) {
result := &ReferenceBundle{
Dir: s.dir,
ScoreTypes: map[ScoreType]int{},
Dir: s.dir,
ScoreTypes: map[ScoreType]int{},
ScoreTypeLimits: map[ScoreType][2]*float64{},
}
if err := s.ViewEachReference(func(ref *Reference) error {
result.Add(ref)
Expand Down Expand Up @@ -384,6 +399,44 @@ func (r *ReferenceBundle) JNDAccuracy() (JNDAccuracyScores, error) {
return result, nil
}

// MOSMSE returns the precision when predicting the MOS score.
func (r *ReferenceBundle) ZimtohrliMOSMSE(z *goohrli.Goohrli) (float64, error) {
if r.IsJND() {
return 0, fmt.Errorf("cannot compute MOS precision on JND references")
}
if _, found := r.ScoreTypes[MOS]; !found {
return 0, fmt.Errorf("cannot compute MOS precision on a data set without MOS")
}

var mosScaler func(mos float64) float64
if math.Abs(*r.ScoreTypeLimits[MOS][0]-1) < 0.2 && math.Abs(*r.ScoreTypeLimits[MOS][1]-5) < 0.2 {
mosScaler = func(mos float64) float64 {
return mos
}
} else if math.Abs(*r.ScoreTypeLimits[MOS][0]) < 0.2 && math.Abs(*r.ScoreTypeLimits[MOS][1]-100) < 0.2 {
mosScaler = func(mos float64) float64 {
return 1 + 0.04*mos
}
} else {
return 0, fmt.Errorf("minimum MOS %v and maximum MOS %v are confusing", *r.ScoreTypeLimits[MOS][0], *r.ScoreTypeLimits[MOS][1])
}

sumOfSquares := 0.0
count := 0
for _, ref := range r.References {
for _, dist := range ref.Distortions {
mos, found := dist.Scores[MOS]
if !found {
return 0, fmt.Errorf("%+v doesn't have a MOS score", ref)
}
delta := mosScaler(mos) - z.MOSFromZimtohrli(dist.Scores[Zimtohrli])
sumOfSquares += delta * delta
count++
}
}
return sumOfSquares / float64(count), nil
}

// Studies is a slice of studies.
type Studies []*Study

Expand Down Expand Up @@ -516,8 +569,59 @@ func (r ReferenceBundles) Split(rng *rand.Rand, split float64) (ReferenceBundles
return left, right
}

func (r ReferenceBundles) OptimizeMapping() ([]float32, error) {
return nil, nil
type MappingOptimizationResult struct {
ParamsBefore []float64
MSEBefore float64
ParamsAfter []float64
MSEAfter float64
}

func (r ReferenceBundles) OptimizeMapping() (*MappingOptimizationResult, error) {
z := goohrli.New(goohrli.DefaultParameters(48000))
errors := []error{}
p := optimize.Problem{
Func: func(x []float64) float64 {
params := z.Parameters()
for index := range params.MOSMapperParams {
params.MOSMapperParams[index] = math.Abs(x[index])
}
z.Set(params)
sum := 0.0
count := 0
for _, bundle := range r {
if !bundle.IsJND() {
mse, err := bundle.ZimtohrliMOSMSE(z)
if err != nil {
errors = append(errors, err)
}
sum += mse
count += 1
}
}
return sum / float64(count)
},
Status: func() (optimize.Status, error) {
if len(errors) > 0 {
return optimize.Failure, fmt.Errorf("%+v", errors)
}
return optimize.NotTerminated, nil
},
}
startParams := z.Parameters().MOSMapperParams
result := &MappingOptimizationResult{
ParamsBefore: startParams[:],
MSEBefore: p.Func(startParams[:]),
}
optResult, err := optimize.Minimize(p, startParams[:], nil, nil)
if err != nil {
return nil, err
}
if err := optResult.Status.Err(); err != nil {
return nil, err
}
result.ParamsAfter = optResult.X
result.MSEAfter = optResult.F
return result, nil
}

// OptimizationEvent is a step in the optimization process.
Expand Down
Loading

0 comments on commit b4ad528

Please sign in to comment.