Skip to content

Commit

Permalink
Update AudioFrameDecoder to be usable in an iterative manner.
Browse files Browse the repository at this point in the history
  - This interface now supports repeated calls to `Decode`. Upon returning from `Decode` all input frames are done being processed.
  - Update `encoder_main_lib` to decode the frames as they are ready - instead of waiting until the end and they have all been buffered.

PiperOrigin-RevId: 628380206
  • Loading branch information
jwcullen committed Apr 30, 2024
1 parent c97801f commit b127705
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 60 deletions.
83 changes: 36 additions & 47 deletions iamf/cli/audio_frame_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ absl::Status InitializeWavWriterForSubstreamId(
: file_directory / file_name;

wav_writers.emplace(
substream_id,
iamf_tools::WavWriter(wav_path, num_channels, sample_rate, bit_depth));
substream_id, WavWriter(wav_path, num_channels, sample_rate, bit_depth));
return absl::OkStatus();
}

Expand Down Expand Up @@ -182,68 +181,58 @@ void AbortAllWavWriters(

} // namespace

absl::Status AudioFrameDecoder::Decode(
const std::list<AudioFrameWithData>& encoded_audio_frames,
std::list<DecodedAudioFrame>& decoded_audio_frames) {
// A map of substream IDs to the relevant decoder and codec config. This is
// necessary to process streams with stateful decoders correctly.
absl::node_hash_map<uint32_t, std::unique_ptr<DecoderBase>>
substream_id_to_decoder;
// A map of substream IDs to the relevant wav writer.
absl::node_hash_map<uint32_t, WavWriter> substream_id_to_wav_writer;

// Initialize all decoders and find all corresponding Codec Config OBUs.
for (const auto& audio_frame : encoded_audio_frames) {
const uint32_t substream_id = audio_frame.obu.GetSubstreamId();
auto& decoder = substream_id_to_decoder[substream_id];
if (decoder) {
// Already found the information for this stream.
continue;
// Initializes all decoders and wav writers based on the corresponding Audio
// Element and Codec Config OBUs.
absl::Status AudioFrameDecoder::InitDecodersForSubstreams(
const SubstreamIdLabelsMap& substream_id_to_labels,
const CodecConfigObu& codec_config) {
for (const auto& [substream_id, labels] : substream_id_to_labels) {
auto& decoder = substream_id_to_decoder_[substream_id];
if (decoder != nullptr) {
return absl::InvalidArgumentError(absl::StrCat(
"Already initialized decoder for substream ID: ", substream_id,
". Maybe multiple Audio Element OBUs have the same substream ID?"));
}
if (audio_frame.audio_element_with_data == nullptr ||
audio_frame.audio_element_with_data->codec_config == nullptr) {
LOG(ERROR) << "Unexpected nullptr in an audio frame with id="
<< substream_id;
return absl::UnknownError("");
}

const auto& audio_element = *audio_frame.audio_element_with_data;

const auto& iter = audio_element.substream_id_to_labels.find(substream_id);
if (iter == audio_element.substream_id_to_labels.end()) {
LOG(ERROR) << "Unknown number of channels for substream id: "
<< substream_id;
return absl::UnknownError("");
}
const int num_channels = static_cast<int>(iter->second.size());
const int num_channels = static_cast<int>(labels.size());

// Initialize the decoder based on the found Codec Config OBU and number of
// channels.
RETURN_IF_NOT_OK(
InitializeDecoder(*audio_element.codec_config, num_channels, decoder));
// Initialize the decoder based on the found Codec Config OBU and number
// of channels.
RETURN_IF_NOT_OK(InitializeDecoder(codec_config, num_channels, decoder));
RETURN_IF_NOT_OK(InitializeWavWriterForSubstreamId(
substream_id, output_wav_directory_, file_prefix_, num_channels,
static_cast<int>(audio_element.codec_config->GetOutputSampleRate()),
static_cast<int>(
audio_element.codec_config->GetBitDepthToMeasureLoudness()),
substream_id_to_wav_writer));
static_cast<int>(codec_config.GetOutputSampleRate()),
static_cast<int>(codec_config.GetBitDepthToMeasureLoudness()),
substream_id_to_wav_writer_));
}

return absl::OkStatus();
}

absl::Status AudioFrameDecoder::Decode(
const std::list<AudioFrameWithData>& encoded_audio_frames,
std::list<DecodedAudioFrame>& decoded_audio_frames) {
// Decode all frames in all substreams.
for (const auto& audio_frame : encoded_audio_frames) {
auto decoder_iter =
substream_id_to_decoder_.find(audio_frame.obu.GetSubstreamId());
if (decoder_iter == substream_id_to_decoder_.end()) {
return absl::InvalidArgumentError(
absl::StrCat("No decoder found for substream ID: ",
audio_frame.obu.GetSubstreamId()));
}

DecodedAudioFrame decoded_audio_frame;
auto decode_status = DecodeAudioFrame(
audio_frame,
substream_id_to_decoder.at(audio_frame.obu.GetSubstreamId()).get(),
decoded_audio_frame);
audio_frame, decoder_iter->second.get(), decoded_audio_frame);
if (!decode_status.ok()) {
LOG(ERROR) << "Failed to decode audio streams. decode_status: "
<< decode_status;
AbortAllWavWriters(substream_id_to_wav_writer);
AbortAllWavWriters(substream_id_to_wav_writer_);
return decode_status;
}
RETURN_IF_NOT_OK(DumpDecodedAudioFrameToWavWriter(
decoded_audio_frame, substream_id_to_wav_writer));
decoded_audio_frame, substream_id_to_wav_writer_));
decoded_audio_frames.push_back(decoded_audio_frame);
}

Expand Down
34 changes: 33 additions & 1 deletion iamf/cli/audio_frame_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@

#include <cstdint>
#include <list>
#include <memory>
#include <string>
#include <vector>

#include "absl/container/node_hash_map.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "iamf/cli/audio_element_with_data.h"
#include "iamf/cli/audio_frame_with_data.h"
#include "iamf/cli/codec/decoder_base.h"
#include "iamf/cli/wav_writer.h"
#include "iamf/obu/codec_config.h"

namespace iamf_tools {

Expand Down Expand Up @@ -49,6 +54,14 @@ struct DecodedAudioFrame {
* This class manages the underlying codec decoders for all substreams. Codec
* decoders may be stateful; this class manages a one-to-one mapping between
* codec decoders and substream.
*
* Call `InitDecodersForSubstreams` with pairs of `SubstreamIdLabelsMap` and
* `CodecConfigObu`. This typically will require one call per Audio Element OBU.
*
* Then call `Decode` repeatedly with a list of `AudioFrameWithData`. There may
* be multiple `AudioFrameWithData`s in a single call to this function. Each
* substream in the list is assumed to be self-consistent in temporal order. It
* is permitted in any order relative to other substreams.
*/
class AudioFrameDecoder {
public:
Expand All @@ -62,7 +75,18 @@ class AudioFrameDecoder {
: output_wav_directory_(output_wav_directory),
file_prefix_(file_prefix) {}

// TODO(b/306319126): Decode one audio frame at a time.
/*\!brief Initialize codec decoders for each substream.
*
* \param substream_id_to_labels Substreams and their associated labels to
* initialize. The number of channels is determined by the number of
* labels.
* \param codec_config Codec Config OBU to use for all substreams.
* \return `absl::OkStatus()` on success. A specific status on failure.
*/
absl::Status InitDecodersForSubstreams(
const SubstreamIdLabelsMap& substream_id_to_labels,
const CodecConfigObu& codec_config);

/*\!brief Decodes a list of Audio Frame OBUs.
*
* \param encoded_audio_frames Input Audio Frame OBUs.
Expand All @@ -75,6 +99,14 @@ class AudioFrameDecoder {
private:
const std::string output_wav_directory_;
const std::string file_prefix_;

// A map of substream IDs to the relevant decoder and codec config. This is
// necessary to process streams with stateful decoders correctly.
absl::node_hash_map<uint32_t, std::unique_ptr<DecoderBase>>
substream_id_to_decoder_;

// A map of substream IDs to the relevant wav writer.
absl::node_hash_map<uint32_t, WavWriter> substream_id_to_wav_writer_;
};

} // namespace iamf_tools
Expand Down
38 changes: 30 additions & 8 deletions iamf/cli/encoder_main_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,23 @@ absl::Status CreateOutputDirectory(const std::string& output_directory) {
return absl::OkStatus();
}

absl::Status InitAudioFrameDecoderForAllAudioElements(
const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>&
audio_elements,
AudioFrameDecoder& audio_frame_decoder) {
for (const auto& [_, audio_element] : audio_elements) {
if (audio_element.codec_config == nullptr) {
// Skip stray audio elements. We won't know how to decode their
// substreams.
continue;
}

RETURN_IF_NOT_OK(audio_frame_decoder.InitDecodersForSubstreams(
audio_element.substream_id_to_labels, *audio_element.codec_config));
}
return absl::OkStatus();
}

absl::Status GenerateObus(
const iamf_tools_cli_proto::UserMetadata& user_metadata,
const std::string& input_wav_directory,
Expand Down Expand Up @@ -239,9 +256,17 @@ absl::Status GenerateObus(
output_wav_directory,
user_metadata.test_vector_metadata().file_name_prefix(), demixing_module,
parameters_manager, global_timing_module);

RETURN_IF_NOT_OK(audio_frame_generator.Initialize());

// Initialize the audio frame decoder. It is needed to determine the recon
// gain parameters and measure the loudness of the mixes.
std::list<DecodedAudioFrame> decoded_audio_frames;
AudioFrameDecoder audio_frame_decoder(
output_wav_directory,
user_metadata.test_vector_metadata().file_name_prefix());
RETURN_IF_NOT_OK(InitAudioFrameDecoderForAllAudioElements(
audio_elements, audio_frame_decoder));

// TODO(b/315924757): Currently getting all parameter blocks corresponding to
// a timestamp from `parameter_blocks` to simulate
// iterative generations.
Expand Down Expand Up @@ -280,18 +305,15 @@ absl::Status GenerateObus(
if (temp_audio_frames.empty()) {
absl::SleepFor(absl::Milliseconds(50));
} else {
// Decode all of the newly encoded frames and collect them.
RETURN_IF_NOT_OK(
audio_frame_decoder.Decode(temp_audio_frames, decoded_audio_frames));

audio_frames.splice(audio_frames.end(), temp_audio_frames);
}
}
PrintAudioFrames(audio_frames);

AudioFrameDecoder audio_frame_decoder(
output_wav_directory,
user_metadata.test_vector_metadata().file_name_prefix());
std::list<DecodedAudioFrame> decoded_audio_frames;
RETURN_IF_NOT_OK(
audio_frame_decoder.Decode(audio_frames, decoded_audio_frames));

// Demix audio samples; useful for the following operations.
IdTimeLabeledFrameMap id_to_time_to_labeled_frame;
IdTimeLabeledFrameMap id_to_time_to_labeled_decoded_frame;
Expand Down
84 changes: 80 additions & 4 deletions iamf/cli/tests/audio_frame_decoder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const int kNumSamplesPerFrame = 8;
const int kBytesPerSample = 2;
constexpr absl::string_view kWavFilePrefix = "test";

TEST(AudioFrameDecoderTest, NoAudioFrames) {
TEST(Decode, SucceedsOnEmptyInput) {
AudioFrameDecoder decoder(::testing::TempDir(), kWavFilePrefix);

std::list<DecodedAudioFrame> decoded_audio_frames;
Expand Down Expand Up @@ -64,14 +64,88 @@ std::list<AudioFrameWithData> PrepareEncodedAudioFrames(
return encoded_audio_frames;
}

TEST(AudioFrameDecoderTest, DecodeLpcmFrame) {
TEST(Decode, RequiresSubstreamsAreInitialized) {
AudioFrameDecoder decoder(::testing::TempDir(), kWavFilePrefix);
// Encoded frames.
absl::flat_hash_map<uint32_t, CodecConfigObu> codec_config_obus;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
std::list<AudioFrameWithData> encoded_audio_frames =
PrepareEncodedAudioFrames(codec_config_obus, audio_elements);

// Decoding fails before substreams are initialized.
std::list<DecodedAudioFrame> decoded_audio_frames;
EXPECT_FALSE(decoder.Decode(encoded_audio_frames, decoded_audio_frames).ok());
const auto& audio_element = audio_elements.at(kAudioElementId);
// Decoding succeeds after substreams are initialized.
EXPECT_TRUE(
decoder
.InitDecodersForSubstreams(audio_element.substream_id_to_labels,
*audio_element.codec_config)
.ok());
EXPECT_TRUE(decoder.Decode(encoded_audio_frames, decoded_audio_frames).ok());
}

TEST(InitDecodersForSubstreams,
ShouldNotBeCalledTwiceWithTheSameSubstreamIdForStatefulEncoders) {
absl::flat_hash_map<uint32_t, CodecConfigObu> codec_config_obus;
AddOpusCodecConfigWithId(kCodecConfigId, codec_config_obus);
const auto& codec_config = codec_config_obus.at(kCodecConfigId);

AudioFrameDecoder decoder(::testing::TempDir(), kWavFilePrefix);
const SubstreamIdLabelsMap kLabelsForSubstreamZero = {{kSubstreamId, {"M"}}};
EXPECT_TRUE(
decoder.InitDecodersForSubstreams(kLabelsForSubstreamZero, codec_config)
.ok());
EXPECT_FALSE(
decoder.InitDecodersForSubstreams(kLabelsForSubstreamZero, codec_config)
.ok());

const SubstreamIdLabelsMap kLabelsForSubstreamOne = {
{kSubstreamId + 1, {"M"}}};
EXPECT_TRUE(
decoder.InitDecodersForSubstreams(kLabelsForSubstreamOne, codec_config)
.ok());
}

void InitAllAudioElements(
const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>&
audio_elements,
AudioFrameDecoder& decoder) {
for (const auto& [audio_element_id, audio_element_with_data] :
audio_elements) {
EXPECT_TRUE(decoder
.InitDecodersForSubstreams(
audio_element_with_data.substream_id_to_labels,
*audio_element_with_data.codec_config)
.ok());
}
}

TEST(Decode, AppendsToOutputList) {
AudioFrameDecoder decoder(::testing::TempDir(), kWavFilePrefix);
// Encoded frames.
absl::flat_hash_map<uint32_t, CodecConfigObu> codec_config_obus;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
std::list<AudioFrameWithData> encoded_audio_frames =
PrepareEncodedAudioFrames(codec_config_obus, audio_elements);
InitAllAudioElements(audio_elements, decoder);

std::list<DecodedAudioFrame> decoded_audio_frames;
EXPECT_TRUE(decoder.Decode(encoded_audio_frames, decoded_audio_frames).ok());
EXPECT_EQ(decoded_audio_frames.size(), 1);
EXPECT_TRUE(decoder.Decode(encoded_audio_frames, decoded_audio_frames).ok());
EXPECT_EQ(decoded_audio_frames.size(), 2);
}

TEST(Decode, DecodesLpcmFrame) {
AudioFrameDecoder decoder(::testing::TempDir(), kWavFilePrefix);

// Encoded frames.
absl::flat_hash_map<uint32_t, CodecConfigObu> codec_config_obus;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
std::list<AudioFrameWithData> encoded_audio_frames =
PrepareEncodedAudioFrames(codec_config_obus, audio_elements);
InitAllAudioElements(audio_elements, decoder);

// Decode.
std::list<DecodedAudioFrame> decoded_audio_frames;
Expand Down Expand Up @@ -110,6 +184,8 @@ void DecodeEightSampleAudioFrame(uint32_t num_samples_to_trim_at_end = 0,
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
std::list<AudioFrameWithData> encoded_audio_frames =
PrepareEncodedAudioFrames(codec_config_obus, audio_elements);
InitAllAudioElements(audio_elements, decoder);

encoded_audio_frames.front().obu.header_.num_samples_to_trim_at_end =
num_samples_to_trim_at_end;
encoded_audio_frames.front().obu.header_.num_samples_to_trim_at_start =
Expand All @@ -119,7 +195,7 @@ void DecodeEightSampleAudioFrame(uint32_t num_samples_to_trim_at_end = 0,
EXPECT_TRUE(decoder.Decode(encoded_audio_frames, decoded_audio_frames).ok());
}

TEST(AudioFrameDecoderTest, WritesDebuggingWavFileWithExpectedNumberOfSamples) {
TEST(Decode, WritesDebuggingWavFileWithExpectedNumberOfSamples) {
DecodeEightSampleAudioFrame();

EXPECT_TRUE(std::filesystem::exists(GetFirstExpectedWavFile(kSubstreamId)));
Expand All @@ -128,7 +204,7 @@ TEST(AudioFrameDecoderTest, WritesDebuggingWavFileWithExpectedNumberOfSamples) {
EXPECT_EQ(reader.remaining_samples(), kNumSamplesPerFrame);
}

TEST(AudioFrameDecoderTest, DebuggingWavFileHasSamplesTrimmed) {
TEST(Decode, DebuggingWavFileHasSamplesTrimmed) {
const uint32_t kNumSamplesToTrimAtEnd = 5;
const uint32_t kNumSamplesToTrimAtStart = 2;
DecodeEightSampleAudioFrame(kNumSamplesToTrimAtEnd, kNumSamplesToTrimAtStart);
Expand Down

0 comments on commit b127705

Please sign in to comment.