Skip to content

Commit

Permalink
DemixingModule: Separate adding audio elements from constructing.
Browse files Browse the repository at this point in the history
  - Remove unneeded `init_status_`.
  - Add test coverage for when `init_status_` would have been bad.

PiperOrigin-RevId: 631806167
  • Loading branch information
jwcullen committed May 8, 2024
1 parent a105e22 commit 0112e6b
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 51 deletions.
44 changes: 15 additions & 29 deletions iamf/cli/demixing_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -833,14 +833,10 @@ absl::Status ApplyDemixers(const std::list<Demixer>& demixers,
}

absl::Status GetDemixerMetadata(
const absl::Status init_error, const DecodedUleb128 audio_element_id,
const DecodedUleb128 audio_element_id,
const absl::flat_hash_map<DecodedUleb128, DemxingMetadataForAudioElementId>&
audio_element_id_to_demixing_metadata,
const DemxingMetadataForAudioElementId*& demixing_metadata) {
if (init_error != absl::OkStatus()) {
return absl::InvalidArgumentError("");
}

const auto iter =
audio_element_id_to_demixing_metadata.find(audio_element_id);
if (iter == audio_element_id_to_demixing_metadata.end()) {
Expand Down Expand Up @@ -871,38 +867,33 @@ absl::Status DemixingModule::FindSamplesOrDemixedSamples(
}
}

DemixingModule::DemixingModule(
absl::Status DemixingModule::Initialize(
const iamf_tools_cli_proto::UserMetadata& user_metadata,
const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>&
audio_elements)
: init_status_(absl::OkStatus()) {
audio_elements) {
for (const auto& audio_frame_metadata :
user_metadata.audio_frame_metadata()) {
const auto audio_element_id = audio_frame_metadata.audio_element_id();
init_status_ = FillRequiredDemixingMetadata(
audio_frame_metadata, audio_elements.at(audio_element_id),
audio_element_id_to_demixing_metadata_[audio_element_id]);

if (init_status_ != absl::OkStatus()) {
LOG(ERROR) << "Initialization of `DemixingModule` failed; abort";
break;
auto audio_element = audio_elements.find(audio_element_id);
if (audio_element == audio_elements.end()) {
return absl::InvalidArgumentError(
absl::StrCat("Audio Element ID= ", audio_element_id, " not found"));
}

RETURN_IF_NOT_OK(FillRequiredDemixingMetadata(
audio_frame_metadata, audio_element->second,
audio_element_id_to_demixing_metadata_[audio_element_id]));
}
return absl::OkStatus();
}

absl::Status DemixingModule::DownMixSamplesToSubstreams(
DecodedUleb128 audio_element_id, const DownMixingParams& down_mixing_params,
LabelSamplesMap& input_label_to_samples,
absl::flat_hash_map<uint32_t, SubstreamData>&
substream_id_to_substream_data) const {
if (init_status_ != absl::OkStatus()) {
LOG(ERROR) << "Cannot call `DownMixSamplesToSubstreams()` when "
<< "initialization failed.";
return init_status_;
}

const DemxingMetadataForAudioElementId* demixing_metadata = nullptr;
RETURN_IF_NOT_OK(GetDemixerMetadata(init_status_, audio_element_id,
RETURN_IF_NOT_OK(GetDemixerMetadata(audio_element_id,
audio_element_id_to_demixing_metadata_,
demixing_metadata));

Expand Down Expand Up @@ -980,11 +971,6 @@ absl::Status DemixingModule::DemixAudioSamples(
const std::list<DecodedAudioFrame>& decoded_audio_frames,
IdTimeLabeledFrameMap& id_to_time_to_labeled_frame,
IdTimeLabeledFrameMap& id_to_time_to_labeled_decoded_frame) const {
if (init_status_ != absl::OkStatus()) {
LOG(ERROR) << "Cannot call `DemixAudioSamples()` when initialization "
<< "failed.";
return init_status_;
}
for (const auto& [audio_element_id, demixing_metadata] :
audio_element_id_to_demixing_metadata_) {
auto& time_to_labeled_frame = id_to_time_to_labeled_frame[audio_element_id];
Expand Down Expand Up @@ -1023,7 +1009,7 @@ absl::Status DemixingModule::GetDownMixers(
DecodedUleb128 audio_element_id,
const std::list<Demixer>*& down_mixers) const {
const DemxingMetadataForAudioElementId* demixing_metadata = nullptr;
RETURN_IF_NOT_OK(GetDemixerMetadata(init_status_, audio_element_id,
RETURN_IF_NOT_OK(GetDemixerMetadata(audio_element_id,
audio_element_id_to_demixing_metadata_,
demixing_metadata));
down_mixers = &demixing_metadata->down_mixers;
Expand All @@ -1034,7 +1020,7 @@ absl::Status DemixingModule::GetDemixers(
DecodedUleb128 audio_element_id,
const std::list<Demixer>*& demixers) const {
const DemxingMetadataForAudioElementId* demixing_metadata = nullptr;
RETURN_IF_NOT_OK(GetDemixerMetadata(init_status_, audio_element_id,
RETURN_IF_NOT_OK(GetDemixerMetadata(audio_element_id,
audio_element_id_to_demixing_metadata_,
demixing_metadata));
demixers = &demixing_metadata->demixers;
Expand Down
8 changes: 6 additions & 2 deletions iamf/cli/demixing_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,17 @@ class DemixingModule {
LabelGainMap label_to_output_gain;
};

/*\!brief Constructor.
/*\!brief Constructor. */
DemixingModule() = default;

/*\!brief Initializes the module to process the given audio elements.
*
* \param user_metadata Input user metadata.
* \param audio_elements Audio elements. Used only for `audio_element_id`,
* `substream_id_to_labels`, and `label_to_output_gain`.
* \return `absl::OkStatus()` on success. A specific status on failure.
*/
DemixingModule(
absl::Status Initialize(
const iamf_tools_cli_proto::UserMetadata& user_metadata,
const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>&
audio_elements);
Expand Down
3 changes: 2 additions & 1 deletion iamf/cli/encoder_main_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ absl::Status GenerateObus(

// Demix audio samples while decoding them; useful for recon gain calculation
// and measuring loudness.
DemixingModule demixing_module(user_metadata, audio_elements);
DemixingModule demixing_module;
RETURN_IF_NOT_OK(demixing_module.Initialize(user_metadata, audio_elements));

AudioFrameGenerator audio_frame_generator(
user_metadata.audio_frame_metadata(),
Expand Down
1 change: 1 addition & 0 deletions iamf/cli/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ cc_test(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
"@com_google_protobuf//:protobuf",
],
)

Expand Down
3 changes: 2 additions & 1 deletion iamf/cli/tests/audio_frame_generator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ void GenerateAudioFrameWithEightSamples(
param_definitions = {};
const std::string output_wav_directory = "/dev/null";

DemixingModule demixing_module(user_metadata, audio_elements);
DemixingModule demixing_module;
ASSERT_TRUE(demixing_module.Initialize(user_metadata, audio_elements).ok());
GlobalTimingModule global_timing_module;
ASSERT_TRUE(
global_timing_module.Initialize(audio_elements, param_definitions).ok());
Expand Down
104 changes: 86 additions & 18 deletions iamf/cli/tests/demixing_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "iamf/obu/demixing_info_param_data.h"
#include "iamf/obu/leb128.h"
#include "iamf/obu/obu_header.h"
#include "src/google/protobuf/text_format.h"

namespace iamf_tools {
namespace {
Expand Down Expand Up @@ -85,6 +86,75 @@ TEST(FindSamplesOrDemixedSamples, ErrorNoMatchingSamples) {
absl::StatusCode::kUnknown);
}

TEST(Initialize, ValidWhenCalledOncePerAudioElement) {
const DecodedUleb128 kAudioElementId = 137;
iamf_tools_cli_proto::UserMetadata user_metadata;
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
R"pb(
audio_element_id: 137
channel_ids: [ 0, 1 ]
channel_labels: [ "L2", "R2" ]
)pb",
user_metadata.add_audio_frame_metadata()));
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
audio_elements.emplace(
kAudioElementId,
AudioElementWithData{
.obu = AudioElementObu(ObuHeader(), kAudioElementId,
AudioElementObu::kAudioElementChannelBased,
/*reserved=*/0,
/*codec_config_id=*/0),
.substream_id_to_labels = {{0, {"M"}}, {1, {"L2"}}},
});

DemixingModule demixing_module;
EXPECT_TRUE(demixing_module.Initialize(user_metadata, audio_elements).ok());
// Each audio element can only be added once.
EXPECT_FALSE(demixing_module.Initialize(user_metadata, audio_elements).ok());
}

TEST(Initialize, InvalidWhenChannelLabelsAndChannelIdsMismatch) {
const DecodedUleb128 kAudioElementId = 137;
iamf_tools_cli_proto::UserMetadata user_metadata;
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
R"pb(
audio_element_id: 137
channel_ids: [ 0, 1, 2 ]
channel_labels: [ "L2", "R2" ]
)pb",
user_metadata.add_audio_frame_metadata()));
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
audio_elements.emplace(
kAudioElementId,
AudioElementWithData{
.obu = AudioElementObu(ObuHeader(), kAudioElementId,
AudioElementObu::kAudioElementChannelBased,
/*reserved=*/0,
/*codec_config_id=*/0),
.substream_id_to_labels = {{0, {"M"}}, {1, {"L2"}}},
});

DemixingModule demixing_module;
EXPECT_FALSE(demixing_module.Initialize(user_metadata, audio_elements).ok());
}

TEST(Initialize, InvalidWhenMissingAudioElement) {
iamf_tools_cli_proto::UserMetadata user_metadata;
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
R"pb(
audio_element_id: 137
channel_ids: [ 0, 1 ]
channel_labels: [ "L2", "R2" ]
)pb",
user_metadata.add_audio_frame_metadata()));
const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>
kNoMatchingAudioElement;

DemixingModule demixing_module;
EXPECT_FALSE(
demixing_module.Initialize(user_metadata, kNoMatchingAudioElement).ok());
}

class DemixingModuleTestBase {
public:
DemixingModuleTestBase() {
Expand All @@ -104,16 +174,15 @@ class DemixingModuleTestBase {
.substream_id_to_labels = substream_id_to_labels_,
});

demixing_module_ =
std::make_unique<DemixingModule>(user_metadata, audio_elements_);
ASSERT_TRUE(
demixing_module_.Initialize(user_metadata, audio_elements_).ok());

const std::list<Demixer>* down_mixers = nullptr;
const std::list<Demixer>* demixers = nullptr;

ASSERT_TRUE(
demixing_module_->GetDownMixers(audio_element_id_, down_mixers).ok());
ASSERT_TRUE(
demixing_module_->GetDemixers(audio_element_id_, demixers).ok());
demixing_module_.GetDownMixers(audio_element_id_, down_mixers).ok());
ASSERT_TRUE(demixing_module_.GetDemixers(audio_element_id_, demixers).ok());
EXPECT_EQ(down_mixers->size(), expected_number_of_down_mixers);
EXPECT_EQ(demixers->size(), expected_number_of_down_mixers);
}
Expand All @@ -131,7 +200,7 @@ class DemixingModuleTestBase {
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements_;
SubstreamIdLabelsMap substream_id_to_labels_;

std::unique_ptr<DemixingModule> demixing_module_;
DemixingModule demixing_module_;
};

class DownMixingModuleTest : public DemixingModuleTestBase,
Expand All @@ -141,12 +210,12 @@ class DownMixingModuleTest : public DemixingModuleTestBase,
int expected_number_of_down_mixers) {
TestCreateDemixingModule(expected_number_of_down_mixers);

EXPECT_TRUE(
demixing_module_
->DownMixSamplesToSubstreams(audio_element_id_, down_mixing_params,
input_label_to_samples_,
substream_id_to_substream_data_)
.ok());
EXPECT_TRUE(demixing_module_
.DownMixSamplesToSubstreams(audio_element_id_,
down_mixing_params,
input_label_to_samples_,
substream_id_to_substream_data_)
.ok());

for (const auto& [substream_id, substream_data] :
substream_id_to_substream_data_) {
Expand Down Expand Up @@ -561,9 +630,9 @@ class DemixingModuleTest : public DemixingModuleTestBase,
TestCreateDemixingModule(expected_number_of_down_mixers);

EXPECT_TRUE(demixing_module_
->DemixAudioSamples(audio_frames_, decoded_audio_frames_,
unused_id_to_time_to_labeled_frame,
id_to_time_to_labeled_decoded_frame)
.DemixAudioSamples(audio_frames_, decoded_audio_frames_,
unused_id_to_time_to_labeled_frame,
id_to_time_to_labeled_decoded_frame)
.ok());

// Check that the demixed samples have the correct values.
Expand Down Expand Up @@ -595,14 +664,13 @@ TEST_F(DemixingModuleTest, DemixingAudioSamplesSucceedsWithEmptyInputs) {

// Clear the inputs.
audio_elements_.clear();
demixing_module_ =
std::make_unique<DemixingModule>(user_metadata, audio_elements_);
ASSERT_TRUE(demixing_module_.Initialize(user_metadata, audio_elements_).ok());

// Call `DemixAudioSamples()`.
IdTimeLabeledFrameMap id_to_time_to_labeled_frame,
id_to_time_to_labeled_decoded_frame;
EXPECT_TRUE(demixing_module_
->DemixAudioSamples(
.DemixAudioSamples(
/*audio_frames=*/{},
/*decoded_audio_frames=*/{}, id_to_time_to_labeled_frame,
id_to_time_to_labeled_decoded_frame)
Expand Down

0 comments on commit 0112e6b

Please sign in to comment.