diff --git a/iamf/cli/audio_frame_decoder.cc b/iamf/cli/audio_frame_decoder.cc index 9e3c6f4f..b224d90c 100644 --- a/iamf/cli/audio_frame_decoder.cc +++ b/iamf/cli/audio_frame_decoder.cc @@ -88,8 +88,7 @@ absl::Status InitializeWavWriterForSubstreamId( : file_directory / file_name; wav_writers.emplace( - substream_id, - iamf_tools::WavWriter(wav_path, num_channels, sample_rate, bit_depth)); + substream_id, WavWriter(wav_path, num_channels, sample_rate, bit_depth)); return absl::OkStatus(); } @@ -182,68 +181,58 @@ void AbortAllWavWriters( } // namespace -absl::Status AudioFrameDecoder::Decode( - const std::list& encoded_audio_frames, - std::list& decoded_audio_frames) { - // A map of substream IDs to the relevant decoder and codec config. This is - // necessary to process streams with stateful decoders correctly. - absl::node_hash_map> - substream_id_to_decoder; - // A map of substream IDs to the relevant wav writer. - absl::node_hash_map substream_id_to_wav_writer; - - // Initialize all decoders and find all corresponding Codec Config OBUs. - for (const auto& audio_frame : encoded_audio_frames) { - const uint32_t substream_id = audio_frame.obu.GetSubstreamId(); - auto& decoder = substream_id_to_decoder[substream_id]; - if (decoder) { - // Already found the information for this stream. - continue; +// Initializes all decoders and wav writers based on the corresponding Audio +// Element and Codec Config OBUs. +absl::Status AudioFrameDecoder::InitDecodersForSubstreams( + const SubstreamIdLabelsMap& substream_id_to_labels, + const CodecConfigObu& codec_config) { + for (const auto& [substream_id, labels] : substream_id_to_labels) { + auto& decoder = substream_id_to_decoder_[substream_id]; + if (decoder != nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "Already initialized decoder for substream ID: ", substream_id, + ". Maybe multiple Audio Element OBUs have the same substream ID?")); } - if (audio_frame.audio_element_with_data == nullptr || - audio_frame.audio_element_with_data->codec_config == nullptr) { - LOG(ERROR) << "Unexpected nullptr in an audio frame with id=" - << substream_id; - return absl::UnknownError(""); - } - - const auto& audio_element = *audio_frame.audio_element_with_data; - const auto& iter = audio_element.substream_id_to_labels.find(substream_id); - if (iter == audio_element.substream_id_to_labels.end()) { - LOG(ERROR) << "Unknown number of channels for substream id: " - << substream_id; - return absl::UnknownError(""); - } - const int num_channels = static_cast(iter->second.size()); + const int num_channels = static_cast(labels.size()); - // Initialize the decoder based on the found Codec Config OBU and number of - // channels. - RETURN_IF_NOT_OK( - InitializeDecoder(*audio_element.codec_config, num_channels, decoder)); + // Initialize the decoder based on the found Codec Config OBU and number + // of channels. + RETURN_IF_NOT_OK(InitializeDecoder(codec_config, num_channels, decoder)); RETURN_IF_NOT_OK(InitializeWavWriterForSubstreamId( substream_id, output_wav_directory_, file_prefix_, num_channels, - static_cast(audio_element.codec_config->GetOutputSampleRate()), - static_cast( - audio_element.codec_config->GetBitDepthToMeasureLoudness()), - substream_id_to_wav_writer)); + static_cast(codec_config.GetOutputSampleRate()), + static_cast(codec_config.GetBitDepthToMeasureLoudness()), + substream_id_to_wav_writer_)); } + return absl::OkStatus(); +} + +absl::Status AudioFrameDecoder::Decode( + const std::list& encoded_audio_frames, + std::list& decoded_audio_frames) { // Decode all frames in all substreams. for (const auto& audio_frame : encoded_audio_frames) { + auto decoder_iter = + substream_id_to_decoder_.find(audio_frame.obu.GetSubstreamId()); + if (decoder_iter == substream_id_to_decoder_.end()) { + return absl::InvalidArgumentError( + absl::StrCat("No decoder found for substream ID: ", + audio_frame.obu.GetSubstreamId())); + } + DecodedAudioFrame decoded_audio_frame; auto decode_status = DecodeAudioFrame( - audio_frame, - substream_id_to_decoder.at(audio_frame.obu.GetSubstreamId()).get(), - decoded_audio_frame); + audio_frame, decoder_iter->second.get(), decoded_audio_frame); if (!decode_status.ok()) { LOG(ERROR) << "Failed to decode audio streams. decode_status: " << decode_status; - AbortAllWavWriters(substream_id_to_wav_writer); + AbortAllWavWriters(substream_id_to_wav_writer_); return decode_status; } RETURN_IF_NOT_OK(DumpDecodedAudioFrameToWavWriter( - decoded_audio_frame, substream_id_to_wav_writer)); + decoded_audio_frame, substream_id_to_wav_writer_)); decoded_audio_frames.push_back(decoded_audio_frame); } diff --git a/iamf/cli/audio_frame_decoder.h b/iamf/cli/audio_frame_decoder.h index 56777bc8..cb7022a1 100644 --- a/iamf/cli/audio_frame_decoder.h +++ b/iamf/cli/audio_frame_decoder.h @@ -14,13 +14,18 @@ #include #include +#include #include #include +#include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "iamf/cli/audio_element_with_data.h" #include "iamf/cli/audio_frame_with_data.h" +#include "iamf/cli/codec/decoder_base.h" +#include "iamf/cli/wav_writer.h" +#include "iamf/obu/codec_config.h" namespace iamf_tools { @@ -49,6 +54,14 @@ struct DecodedAudioFrame { * This class manages the underlying codec decoders for all substreams. Codec * decoders may be stateful; this class manages a one-to-one mapping between * codec decoders and substream. + * + * Call `InitDecodersForSubstreams` with pairs of `SubstreamIdLabelsMap` and + * `CodecConfigObu`. This typically will require one call per Audio Element OBU. + * + * Then call `Decode` repeatedly with a list of `AudioFrameWithData`. There may + * be multiple `AudioFrameWithData`s in a single call to this function. Each + * substream in the list is assumed to be self-consistent in temporal order. It + * is permitted in any order relative to other substreams. */ class AudioFrameDecoder { public: @@ -62,7 +75,18 @@ class AudioFrameDecoder { : output_wav_directory_(output_wav_directory), file_prefix_(file_prefix) {} - // TODO(b/306319126): Decode one audio frame at a time. + /*\!brief Initialize codec decoders for each substream. + * + * \param substream_id_to_labels Substreams and their associated labels to + * initialize. The number of channels is determined by the number of + * labels. + * \param codec_config Codec Config OBU to use for all substreams. + * \return `absl::OkStatus()` on success. A specific status on failure. + */ + absl::Status InitDecodersForSubstreams( + const SubstreamIdLabelsMap& substream_id_to_labels, + const CodecConfigObu& codec_config); + /*\!brief Decodes a list of Audio Frame OBUs. * * \param encoded_audio_frames Input Audio Frame OBUs. @@ -75,6 +99,14 @@ class AudioFrameDecoder { private: const std::string output_wav_directory_; const std::string file_prefix_; + + // A map of substream IDs to the relevant decoder and codec config. This is + // necessary to process streams with stateful decoders correctly. + absl::node_hash_map> + substream_id_to_decoder_; + + // A map of substream IDs to the relevant wav writer. + absl::node_hash_map substream_id_to_wav_writer_; }; } // namespace iamf_tools diff --git a/iamf/cli/encoder_main_lib.cc b/iamf/cli/encoder_main_lib.cc index 1951be65..a178e03c 100644 --- a/iamf/cli/encoder_main_lib.cc +++ b/iamf/cli/encoder_main_lib.cc @@ -156,6 +156,23 @@ absl::Status CreateOutputDirectory(const std::string& output_directory) { return absl::OkStatus(); } +absl::Status InitAudioFrameDecoderForAllAudioElements( + const absl::flat_hash_map& + audio_elements, + AudioFrameDecoder& audio_frame_decoder) { + for (const auto& [_, audio_element] : audio_elements) { + if (audio_element.codec_config == nullptr) { + // Skip stray audio elements. We won't know how to decode their + // substreams. + continue; + } + + RETURN_IF_NOT_OK(audio_frame_decoder.InitDecodersForSubstreams( + audio_element.substream_id_to_labels, *audio_element.codec_config)); + } + return absl::OkStatus(); +} + absl::Status GenerateObus( const iamf_tools_cli_proto::UserMetadata& user_metadata, const std::string& input_wav_directory, @@ -239,9 +256,17 @@ absl::Status GenerateObus( output_wav_directory, user_metadata.test_vector_metadata().file_name_prefix(), demixing_module, parameters_manager, global_timing_module); - RETURN_IF_NOT_OK(audio_frame_generator.Initialize()); + // Initialize the audio frame decoder. It is needed to determine the recon + // gain parameters and measure the loudness of the mixes. + std::list decoded_audio_frames; + AudioFrameDecoder audio_frame_decoder( + output_wav_directory, + user_metadata.test_vector_metadata().file_name_prefix()); + RETURN_IF_NOT_OK(InitAudioFrameDecoderForAllAudioElements( + audio_elements, audio_frame_decoder)); + // TODO(b/315924757): Currently getting all parameter blocks corresponding to // a timestamp from `parameter_blocks` to simulate // iterative generations. @@ -280,18 +305,15 @@ absl::Status GenerateObus( if (temp_audio_frames.empty()) { absl::SleepFor(absl::Milliseconds(50)); } else { + // Decode all of the newly encoded frames and collect them. + RETURN_IF_NOT_OK( + audio_frame_decoder.Decode(temp_audio_frames, decoded_audio_frames)); + audio_frames.splice(audio_frames.end(), temp_audio_frames); } } PrintAudioFrames(audio_frames); - AudioFrameDecoder audio_frame_decoder( - output_wav_directory, - user_metadata.test_vector_metadata().file_name_prefix()); - std::list decoded_audio_frames; - RETURN_IF_NOT_OK( - audio_frame_decoder.Decode(audio_frames, decoded_audio_frames)); - // Demix audio samples; useful for the following operations. IdTimeLabeledFrameMap id_to_time_to_labeled_frame; IdTimeLabeledFrameMap id_to_time_to_labeled_decoded_frame; diff --git a/iamf/cli/tests/audio_frame_decoder_test.cc b/iamf/cli/tests/audio_frame_decoder_test.cc index 6d9ac1f6..46262801 100644 --- a/iamf/cli/tests/audio_frame_decoder_test.cc +++ b/iamf/cli/tests/audio_frame_decoder_test.cc @@ -30,7 +30,7 @@ const int kNumSamplesPerFrame = 8; const int kBytesPerSample = 2; constexpr absl::string_view kWavFilePrefix = "test"; -TEST(AudioFrameDecoderTest, NoAudioFrames) { +TEST(Decode, SucceedsOnEmptyInput) { AudioFrameDecoder decoder(::testing::TempDir(), kWavFilePrefix); std::list decoded_audio_frames; @@ -64,7 +64,80 @@ std::list PrepareEncodedAudioFrames( return encoded_audio_frames; } -TEST(AudioFrameDecoderTest, DecodeLpcmFrame) { +TEST(Decode, RequiresSubstreamsAreInitialized) { + AudioFrameDecoder decoder(::testing::TempDir(), kWavFilePrefix); + // Encoded frames. + absl::flat_hash_map codec_config_obus; + absl::flat_hash_map audio_elements; + std::list encoded_audio_frames = + PrepareEncodedAudioFrames(codec_config_obus, audio_elements); + + // Decoding fails before substreams are initialized. + std::list decoded_audio_frames; + EXPECT_FALSE(decoder.Decode(encoded_audio_frames, decoded_audio_frames).ok()); + const auto& audio_element = audio_elements.at(kAudioElementId); + // Decoding succeeds after substreams are initialized. + EXPECT_TRUE( + decoder + .InitDecodersForSubstreams(audio_element.substream_id_to_labels, + *audio_element.codec_config) + .ok()); + EXPECT_TRUE(decoder.Decode(encoded_audio_frames, decoded_audio_frames).ok()); +} + +TEST(InitDecodersForSubstreams, + ShouldNotBeCalledTwiceWithTheSameSubstreamIdForStatefulEncoders) { + absl::flat_hash_map codec_config_obus; + AddOpusCodecConfigWithId(kCodecConfigId, codec_config_obus); + const auto& codec_config = codec_config_obus.at(kCodecConfigId); + + AudioFrameDecoder decoder(::testing::TempDir(), kWavFilePrefix); + const SubstreamIdLabelsMap kLabelsForSubstreamZero = {{kSubstreamId, {"M"}}}; + EXPECT_TRUE( + decoder.InitDecodersForSubstreams(kLabelsForSubstreamZero, codec_config) + .ok()); + EXPECT_FALSE( + decoder.InitDecodersForSubstreams(kLabelsForSubstreamZero, codec_config) + .ok()); + + const SubstreamIdLabelsMap kLabelsForSubstreamOne = { + {kSubstreamId + 1, {"M"}}}; + EXPECT_TRUE( + decoder.InitDecodersForSubstreams(kLabelsForSubstreamOne, codec_config) + .ok()); +} + +void InitAllAudioElements( + const absl::flat_hash_map& + audio_elements, + AudioFrameDecoder& decoder) { + for (const auto& [audio_element_id, audio_element_with_data] : + audio_elements) { + EXPECT_TRUE(decoder + .InitDecodersForSubstreams( + audio_element_with_data.substream_id_to_labels, + *audio_element_with_data.codec_config) + .ok()); + } +} + +TEST(Decode, AppendsToOutputList) { + AudioFrameDecoder decoder(::testing::TempDir(), kWavFilePrefix); + // Encoded frames. + absl::flat_hash_map codec_config_obus; + absl::flat_hash_map audio_elements; + std::list encoded_audio_frames = + PrepareEncodedAudioFrames(codec_config_obus, audio_elements); + InitAllAudioElements(audio_elements, decoder); + + std::list decoded_audio_frames; + EXPECT_TRUE(decoder.Decode(encoded_audio_frames, decoded_audio_frames).ok()); + EXPECT_EQ(decoded_audio_frames.size(), 1); + EXPECT_TRUE(decoder.Decode(encoded_audio_frames, decoded_audio_frames).ok()); + EXPECT_EQ(decoded_audio_frames.size(), 2); +} + +TEST(Decode, DecodesLpcmFrame) { AudioFrameDecoder decoder(::testing::TempDir(), kWavFilePrefix); // Encoded frames. @@ -72,6 +145,7 @@ TEST(AudioFrameDecoderTest, DecodeLpcmFrame) { absl::flat_hash_map audio_elements; std::list encoded_audio_frames = PrepareEncodedAudioFrames(codec_config_obus, audio_elements); + InitAllAudioElements(audio_elements, decoder); // Decode. std::list decoded_audio_frames; @@ -110,6 +184,8 @@ void DecodeEightSampleAudioFrame(uint32_t num_samples_to_trim_at_end = 0, absl::flat_hash_map audio_elements; std::list encoded_audio_frames = PrepareEncodedAudioFrames(codec_config_obus, audio_elements); + InitAllAudioElements(audio_elements, decoder); + encoded_audio_frames.front().obu.header_.num_samples_to_trim_at_end = num_samples_to_trim_at_end; encoded_audio_frames.front().obu.header_.num_samples_to_trim_at_start = @@ -119,7 +195,7 @@ void DecodeEightSampleAudioFrame(uint32_t num_samples_to_trim_at_end = 0, EXPECT_TRUE(decoder.Decode(encoded_audio_frames, decoded_audio_frames).ok()); } -TEST(AudioFrameDecoderTest, WritesDebuggingWavFileWithExpectedNumberOfSamples) { +TEST(Decode, WritesDebuggingWavFileWithExpectedNumberOfSamples) { DecodeEightSampleAudioFrame(); EXPECT_TRUE(std::filesystem::exists(GetFirstExpectedWavFile(kSubstreamId))); @@ -128,7 +204,7 @@ TEST(AudioFrameDecoderTest, WritesDebuggingWavFileWithExpectedNumberOfSamples) { EXPECT_EQ(reader.remaining_samples(), kNumSamplesPerFrame); } -TEST(AudioFrameDecoderTest, DebuggingWavFileHasSamplesTrimmed) { +TEST(Decode, DebuggingWavFileHasSamplesTrimmed) { const uint32_t kNumSamplesToTrimAtEnd = 5; const uint32_t kNumSamplesToTrimAtStart = 2; DecodeEightSampleAudioFrame(kNumSamplesToTrimAtEnd, kNumSamplesToTrimAtStart);