diff --git a/mp4/stts.go b/mp4/stts.go index 79898b75..71cf5085 100644 --- a/mp4/stts.go +++ b/mp4/stts.go @@ -166,12 +166,15 @@ func (b *SttsBox) Info(w io.Writer, specificBoxLevels, indent, indentStep string return bd.err } -// GetSampleNrAtTime - get sample number at or as soon as possible after time +// GetSampleNrAtTime returns the 1-based sample number at or as soon as possible after time. +// Match a final single zero duration if present. +// If time is too big to reach, an error is returned. // Time is calculated by summing up durations of previous samples func (b *SttsBox) GetSampleNrAtTime(sampleStartTime uint64) (sampleNr uint32, err error) { accTime := uint64(0) accNr := uint32(0) - for i := 0; i < len(b.SampleCount); i++ { + nrEntries := len(b.SampleCount) + for i := 0; i < nrEntries; i++ { timeDelta := uint64(b.SampleTimeDelta[i]) if sampleStartTime < accTime+uint64(b.SampleCount[i])*timeDelta { relTime := (sampleStartTime - accTime) @@ -184,5 +187,10 @@ func (b *SttsBox) GetSampleNrAtTime(sampleStartTime uint64) (sampleNr uint32, er accNr += b.SampleCount[i] accTime += timeDelta * uint64(b.SampleCount[i]) } + // Check if there is a final single zero duration and time matches. + if b.SampleTimeDelta[nrEntries-1] == 0 && b.SampleCount[nrEntries-1] == 1 && + sampleStartTime == accTime { + return accNr, nil + } return 0, fmt.Errorf("no matching sample found for time=%d", sampleStartTime) } diff --git a/mp4/stts_test.go b/mp4/stts_test.go index c2af1384..fa2f3cd7 100644 --- a/mp4/stts_test.go +++ b/mp4/stts_test.go @@ -17,26 +17,37 @@ func TestGetSampleNrAtTime(t *testing.T) { SampleTimeDelta: []uint32{10, 14}, } + sttsZero := SttsBox{ + SampleCount: []uint32{2, 1}, + SampleTimeDelta: []uint32{10, 0}, // Single zero duration at end + } + testCases := []struct { + stts SttsBox startTime uint64 sampleNr uint32 expectError bool }{ - {0, 1, false}, - {1, 2, false}, - {10, 2, false}, - {20, 3, false}, - {30, 4, false}, - {31, 5, false}, - {43, 5, false}, - {44, 5, false}, - {45, 6, false}, - {57, 6, false}, - {58, 0, true}, + {stts, 0, 1, false}, + {stts, 1, 2, false}, + {stts, 10, 2, false}, + {stts, 20, 3, false}, + {stts, 30, 4, false}, + {stts, 31, 5, false}, + {stts, 43, 5, false}, + {stts, 44, 5, false}, + {stts, 45, 6, false}, + {stts, 57, 6, false}, + {stts, 58, 0, true}, + {sttsZero, 0, 1, false}, + {sttsZero, 10, 2, false}, + {sttsZero, 19, 3, false}, + {sttsZero, 20, 3, false}, + {sttsZero, 21, 0, true}, } for _, tc := range testCases { - gotNr, err := stts.GetSampleNrAtTime(tc.startTime) + gotNr, err := tc.stts.GetSampleNrAtTime(tc.startTime) if tc.expectError { if err == nil { t.Errorf("Did not get error for startTime %d", tc.startTime)