Skip to content

Commit

Permalink
fix(xtts): update streaming for transformers>=4.42.0 (#59)
Browse files Browse the repository at this point in the history
* Fix Stream Generator on MacOS

* Make it work on mps

* Implement custom tensor.isin

* Fix for latest TF

* Comment out hack for now

* Remove unused code

* build: increase minimum transformers version

* style: fix

---------

Co-authored-by: Enno Hermann <Eginhard@users.noreply.github.com>
  • Loading branch information
gravityrail and eginhard authored Jul 25, 2024
1 parent 20583a4 commit 20bbb41
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions TTS/tts/layers/xtts/stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def generate( # noqa: PLR0911

elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)

# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
Expand All @@ -401,7 +401,7 @@ def generate( # noqa: PLR0911
)
elif is_sample_gen_stream_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)

# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
Expand Down Expand Up @@ -463,7 +463,7 @@ def generate( # noqa: PLR0911

elif is_beam_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)

if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ dependencies = [
"gruut[de,es,fr]==2.2.3",
# Tortoise
"einops>=0.6.0",
"transformers>=4.41.1",
"transformers>=4.42.0",
# Bark
"encodec>=0.1.1",
# XTTS
Expand Down

0 comments on commit 20bbb41

Please sign in to comment.