Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/use vllm #1053

Merged
merged 7 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from modelscope import snapshot_download
import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, VllmCosyVoice2Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type

Expand Down Expand Up @@ -63,6 +63,9 @@ def list_available_spks(self):
spks = list(self.frontend.spk2info.keys())
return spks

def add_spk_info(self, spk_id, spk_info):
self.frontend.add_spk_info(spk_id, spk_info)

def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_sft(i, spk_id)
Expand All @@ -88,6 +91,22 @@ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=F
yield model_output
start_time = time.time()

def inference_zero_shot_by_spk_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
"""使用预定义的说话人执行 zero_shot 推理"""
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_zero_shot_by_spk_id(i, spk_id)
start_time = time.time()
last_time = start_time
chunk_index = 0
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech index:{}, len {:.2f}, rtf {:.3f}, cost {:.3f}s, all cost time {:.3f}s'.format(
chunk_index, speech_len, (time.time()-last_time)/speech_len, time.time()-last_time, time.time()-start_time))
yield model_output
last_time = time.time()
chunk_index += 1

def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
Expand Down Expand Up @@ -126,7 +145,7 @@ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed

class CosyVoice2(CosyVoice):

def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False):
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
self.fp16 = fp16
Expand All @@ -145,7 +164,14 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
if use_vllm:
try:
self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16)
except Exception as e:
logging.warning(f'use vllm inference failed. \n{e}')
raise e
else:
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
Expand All @@ -171,3 +197,14 @@ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

def inference_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2_by_spk_id(i, instruct_text, spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
84 changes: 81 additions & 3 deletions cosyvoice/cli/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Generator
from typing import Generator, Optional
import json
import onnxruntime
import torch
Expand All @@ -24,6 +24,8 @@
import os
import re
import inflect
from pydantic import BaseModel, ConfigDict

try:
import ttsfrd
use_ttsfrd = True
Expand All @@ -36,6 +38,18 @@
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation


class SpeakerInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

name: Optional[str] = None
spk_id: str
prompt_text: str
prompt_text_token: torch.Tensor
speech_feat: torch.Tensor
speech_token: torch.Tensor
embedding: torch.Tensor


class CosyVoiceFrontEnd:

def __init__(self,
Expand All @@ -55,8 +69,9 @@ def __init__(self,
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
"CPUExecutionProvider"])
self.spk2info_path = spk2info
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device)
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=False)
else:
self.spk2info = {}
self.allowed_special = allowed_special
Expand All @@ -68,7 +83,8 @@ def __init__(self,
'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyinvg')
else:
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
# self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=False)
self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine()

Expand Down Expand Up @@ -138,11 +154,15 @@ def text_normalize(self, text, split=True, text_frontend=True):
text = text.replace(" - ", ",")
text = remove_bracket(text)
text = re.sub(r'[,,、]+$', '。', text)
if not split:
return text
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
else:
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
if not split:
return text
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
texts = [i for i in texts if not is_only_punctuation(i)]
Expand All @@ -151,6 +171,7 @@ def text_normalize(self, text, split=True, text_frontend=True):
def frontend_sft(self, tts_text, spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
embedding = self.spk2info[spk_id]['embedding']
assert embedding is not None
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input

Expand Down Expand Up @@ -209,3 +230,60 @@ def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
'flow_embedding': embedding}
return model_input

def generate_spk_info(self, spk_id: str, prompt_text: str, prompt_speech_16k: torch.Tensor, resample_rate:int=24000, name: str=None):
assert isinstance(spk_id, str)
assert spk_id not in self.spk2info, "spk_id already exists"
prompt_text_token, _ = self._extract_text_token(prompt_text)
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
speech_feat, _ = self._extract_speech_feat(prompt_speech_resample)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
if resample_rate == 24000:
# cosyvoice2, force speech_feat % speech_token = 2
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
speech_feat = speech_feat[:, :2 * token_len]
speech_token = speech_token[:, :token_len]
embedding = self._extract_spk_embedding(prompt_speech_16k)
spk_info = SpeakerInfo(
name=name,
spk_id=spk_id,
prompt_text=prompt_text,
prompt_text_token=prompt_text_token,
speech_feat=speech_feat,
speech_token=speech_token,
embedding=embedding,
)
self.add_spk_info(spk_id, spk_info)

def add_spk_info(self, spk_id: str, spk_info: dict|SpeakerInfo):
if isinstance(spk_info, BaseModel):
spk_info = spk_info.model_dump()
self.spk2info[spk_id] = spk_info
if self.spk2info_path:
torch.save(self.spk2info, self.spk2info_path)

def frontend_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id):
assert spk_id in self.spk2info
tts_text_token, _ = self._extract_text_token(tts_text)
prompt_text_token, _ = self._extract_text_token(instruct_text + '<|endofprompt|>')
model_input = {'text': tts_text_token,
'prompt_text': prompt_text_token,
'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'],
'llm_embedding': self.spk2info[spk_id]['embedding'],
'flow_embedding': self.spk2info[spk_id]['embedding'],
}
return model_input

def frontend_zero_shot_by_spk_id(self, tts_text, spk_id):
assert spk_id in self.spk2info
tts_text_token, _ = self._extract_text_token(tts_text)
model_input = {'text': tts_text_token,
'prompt_text': self.spk2info[spk_id]['prompt_text_token'],
'llm_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'],
'llm_embedding': self.spk2info[spk_id]['embedding'],
'flow_embedding': self.spk2info[spk_id]['embedding']
}
return model_input
23 changes: 23 additions & 0 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,26 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
torch.cuda.empty_cache()


class VllmCosyVoice2Model(CosyVoice2Model):
def __init__(self,
model_dir: str,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool):
try:
from cosyvoice.llm.llm_vllm import VllmQwen2LM
except Exception as e:
raise e
llm = VllmQwen2LM(model_dir)
super().__init__(llm,flow,hift,fp16)

def load(self, llm_model, flow_model, hift_model):
self.flow.load_state_dict(torch.load(flow_model, weights_only=True, map_location=self.device), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in
torch.load(hift_model, weights_only=True, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
Loading