diff --git a/ChatTTS/core.py b/ChatTTS/core.py index ceabd70fa..be3638ac6 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -402,6 +402,7 @@ async def _infer( else: # Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop. import librosa + silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10) silence_left = 0 if len(silence_intervals) == 0: @@ -532,7 +533,9 @@ async def _infer_code( async for i in results_generator: token_ids = [] hidden_states = [] - if (stream and len(i.outputs[0].token_ids) % stream_batch_size == 0) or i.finished: + if ( + stream and len(i.outputs[0].token_ids) % stream_batch_size == 0 + ) or i.finished: token_ids.append(torch.tensor(i.outputs[0].token_ids)) hidden_states.append( i.outputs[0].hidden_states.to(torch.float32).to(self.device) @@ -568,9 +571,7 @@ async def _infer_code( hidden_states = [] if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished: token_ids.append(i.ids[0]) - hidden_states.append( - i.hiddens[0].to(torch.float32).to(self.device) - ) + hidden_states.append(i.hiddens[0].to(torch.float32).to(self.device)) yield GPT.GenerationOutputs( ids=token_ids, finished=i.finished, diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 4abeb241b..3b700f338 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -68,7 +68,7 @@ def from_pretrained( num_audio_tokens=self.num_audio_tokens, num_text_tokens=self.num_text_tokens, post_model_path=embed_file_path, - dtype="float32" + dtype="float32", ) self.logger.info("vLLM model loaded") return @@ -585,7 +585,7 @@ async def generate( attentions, hiddens, infer_text, - False + False, ) del not_finished @@ -609,11 +609,5 @@ async def generate( del finish, inputs_ids_buf yield self._prepare_generation_outputs( - inputs_ids, - start_idx, - end_idx, - attentions, - hiddens, - infer_text, - True + inputs_ids, start_idx, end_idx, attentions, hiddens, infer_text, True ) diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index afa438a5f..693f779ac 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -107,7 +107,9 @@ def set_block_size(self, block_size: int) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> tuple[list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor]]: + ) -> tuple[ + list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor] + ]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -360,7 +362,9 @@ def _prepare_sample( def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor]]: + ) -> Tuple[ + torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor] + ]: speaker_embedding = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or @@ -368,9 +372,13 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_metadata, prompt_lens, speaker_embedding) = ( - self._prepare_prompt(seq_group_metadata_list) - ) + ( + input_tokens, + input_positions, + input_metadata, + prompt_lens, + speaker_embedding, + ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_metadata) = self._prepare_decode( seq_group_metadata_list @@ -462,7 +470,13 @@ def get_size_or_none(x: Optional[torch.Tensor]): perform_sampling=False, ) - return input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding + return ( + input_tokens, + input_positions, + input_metadata, + sampling_metadata, + speaker_embedding, + ) @torch.inference_mode() def execute_model( @@ -471,9 +485,13 @@ def execute_model( kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[SamplerOutput]: - input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding = ( - self.prepare_input_tensors(seq_group_metadata_list) - ) + ( + input_tokens, + input_positions, + input_metadata, + sampling_metadata, + speaker_embedding, + ) = self.prepare_input_tensors(seq_group_metadata_list) # print(sampling_metadata.seq_data) seq_groups = [] for i, rtn in enumerate(sampling_metadata.seq_groups): @@ -522,7 +540,9 @@ def execute_model( if speaker_embedding_params is None: speaker_embedding_params = speaker_embedding[i] else: - speaker_embedding_params = torch.cat((speaker_embedding_params, speaker_embedding[i])) + speaker_embedding_params = torch.cat( + (speaker_embedding_params, speaker_embedding[i]) + ) else: speaker_embedding_params = self.post_model(input_tokens, text_mask) @@ -560,7 +580,7 @@ def execute_model( # sampling_metadata=sampling_metadata, # ) results = [] - for i,val in enumerate(seq_groups): + for i, val in enumerate(seq_groups): idx_next_i = idx_next[i, 0, :].tolist() logprob_i = logprob[i].tolist() tmp_hidden_states = hidden_states[i] @@ -781,7 +801,9 @@ def _make_tensor_with_pad( for x_i in x: pad_i = pad if isinstance(x[0][0], list): - pad_i = [0,] * len(x[0][0]) + pad_i = [ + 0, + ] * len(x[0][0]) elif isinstance(x[0][0], tuple): pad_i = (0,) * len(x[0][0]) padded_x.append(_pad_to_max(x_i, max_len, pad_i)) @@ -791,6 +813,7 @@ def _make_tensor_with_pad( device=device, ) + def _make_with_pad( x: List[torch.Tensor], max_len: int, @@ -805,11 +828,15 @@ def _make_with_pad( padded_x.append(x_i) else: padded_x.append( - torch.cat((torch.zeros(1, max_len-x_i.shape[-2], 768).to(device), x_i), dim=1) + torch.cat( + (torch.zeros(1, max_len - x_i.shape[-2], 768).to(device), x_i), + dim=1, + ) ) return padded_x + def _get_graph_batch_size(batch_size: int) -> int: if batch_size <= 2: return batch_size