Skip to content

Commit

Permalink
Generates Recon Gain parameter blocks iteratively
Browse files Browse the repository at this point in the history
- Makes the recon gain generator stateless, containing only a static method.
- Additional logging is handled by the caller.
- Drive-by changes: Put error messages in the error status.

PiperOrigin-RevId: 633936645
  • Loading branch information
yero authored and jwcullen committed May 16, 2024
1 parent 001397f commit eac84b6
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 247 deletions.
112 changes: 73 additions & 39 deletions iamf/cli/encoder_main_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
*/
#include "iamf/cli/encoder_main_lib.h"

#include <algorithm>
#include <cstdint>
#include <filesystem>
#include <limits>
Expand Down Expand Up @@ -98,10 +99,23 @@ absl::Status AddAllParameterBlockMetadataForCurrentTimestamp(
const google::protobuf::RepeatedPtrField<
iamf_tools_cli_proto::ParameterBlockObuMetadata>&
parameter_block_metadata,
const absl::flat_hash_map<DecodedUleb128, const ParamDefinition*>&
param_definitions,
const ParamDefinition::ParameterDefinitionType type_to_add,
ParameterBlockGenerator& parameter_block_generator,
int32_t& current_timestamp) {
int32_t next_timestamp = current_timestamp;
for (const auto& metadata : parameter_block_metadata) {
auto param_definition_iter = param_definitions.find(
static_cast<DecodedUleb128>(metadata.parameter_id()));
if (param_definition_iter == param_definitions.end()) {
return absl::InvalidArgumentError(absl::StrCat(
"No param definition found for param ID= ", metadata.parameter_id()));
}
if (param_definition_iter->second->GetType() != type_to_add) {
continue;
}

if (metadata.start_timestamp() == current_timestamp) {
uint32_t duration;
RETURN_IF_NOT_OK(
Expand All @@ -125,6 +139,8 @@ absl::Status MaybeGenerateDemixingAndMixGainParameterBlocks(
const google::protobuf::RepeatedPtrField<
iamf_tools_cli_proto::ParameterBlockObuMetadata>&
parameter_block_metadata,
const absl::flat_hash_map<DecodedUleb128, const ParamDefinition*>&
param_definitions,
ParametersManager& parameters_manager,
ParameterBlockGenerator& parameter_block_generator,
GlobalTimingModule& global_timing_module, int32_t& current_timestamp,
Expand All @@ -137,9 +153,18 @@ absl::Status MaybeGenerateDemixingAndMixGainParameterBlocks(
// Only generate parameter blocks when all audio frames corresponding to
// the same temporal units are ready.
if (global_audio_frame_timestamp == current_timestamp) {
int32_t current_timestamp_for_demixing = current_timestamp;
int32_t current_timestamp_for_mix_gain = current_timestamp;
RETURN_IF_NOT_OK(AddAllParameterBlockMetadataForCurrentTimestamp(
parameter_block_metadata, parameter_block_generator,
current_timestamp));
parameter_block_metadata, param_definitions,
ParamDefinition::kParameterDefinitionDemixing,
parameter_block_generator, current_timestamp_for_demixing));
RETURN_IF_NOT_OK(AddAllParameterBlockMetadataForCurrentTimestamp(
parameter_block_metadata, param_definitions,
ParamDefinition::kParameterDefinitionMixGain, parameter_block_generator,
current_timestamp_for_mix_gain));
current_timestamp = std::max(current_timestamp_for_demixing,
current_timestamp_for_mix_gain);

std::list<ParameterBlockWithData> mix_gain_parameter_blocks_for_frame;
std::list<ParameterBlockWithData> demixing_parameter_blocks_for_frame;
Expand Down Expand Up @@ -313,9 +338,10 @@ absl::Status GenerateObus(
int32_t current_timestamp = 0;
while (audio_frame_generator.TakingSamples()) {
RETURN_IF_NOT_OK(MaybeGenerateDemixingAndMixGainParameterBlocks(
user_metadata.parameter_block_metadata(), parameters_manager,
parameter_block_generator, global_timing_module, current_timestamp,
demixing_parameter_blocks, mix_gain_parameter_blocks));
user_metadata.parameter_block_metadata(), param_definitions,
parameters_manager, parameter_block_generator, global_timing_module,
current_timestamp, demixing_parameter_blocks,
mix_gain_parameter_blocks));

for (const auto& audio_frame_metadata :
user_metadata.audio_frame_metadata()) {
Expand All @@ -336,50 +362,58 @@ absl::Status GenerateObus(
// TODO(b/329375123): This should be on Thread 2.
IdTimeLabeledFrameMap id_to_time_to_labeled_frame;
IdTimeLabeledFrameMap id_to_time_to_labeled_decoded_frame;
std::list<ParameterBlockWithData> recon_gain_parameter_blocks;
while (audio_frame_generator.GeneratingFrames()) {
std::list<AudioFrameWithData> temp_audio_frames;
RETURN_IF_NOT_OK(audio_frame_generator.OutputFrames(temp_audio_frames));
if (temp_audio_frames.empty()) {
absl::SleepFor(absl::Milliseconds(50));
} else {
// Decode the audio frames. They are required to determine the demixed
// frames.
std::list<DecodedAudioFrame> temp_decoded_audio_frames;
RETURN_IF_NOT_OK(audio_frame_decoder.Decode(temp_audio_frames,
temp_decoded_audio_frames));

// Demix the audio frames.
IdLabeledFrameMap id_to_labeled_frame;
IdLabeledFrameMap id_to_labeled_decoded_frame;
RETURN_IF_NOT_OK(demixing_module.DemixAudioSamples(
temp_audio_frames, temp_decoded_audio_frames, id_to_labeled_frame,
id_to_labeled_decoded_frame));

// Collect and organize in time.
const auto start_timestamp = temp_audio_frames.front().start_timestamp;
for (const auto& [id, labeled_frame] : id_to_labeled_frame) {
id_to_time_to_labeled_frame[id][start_timestamp] = labeled_frame;
}
for (const auto& [id, labeled_decoded_frame] :
id_to_labeled_decoded_frame) {
id_to_time_to_labeled_decoded_frame[id][start_timestamp] =
labeled_decoded_frame;
}
audio_frames.splice(audio_frames.end(), temp_audio_frames);
continue;
}

// TODO(b/315924757): Generate recon gain parameter blocks iteratively
// here.
// Decode the audio frames. They are required to determine the demixed
// frames.
std::list<DecodedAudioFrame> temp_decoded_audio_frames;
RETURN_IF_NOT_OK(audio_frame_decoder.Decode(temp_audio_frames,
temp_decoded_audio_frames));

// Demix the audio frames.
IdLabeledFrameMap id_to_labeled_frame;
IdLabeledFrameMap id_to_labeled_decoded_frame;
RETURN_IF_NOT_OK(demixing_module.DemixAudioSamples(
temp_audio_frames, temp_decoded_audio_frames, id_to_labeled_frame,
id_to_labeled_decoded_frame));

// Add recon gain parameter blocks' metadata.
const auto start_timestamp = temp_audio_frames.front().start_timestamp;
int32_t unused_current_timestamp = start_timestamp;
RETURN_IF_NOT_OK(AddAllParameterBlockMetadataForCurrentTimestamp(
user_metadata.parameter_block_metadata(), param_definitions,
ParamDefinition::kParameterDefinitionReconGain,
parameter_block_generator, unused_current_timestamp));

// Recon gain parameter blocks are generated based on the original and
// demixed audio frames.
std::list<ParameterBlockWithData> temp_recon_gain_parameter_blocks;
RETURN_IF_NOT_OK(parameter_block_generator.GenerateReconGain(
id_to_labeled_frame, id_to_labeled_decoded_frame, global_timing_module,
temp_recon_gain_parameter_blocks));
recon_gain_parameter_blocks.splice(recon_gain_parameter_blocks.end(),
temp_recon_gain_parameter_blocks);

// Collect and organize generated audio frames in time.
for (const auto& [id, labeled_frame] : id_to_labeled_frame) {
id_to_time_to_labeled_frame[id][start_timestamp] = labeled_frame;
}
for (const auto& [id, labeled_decoded_frame] :
id_to_labeled_decoded_frame) {
id_to_time_to_labeled_decoded_frame[id][start_timestamp] =
labeled_decoded_frame;
}
audio_frames.splice(audio_frames.end(), temp_audio_frames);
}
PrintAudioFrames(audio_frames);

// Generate the remaining parameter blocks. Recon gain blocks blocks are
// determined based on the original and demixed audio frames.
std::list<ParameterBlockWithData> recon_gain_parameter_blocks;
RETURN_IF_NOT_OK(parameter_block_generator.GenerateReconGain(
id_to_time_to_labeled_frame, id_to_time_to_labeled_decoded_frame,
global_timing_module, recon_gain_parameter_blocks));

ArbitraryObuGenerator arbitrary_obu_generator(
user_metadata.arbitrary_obu_metadata());
RETURN_IF_NOT_OK(arbitrary_obu_generator.Generate(arbitrary_obus));
Expand Down
Loading

0 comments on commit eac84b6

Please sign in to comment.