From 9504417a49f5d91abcdd69ea96184305db574e0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nahuel=20Unai=20Rosell=C3=B3=20Beneitez?= Date: Tue, 13 Jan 2026 06:03:48 -0500 Subject: [PATCH 1/3] Set TDP scaling and batched FSA build as static --- i6_models/parts/rasr_fsa.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/i6_models/parts/rasr_fsa.py b/i6_models/parts/rasr_fsa.py index ececc34c..98174d13 100644 --- a/i6_models/parts/rasr_fsa.py +++ b/i6_models/parts/rasr_fsa.py @@ -165,7 +165,8 @@ def __setstate__(self, state): self.__dict__.update(state) self.builder = self.get_builder(config_path=self.config_path) - def apply_tdp_scale_to_fsa_tuple(self, fsa: FsaTuple, tdp_scale: float) -> FsaTuple: + @staticmethod + def apply_tdp_scale_to_fsa_tuple(fsa: FsaTuple, tdp_scale: float) -> FsaTuple: """ Scales the weights of an FSA represented as a tuple by the factor (TDP scale) provided. @@ -204,12 +205,14 @@ def build_single(self, single_identifier: Any) -> FsaTuple: """ ... + @staticmethod @abstractmethod - def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> Union[WeightedFsa, WeightedFsaV2]: + def build_batched_fsa(fsas: Iterable[FsaTuple], tdp_scale: float) -> Union[WeightedFsa, WeightedFsaV2]: """ Creates the final FSA to be used by the corresponding `fbw` op from `i6_native_ops`. :param fsas: Sequence of FSAs to be batched together. + :param tdp_scale: FSA weights will be scaled (multiplied) by this value. :return: Single FSA which bundles together all FSAs provided as parameter. The final object is compatible with the corresponding `fbw` op from `i6_native_ops`. """ @@ -230,7 +233,7 @@ def build_batch(self, multiple_identifiers: Iterable[Any]) -> Union[WeightedFsa, fsas: Iterable[FsaTuple] = map(self.build_single, multiple_identifiers) - return self.build_batched_fsa(fsas) + return self.build_batched_fsa(fsas, self.tdp_scale) class RasrFsaBuilder(_AbstractRasrFsaBuilder): @@ -257,7 +260,8 @@ def build_single(self, seq_tag: str) -> FsaTuple: raw_fsa = self.builder.build_by_segment_name(seq_tag) return raw_fsa - def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsa: + @staticmethod + def build_batched_fsa(fsas: Iterable[FsaTuple], tdp_scale: float) -> WeightedFsa: """ Build and concatenate the FSAs for a batch of sequence tags and reformat as an input to `i6_native_ops.fbw.fbw_loss`. @@ -273,6 +277,7 @@ def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsa: * integer edge array of shape [E, 3] where each row is an edge consisting of from-state, to-state and the emission idx * float weight array of shape [E,] + :param tdp_scale: FSA weights will be scaled (multiplied) by this value. :return: a concatenated FSA """ @@ -300,8 +305,8 @@ def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsa: start_end_states, ) - if self.tdp_scale != 1.0: - out_fsa *= self.tdp_scale + if tdp_scale != 1.0: + out_fsa *= tdp_scale return out_fsa @@ -321,7 +326,8 @@ class _RasrFsaBuilderFbw2(_AbstractRasrFsaBuilder): Using any subclass requires a working installation of the python package `librasr`. """ - def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsaV2: + @staticmethod + def build_batched_fsa(fsas: Iterable[FsaTuple], tdp_scale: float) -> WeightedFsaV2: """ Joins a set of FSAs represented as tuples into a single :classref:`WeightedFsaV2` object. @@ -331,6 +337,7 @@ def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsaV2: * integer edge array of shape [E, 3] where each row is an edge consisting of from-state, to-state and the emission idx * float weight array of shape [E,] + :param tdp_scale: FSA weights will be scaled (multiplied) by this value. :return: Single FSA object corresponding to the joined FSAs passed as parameter. """ From 82e05a412040353cf5afbf2ce9d13d6d5f416d84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nahuel=20Unai=20Rosell=C3=B3=20Beneitez?= Date: Tue, 13 Jan 2026 08:22:45 -0500 Subject: [PATCH 2/3] Revert "Set TDP scaling and batched FSA build as static" This reverts commit 9504417a49f5d91abcdd69ea96184305db574e0a. --- i6_models/parts/rasr_fsa.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/i6_models/parts/rasr_fsa.py b/i6_models/parts/rasr_fsa.py index 98174d13..ececc34c 100644 --- a/i6_models/parts/rasr_fsa.py +++ b/i6_models/parts/rasr_fsa.py @@ -165,8 +165,7 @@ def __setstate__(self, state): self.__dict__.update(state) self.builder = self.get_builder(config_path=self.config_path) - @staticmethod - def apply_tdp_scale_to_fsa_tuple(fsa: FsaTuple, tdp_scale: float) -> FsaTuple: + def apply_tdp_scale_to_fsa_tuple(self, fsa: FsaTuple, tdp_scale: float) -> FsaTuple: """ Scales the weights of an FSA represented as a tuple by the factor (TDP scale) provided. @@ -205,14 +204,12 @@ def build_single(self, single_identifier: Any) -> FsaTuple: """ ... - @staticmethod @abstractmethod - def build_batched_fsa(fsas: Iterable[FsaTuple], tdp_scale: float) -> Union[WeightedFsa, WeightedFsaV2]: + def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> Union[WeightedFsa, WeightedFsaV2]: """ Creates the final FSA to be used by the corresponding `fbw` op from `i6_native_ops`. :param fsas: Sequence of FSAs to be batched together. - :param tdp_scale: FSA weights will be scaled (multiplied) by this value. :return: Single FSA which bundles together all FSAs provided as parameter. The final object is compatible with the corresponding `fbw` op from `i6_native_ops`. """ @@ -233,7 +230,7 @@ def build_batch(self, multiple_identifiers: Iterable[Any]) -> Union[WeightedFsa, fsas: Iterable[FsaTuple] = map(self.build_single, multiple_identifiers) - return self.build_batched_fsa(fsas, self.tdp_scale) + return self.build_batched_fsa(fsas) class RasrFsaBuilder(_AbstractRasrFsaBuilder): @@ -260,8 +257,7 @@ def build_single(self, seq_tag: str) -> FsaTuple: raw_fsa = self.builder.build_by_segment_name(seq_tag) return raw_fsa - @staticmethod - def build_batched_fsa(fsas: Iterable[FsaTuple], tdp_scale: float) -> WeightedFsa: + def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsa: """ Build and concatenate the FSAs for a batch of sequence tags and reformat as an input to `i6_native_ops.fbw.fbw_loss`. @@ -277,7 +273,6 @@ def build_batched_fsa(fsas: Iterable[FsaTuple], tdp_scale: float) -> WeightedFsa * integer edge array of shape [E, 3] where each row is an edge consisting of from-state, to-state and the emission idx * float weight array of shape [E,] - :param tdp_scale: FSA weights will be scaled (multiplied) by this value. :return: a concatenated FSA """ @@ -305,8 +300,8 @@ def build_batched_fsa(fsas: Iterable[FsaTuple], tdp_scale: float) -> WeightedFsa start_end_states, ) - if tdp_scale != 1.0: - out_fsa *= tdp_scale + if self.tdp_scale != 1.0: + out_fsa *= self.tdp_scale return out_fsa @@ -326,8 +321,7 @@ class _RasrFsaBuilderFbw2(_AbstractRasrFsaBuilder): Using any subclass requires a working installation of the python package `librasr`. """ - @staticmethod - def build_batched_fsa(fsas: Iterable[FsaTuple], tdp_scale: float) -> WeightedFsaV2: + def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsaV2: """ Joins a set of FSAs represented as tuples into a single :classref:`WeightedFsaV2` object. @@ -337,7 +331,6 @@ def build_batched_fsa(fsas: Iterable[FsaTuple], tdp_scale: float) -> WeightedFsa * integer edge array of shape [E, 3] where each row is an edge consisting of from-state, to-state and the emission idx * float weight array of shape [E,] - :param tdp_scale: FSA weights will be scaled (multiplied) by this value. :return: Single FSA object corresponding to the joined FSAs passed as parameter. """ From 3d1d25d2d10ed372a0d1eb3e523fd4541db87422 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nahuel=20Unai=20Rosell=C3=B3=20Beneitez?= Date: Tue, 13 Jan 2026 08:25:42 -0500 Subject: [PATCH 3/3] Move code into standalone function Interface is left as is --- i6_models/parts/rasr_fsa.py | 60 +++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/i6_models/parts/rasr_fsa.py b/i6_models/parts/rasr_fsa.py index ececc34c..00d7ad9f 100644 --- a/i6_models/parts/rasr_fsa.py +++ b/i6_models/parts/rasr_fsa.py @@ -306,6 +306,43 @@ def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsa: return out_fsa +def join_fsas_fbw_v2(fsas: Iterable[FsaTuple]) -> WeightedFsaV2: + """ + Joins a set of FSAs represented as tuples into a single :classref:`WeightedFsaV2` object, + for consumption by the FBW V2 op. + + :param fsas: FSAs to be concatenated, represented as tuples with the following fields: + * number of states S + * number of edges E + * integer edge array of shape [E, 3] where each row is an edge + consisting of from-state, to-state and the emission idx + * float weight array of shape [E,] + :return: Single FSA object corresponding to the joined FSAs passed as parameter. + """ + fsas = list(fsas) # ensure we can iterate multiple times over this iterable + num_states = [f[0] for f in fsas] + num_edges = [f[1] for f in fsas] + start_states = np.cumsum(np.array([0] + num_states, dtype=np.uint32))[:-1] + end_states = np.cumsum(num_states) - 1 + weights = np.concatenate(tuple(f[3] for f in fsas)) + + edges = [] + for idx, f in enumerate(fsas): + f_edges = f[2].reshape(3, -1).copy() + f_edges[:2, :] += start_states[idx] + edges.append(f_edges) + + out_fsa = WeightedFsaV2( + torch.IntTensor(num_states).to(torch.uint32), + torch.IntTensor(num_edges).to(torch.uint32), + torch.IntTensor(np.concatenate(edges, axis=1)).contiguous(), + torch.Tensor(weights), + torch.IntTensor(np.array([start_states, end_states])), + ) + + return out_fsa + + class _RasrFsaBuilderFbw2(_AbstractRasrFsaBuilder): """ Abstract base class for building an FSA. @@ -334,28 +371,7 @@ def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsaV2: :return: Single FSA object corresponding to the joined FSAs passed as parameter. """ - fsas = list(fsas) # ensure we can iterate multiple times over this iterable - num_states = [f[0] for f in fsas] - num_edges = [f[1] for f in fsas] - start_states = np.cumsum(np.array([0] + num_states, dtype=np.uint32))[:-1] - end_states = np.cumsum(num_states) - 1 - weights = np.concatenate(tuple(f[3] for f in fsas)) - - edges = [] - for idx, f in enumerate(fsas): - f_edges = f[2].reshape(3, -1).copy() - f_edges[:2, :] += start_states[idx] - edges.append(f_edges) - - out_fsa = WeightedFsaV2( - torch.IntTensor(num_states).to(torch.uint32), - torch.IntTensor(num_edges).to(torch.uint32), - torch.IntTensor(np.concatenate(edges, axis=1)).contiguous(), - torch.Tensor(weights), - torch.IntTensor(np.array([start_states, end_states])), - ) - - return out_fsa + return join_fsas_fbw_v2(fsas) class RasrFsaBuilderV2(_RasrFsaBuilderFbw2):