diff --git a/iamf/cli/demixing_module.cc b/iamf/cli/demixing_module.cc index 301d56bb..31f39dbd 100644 --- a/iamf/cli/demixing_module.cc +++ b/iamf/cli/demixing_module.cc @@ -833,14 +833,10 @@ absl::Status ApplyDemixers(const std::list& demixers, } absl::Status GetDemixerMetadata( - const absl::Status init_error, const DecodedUleb128 audio_element_id, + const DecodedUleb128 audio_element_id, const absl::flat_hash_map& audio_element_id_to_demixing_metadata, const DemxingMetadataForAudioElementId*& demixing_metadata) { - if (init_error != absl::OkStatus()) { - return absl::InvalidArgumentError(""); - } - const auto iter = audio_element_id_to_demixing_metadata.find(audio_element_id); if (iter == audio_element_id_to_demixing_metadata.end()) { @@ -871,23 +867,24 @@ absl::Status DemixingModule::FindSamplesOrDemixedSamples( } } -DemixingModule::DemixingModule( +absl::Status DemixingModule::Initialize( const iamf_tools_cli_proto::UserMetadata& user_metadata, const absl::flat_hash_map& - audio_elements) - : init_status_(absl::OkStatus()) { + audio_elements) { for (const auto& audio_frame_metadata : user_metadata.audio_frame_metadata()) { const auto audio_element_id = audio_frame_metadata.audio_element_id(); - init_status_ = FillRequiredDemixingMetadata( - audio_frame_metadata, audio_elements.at(audio_element_id), - audio_element_id_to_demixing_metadata_[audio_element_id]); - - if (init_status_ != absl::OkStatus()) { - LOG(ERROR) << "Initialization of `DemixingModule` failed; abort"; - break; + auto audio_element = audio_elements.find(audio_element_id); + if (audio_element == audio_elements.end()) { + return absl::InvalidArgumentError( + absl::StrCat("Audio Element ID= ", audio_element_id, " not found")); } + + RETURN_IF_NOT_OK(FillRequiredDemixingMetadata( + audio_frame_metadata, audio_element->second, + audio_element_id_to_demixing_metadata_[audio_element_id])); } + return absl::OkStatus(); } absl::Status DemixingModule::DownMixSamplesToSubstreams( @@ -895,14 +892,8 @@ absl::Status DemixingModule::DownMixSamplesToSubstreams( LabelSamplesMap& input_label_to_samples, absl::flat_hash_map& substream_id_to_substream_data) const { - if (init_status_ != absl::OkStatus()) { - LOG(ERROR) << "Cannot call `DownMixSamplesToSubstreams()` when " - << "initialization failed."; - return init_status_; - } - const DemxingMetadataForAudioElementId* demixing_metadata = nullptr; - RETURN_IF_NOT_OK(GetDemixerMetadata(init_status_, audio_element_id, + RETURN_IF_NOT_OK(GetDemixerMetadata(audio_element_id, audio_element_id_to_demixing_metadata_, demixing_metadata)); @@ -980,11 +971,6 @@ absl::Status DemixingModule::DemixAudioSamples( const std::list& decoded_audio_frames, IdTimeLabeledFrameMap& id_to_time_to_labeled_frame, IdTimeLabeledFrameMap& id_to_time_to_labeled_decoded_frame) const { - if (init_status_ != absl::OkStatus()) { - LOG(ERROR) << "Cannot call `DemixAudioSamples()` when initialization " - << "failed."; - return init_status_; - } for (const auto& [audio_element_id, demixing_metadata] : audio_element_id_to_demixing_metadata_) { auto& time_to_labeled_frame = id_to_time_to_labeled_frame[audio_element_id]; @@ -1023,7 +1009,7 @@ absl::Status DemixingModule::GetDownMixers( DecodedUleb128 audio_element_id, const std::list*& down_mixers) const { const DemxingMetadataForAudioElementId* demixing_metadata = nullptr; - RETURN_IF_NOT_OK(GetDemixerMetadata(init_status_, audio_element_id, + RETURN_IF_NOT_OK(GetDemixerMetadata(audio_element_id, audio_element_id_to_demixing_metadata_, demixing_metadata)); down_mixers = &demixing_metadata->down_mixers; @@ -1034,7 +1020,7 @@ absl::Status DemixingModule::GetDemixers( DecodedUleb128 audio_element_id, const std::list*& demixers) const { const DemxingMetadataForAudioElementId* demixing_metadata = nullptr; - RETURN_IF_NOT_OK(GetDemixerMetadata(init_status_, audio_element_id, + RETURN_IF_NOT_OK(GetDemixerMetadata(audio_element_id, audio_element_id_to_demixing_metadata_, demixing_metadata)); demixers = &demixing_metadata->demixers; diff --git a/iamf/cli/demixing_module.h b/iamf/cli/demixing_module.h index e756d448..530ac451 100644 --- a/iamf/cli/demixing_module.h +++ b/iamf/cli/demixing_module.h @@ -78,13 +78,17 @@ class DemixingModule { LabelGainMap label_to_output_gain; }; - /*\!brief Constructor. + /*\!brief Constructor. */ + DemixingModule() = default; + + /*\!brief Initializes the module to process the given audio elements. * * \param user_metadata Input user metadata. * \param audio_elements Audio elements. Used only for `audio_element_id`, * `substream_id_to_labels`, and `label_to_output_gain`. + * \return `absl::OkStatus()` on success. A specific status on failure. */ - DemixingModule( + absl::Status Initialize( const iamf_tools_cli_proto::UserMetadata& user_metadata, const absl::flat_hash_map& audio_elements); diff --git a/iamf/cli/encoder_main_lib.cc b/iamf/cli/encoder_main_lib.cc index ba413b3f..ce238f5f 100644 --- a/iamf/cli/encoder_main_lib.cc +++ b/iamf/cli/encoder_main_lib.cc @@ -252,7 +252,8 @@ absl::Status GenerateObus( // Demix audio samples while decoding them; useful for recon gain calculation // and measuring loudness. - DemixingModule demixing_module(user_metadata, audio_elements); + DemixingModule demixing_module; + RETURN_IF_NOT_OK(demixing_module.Initialize(user_metadata, audio_elements)); AudioFrameGenerator audio_frame_generator( user_metadata.audio_frame_metadata(), diff --git a/iamf/cli/tests/BUILD b/iamf/cli/tests/BUILD index 5bc4b884..14fcbe42 100644 --- a/iamf/cli/tests/BUILD +++ b/iamf/cli/tests/BUILD @@ -169,6 +169,7 @@ cc_test( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", ], ) diff --git a/iamf/cli/tests/audio_frame_generator_test.cc b/iamf/cli/tests/audio_frame_generator_test.cc index 457bbe63..595d4436 100644 --- a/iamf/cli/tests/audio_frame_generator_test.cc +++ b/iamf/cli/tests/audio_frame_generator_test.cc @@ -106,7 +106,8 @@ void GenerateAudioFrameWithEightSamples( param_definitions = {}; const std::string output_wav_directory = "/dev/null"; - DemixingModule demixing_module(user_metadata, audio_elements); + DemixingModule demixing_module; + ASSERT_TRUE(demixing_module.Initialize(user_metadata, audio_elements).ok()); GlobalTimingModule global_timing_module; ASSERT_TRUE( global_timing_module.Initialize(audio_elements, param_definitions).ok()); diff --git a/iamf/cli/tests/demixing_module_test.cc b/iamf/cli/tests/demixing_module_test.cc index 02942c2b..c7083d89 100644 --- a/iamf/cli/tests/demixing_module_test.cc +++ b/iamf/cli/tests/demixing_module_test.cc @@ -32,6 +32,7 @@ #include "iamf/obu/demixing_info_param_data.h" #include "iamf/obu/leb128.h" #include "iamf/obu/obu_header.h" +#include "src/google/protobuf/text_format.h" namespace iamf_tools { namespace { @@ -85,6 +86,75 @@ TEST(FindSamplesOrDemixedSamples, ErrorNoMatchingSamples) { absl::StatusCode::kUnknown); } +TEST(Initialize, ValidWhenCalledOncePerAudioElement) { + const DecodedUleb128 kAudioElementId = 137; + iamf_tools_cli_proto::UserMetadata user_metadata; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + audio_element_id: 137 + channel_ids: [ 0, 1 ] + channel_labels: [ "L2", "R2" ] + )pb", + user_metadata.add_audio_frame_metadata())); + absl::flat_hash_map audio_elements; + audio_elements.emplace( + kAudioElementId, + AudioElementWithData{ + .obu = AudioElementObu(ObuHeader(), kAudioElementId, + AudioElementObu::kAudioElementChannelBased, + /*reserved=*/0, + /*codec_config_id=*/0), + .substream_id_to_labels = {{0, {"M"}}, {1, {"L2"}}}, + }); + + DemixingModule demixing_module; + EXPECT_TRUE(demixing_module.Initialize(user_metadata, audio_elements).ok()); + // Each audio element can only be added once. + EXPECT_FALSE(demixing_module.Initialize(user_metadata, audio_elements).ok()); +} + +TEST(Initialize, InvalidWhenChannelLabelsAndChannelIdsMismatch) { + const DecodedUleb128 kAudioElementId = 137; + iamf_tools_cli_proto::UserMetadata user_metadata; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + audio_element_id: 137 + channel_ids: [ 0, 1, 2 ] + channel_labels: [ "L2", "R2" ] + )pb", + user_metadata.add_audio_frame_metadata())); + absl::flat_hash_map audio_elements; + audio_elements.emplace( + kAudioElementId, + AudioElementWithData{ + .obu = AudioElementObu(ObuHeader(), kAudioElementId, + AudioElementObu::kAudioElementChannelBased, + /*reserved=*/0, + /*codec_config_id=*/0), + .substream_id_to_labels = {{0, {"M"}}, {1, {"L2"}}}, + }); + + DemixingModule demixing_module; + EXPECT_FALSE(demixing_module.Initialize(user_metadata, audio_elements).ok()); +} + +TEST(Initialize, InvalidWhenMissingAudioElement) { + iamf_tools_cli_proto::UserMetadata user_metadata; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + audio_element_id: 137 + channel_ids: [ 0, 1 ] + channel_labels: [ "L2", "R2" ] + )pb", + user_metadata.add_audio_frame_metadata())); + const absl::flat_hash_map + kNoMatchingAudioElement; + + DemixingModule demixing_module; + EXPECT_FALSE( + demixing_module.Initialize(user_metadata, kNoMatchingAudioElement).ok()); +} + class DemixingModuleTestBase { public: DemixingModuleTestBase() { @@ -104,16 +174,15 @@ class DemixingModuleTestBase { .substream_id_to_labels = substream_id_to_labels_, }); - demixing_module_ = - std::make_unique(user_metadata, audio_elements_); + ASSERT_TRUE( + demixing_module_.Initialize(user_metadata, audio_elements_).ok()); const std::list* down_mixers = nullptr; const std::list* demixers = nullptr; ASSERT_TRUE( - demixing_module_->GetDownMixers(audio_element_id_, down_mixers).ok()); - ASSERT_TRUE( - demixing_module_->GetDemixers(audio_element_id_, demixers).ok()); + demixing_module_.GetDownMixers(audio_element_id_, down_mixers).ok()); + ASSERT_TRUE(demixing_module_.GetDemixers(audio_element_id_, demixers).ok()); EXPECT_EQ(down_mixers->size(), expected_number_of_down_mixers); EXPECT_EQ(demixers->size(), expected_number_of_down_mixers); } @@ -131,7 +200,7 @@ class DemixingModuleTestBase { absl::flat_hash_map audio_elements_; SubstreamIdLabelsMap substream_id_to_labels_; - std::unique_ptr demixing_module_; + DemixingModule demixing_module_; }; class DownMixingModuleTest : public DemixingModuleTestBase, @@ -141,12 +210,12 @@ class DownMixingModuleTest : public DemixingModuleTestBase, int expected_number_of_down_mixers) { TestCreateDemixingModule(expected_number_of_down_mixers); - EXPECT_TRUE( - demixing_module_ - ->DownMixSamplesToSubstreams(audio_element_id_, down_mixing_params, - input_label_to_samples_, - substream_id_to_substream_data_) - .ok()); + EXPECT_TRUE(demixing_module_ + .DownMixSamplesToSubstreams(audio_element_id_, + down_mixing_params, + input_label_to_samples_, + substream_id_to_substream_data_) + .ok()); for (const auto& [substream_id, substream_data] : substream_id_to_substream_data_) { @@ -561,9 +630,9 @@ class DemixingModuleTest : public DemixingModuleTestBase, TestCreateDemixingModule(expected_number_of_down_mixers); EXPECT_TRUE(demixing_module_ - ->DemixAudioSamples(audio_frames_, decoded_audio_frames_, - unused_id_to_time_to_labeled_frame, - id_to_time_to_labeled_decoded_frame) + .DemixAudioSamples(audio_frames_, decoded_audio_frames_, + unused_id_to_time_to_labeled_frame, + id_to_time_to_labeled_decoded_frame) .ok()); // Check that the demixed samples have the correct values. @@ -595,14 +664,13 @@ TEST_F(DemixingModuleTest, DemixingAudioSamplesSucceedsWithEmptyInputs) { // Clear the inputs. audio_elements_.clear(); - demixing_module_ = - std::make_unique(user_metadata, audio_elements_); + ASSERT_TRUE(demixing_module_.Initialize(user_metadata, audio_elements_).ok()); // Call `DemixAudioSamples()`. IdTimeLabeledFrameMap id_to_time_to_labeled_frame, id_to_time_to_labeled_decoded_frame; EXPECT_TRUE(demixing_module_ - ->DemixAudioSamples( + .DemixAudioSamples( /*audio_frames=*/{}, /*decoded_audio_frames=*/{}, id_to_time_to_labeled_frame, id_to_time_to_labeled_decoded_frame)