Skip to content

Commit

Permalink
Cache the SubstreamIdLabelsMap and LabelGainMap instead of a refe…
Browse files Browse the repository at this point in the history
…rence to all audio elements.

  - These fields are expected to be relatively small and they are only stored once per audio element.
  - This should enable future changes to initialize the module in different ways.

PiperOrigin-RevId: 631787681
  • Loading branch information
jwcullen committed May 8, 2024
1 parent 57777f8 commit a105e22
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
52 changes: 29 additions & 23 deletions iamf/cli/demixing_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,18 +586,22 @@ absl::Status Tf2ToT2Demixer(const DownMixingParams& down_mixing_params,
return absl::OkStatus();
}

absl::Status FindRequiredDemixingMetadata(
absl::Status FillRequiredDemixingMetadata(
const iamf_tools_cli_proto::AudioFrameObuMetadata& audio_frame_metadata,
const SubstreamIdLabelsMap& substream_id_to_labels,
const AudioElementWithData& audio_element_with_data,
DemxingMetadataForAudioElementId& demixing_metadata) {
auto& down_mixers = demixing_metadata.down_mixers;
auto& demixers = demixing_metadata.demixers;

if (!down_mixers.empty() || !demixers.empty()) {
LOG(ERROR) << "`FindRequiredDownMixers()` should only be called once "
LOG(ERROR) << "`FillRequiredDemixingMetadata()` should only be called once "
<< "per Audio Element ID";
return absl::UnknownError("");
}
demixing_metadata.substream_id_to_labels =
audio_element_with_data.substream_id_to_labels;
demixing_metadata.label_to_output_gain =
audio_element_with_data.label_to_output_gain;

const auto& input_channel_ids = audio_frame_metadata.channel_ids();
const auto& input_channel_labels = audio_frame_metadata.channel_labels();
Expand Down Expand Up @@ -629,7 +633,8 @@ absl::Status FindRequiredDemixingMetadata(

// Find the lowest output surround number.
int output_lowest_surround_number = INT_MAX;
for (const auto& [substream_id, labels] : substream_id_to_labels) {
for (const auto& [substream_id, labels] :
demixing_metadata.substream_id_to_labels) {
if (std::find(labels.begin(), labels.end(), "L7") != labels.end() &&
output_lowest_surround_number > 7) {
output_lowest_surround_number = 7;
Expand Down Expand Up @@ -692,7 +697,8 @@ absl::Status FindRequiredDemixingMetadata(

// Find the lowest output height number.
int output_lowest_height_number = INT_MAX;
for (const auto& [substream_id, labels] : substream_id_to_labels) {
for (const auto& [substream_id, labels] :
demixing_metadata.substream_id_to_labels) {
if (std::find(labels.begin(), labels.end(), "Ltf4") != labels.end() &&
output_lowest_height_number > 4) {
output_lowest_height_number = 4;
Expand Down Expand Up @@ -869,13 +875,12 @@ DemixingModule::DemixingModule(
const iamf_tools_cli_proto::UserMetadata& user_metadata,
const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>&
audio_elements)
: audio_elements_(audio_elements), init_status_(absl::OkStatus()) {
: init_status_(absl::OkStatus()) {
for (const auto& audio_frame_metadata :
user_metadata.audio_frame_metadata()) {
const auto audio_element_id = audio_frame_metadata.audio_element_id();
init_status_ = FindRequiredDemixingMetadata(
audio_frame_metadata,
audio_elements.at(audio_element_id).substream_id_to_labels,
init_status_ = FillRequiredDemixingMetadata(
audio_frame_metadata, audio_elements.at(audio_element_id),
audio_element_id_to_demixing_metadata_[audio_element_id]);

if (init_status_ != absl::OkStatus()) {
Expand All @@ -895,19 +900,21 @@ absl::Status DemixingModule::DownMixSamplesToSubstreams(
<< "initialization failed.";
return init_status_;
}
const auto& audio_element_with_data = audio_elements_.at(audio_element_id);
const std::list<Demixer>* down_mixers;
RETURN_IF_NOT_OK(GetDownMixers(audio_element_id, down_mixers));

const DemxingMetadataForAudioElementId* demixing_metadata = nullptr;
RETURN_IF_NOT_OK(GetDemixerMetadata(init_status_, audio_element_id,
audio_element_id_to_demixing_metadata_,
demixing_metadata));

// First perform all the down mixing.
for (const auto& down_mixer : *down_mixers) {
for (const auto& down_mixer : demixing_metadata->down_mixers) {
RETURN_IF_NOT_OK(down_mixer(down_mixing_params, &input_label_to_samples));
}

const size_t num_time_ticks = input_label_to_samples.begin()->second.size();

for (const auto& [substream_id, output_channel_labels] :
audio_element_with_data.substream_id_to_labels) {
demixing_metadata->substream_id_to_labels) {
std::vector<std::vector<int32_t>> substream_samples(
num_time_ticks,
// One or two channels.
Expand All @@ -927,10 +934,10 @@ absl::Status DemixingModule::DownMixSamplesToSubstreams(
}

// Compute and store the linear output gains.
auto gain_iter = audio_element_with_data.label_to_output_gain.find(
output_channel_label);
auto gain_iter =
demixing_metadata->label_to_output_gain.find(output_channel_label);
output_gains_linear[channel_index] = 1.0;
if (gain_iter != audio_element_with_data.label_to_output_gain.end()) {
if (gain_iter != demixing_metadata->label_to_output_gain.end()) {
output_gains_linear[channel_index] =
std::pow(10.0, gain_iter->second / 20.0);
}
Expand Down Expand Up @@ -978,19 +985,18 @@ absl::Status DemixingModule::DemixAudioSamples(
<< "failed.";
return init_status_;
}
for (const auto& [audio_element_id, audio_element_with_data] :
audio_elements_) {
for (const auto& [audio_element_id, demixing_metadata] :
audio_element_id_to_demixing_metadata_) {
auto& time_to_labeled_frame = id_to_time_to_labeled_frame[audio_element_id];
auto& time_to_labeled_decoded_frame =
id_to_time_to_labeled_decoded_frame[audio_element_id];

RETURN_IF_NOT_OK(StoreSamplesForAudioElementId(
audio_frames, decoded_audio_frames,
audio_element_with_data.substream_id_to_labels, time_to_labeled_frame,
demixing_metadata.substream_id_to_labels, time_to_labeled_frame,
time_to_labeled_decoded_frame));
const std::list<Demixer>* demixers;
RETURN_IF_NOT_OK(GetDemixers(audio_element_id, demixers));
RETURN_IF_NOT_OK(ApplyDemixers(*demixers, &time_to_labeled_frame,
RETURN_IF_NOT_OK(ApplyDemixers(demixing_metadata.demixers,
&time_to_labeled_frame,
&time_to_labeled_decoded_frame));

LOG(INFO) << "Demixing Audio Element ID= " << audio_element_id;
Expand Down
4 changes: 2 additions & 2 deletions iamf/cli/demixing_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class DemixingModule {
struct DemxingMetadataForAudioElementId {
std::list<Demixer> demixers;
std::list<Demixer> down_mixers;
SubstreamIdLabelsMap substream_id_to_labels;
LabelGainMap label_to_output_gain;
};

/*\!brief Constructor.
Expand Down Expand Up @@ -152,8 +154,6 @@ class DemixingModule {
const std::list<Demixer>*& demixers) const;

private:
const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>&
audio_elements_;
absl::Status init_status_;

absl::flat_hash_map<DecodedUleb128, DemxingMetadataForAudioElementId>
Expand Down

0 comments on commit a105e22

Please sign in to comment.