diff --git a/i6_models/parts/rasr_fsa.py b/i6_models/parts/rasr_fsa.py index ececc34c..00d7ad9f 100644 --- a/i6_models/parts/rasr_fsa.py +++ b/i6_models/parts/rasr_fsa.py @@ -306,6 +306,43 @@ def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsa: return out_fsa +def join_fsas_fbw_v2(fsas: Iterable[FsaTuple]) -> WeightedFsaV2: + """ + Joins a set of FSAs represented as tuples into a single :classref:`WeightedFsaV2` object, + for consumption by the FBW V2 op. + + :param fsas: FSAs to be concatenated, represented as tuples with the following fields: + * number of states S + * number of edges E + * integer edge array of shape [E, 3] where each row is an edge + consisting of from-state, to-state and the emission idx + * float weight array of shape [E,] + :return: Single FSA object corresponding to the joined FSAs passed as parameter. + """ + fsas = list(fsas) # ensure we can iterate multiple times over this iterable + num_states = [f[0] for f in fsas] + num_edges = [f[1] for f in fsas] + start_states = np.cumsum(np.array([0] + num_states, dtype=np.uint32))[:-1] + end_states = np.cumsum(num_states) - 1 + weights = np.concatenate(tuple(f[3] for f in fsas)) + + edges = [] + for idx, f in enumerate(fsas): + f_edges = f[2].reshape(3, -1).copy() + f_edges[:2, :] += start_states[idx] + edges.append(f_edges) + + out_fsa = WeightedFsaV2( + torch.IntTensor(num_states).to(torch.uint32), + torch.IntTensor(num_edges).to(torch.uint32), + torch.IntTensor(np.concatenate(edges, axis=1)).contiguous(), + torch.Tensor(weights), + torch.IntTensor(np.array([start_states, end_states])), + ) + + return out_fsa + + class _RasrFsaBuilderFbw2(_AbstractRasrFsaBuilder): """ Abstract base class for building an FSA. @@ -334,28 +371,7 @@ def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsaV2: :return: Single FSA object corresponding to the joined FSAs passed as parameter. """ - fsas = list(fsas) # ensure we can iterate multiple times over this iterable - num_states = [f[0] for f in fsas] - num_edges = [f[1] for f in fsas] - start_states = np.cumsum(np.array([0] + num_states, dtype=np.uint32))[:-1] - end_states = np.cumsum(num_states) - 1 - weights = np.concatenate(tuple(f[3] for f in fsas)) - - edges = [] - for idx, f in enumerate(fsas): - f_edges = f[2].reshape(3, -1).copy() - f_edges[:2, :] += start_states[idx] - edges.append(f_edges) - - out_fsa = WeightedFsaV2( - torch.IntTensor(num_states).to(torch.uint32), - torch.IntTensor(num_edges).to(torch.uint32), - torch.IntTensor(np.concatenate(edges, axis=1)).contiguous(), - torch.Tensor(weights), - torch.IntTensor(np.array([start_states, end_states])), - ) - - return out_fsa + return join_fsas_fbw_v2(fsas) class RasrFsaBuilderV2(_RasrFsaBuilderFbw2):