Skip to content

Commit

Permalink
Merge pull request #224 from Dartvauder/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
Dartvauder authored Sep 28, 2024
2 parents 73a75d9 + 86e3c1a commit 5ab80fc
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 37 deletions.
114 changes: 77 additions & 37 deletions LaunchFile/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import librosa
import librosa.display
import base64
import io
import gc
import cv2
import subprocess
Expand Down Expand Up @@ -665,7 +664,7 @@ def generate_magicprompt(prompt, max_new_tokens):
return enhanced_prompt


def load_model(model_name, model_type, n_ctx=None):
def load_model(model_name, model_type, n_ctx, n_batch):
if model_name:
model_path = f"inputs/text/llm_models/{model_name}"
if model_type == "transformers":
Expand All @@ -675,7 +674,7 @@ def load_model(model_name, model_type, n_ctx=None):
model = AutoModelForCausalLM().AutoModelForCausalLM.from_pretrained(
model_path,
device_map=device,
load_in_4bit=True,
load_in_8bit=True,
torch_dtype=torch.float16,
trust_remote_code=True
)
Expand All @@ -687,8 +686,7 @@ def load_model(model_name, model_type, n_ctx=None):
elif model_type == "llama":
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Llama().Llama(model_path, n_gpu_layers=-1 if device == "cuda" else 0)
model.n_ctx = n_ctx
model = Llama().Llama(model_path, n_gpu_layers=-1 if device == "cuda" else 0, n_ctx=n_ctx, n_batch=n_batch)
tokenizer = None
return tokenizer, model, None
except (ValueError, RuntimeError):
Expand Down Expand Up @@ -841,9 +839,9 @@ def get_languages():
}


def generate_text_and_speech(input_text, system_prompt, input_audio, llm_model_name, llm_lora_model_name, enable_web_search, enable_libretranslate, target_lang, enable_openparse, pdf_file, enable_multimodal, input_image, enable_tts,
llm_settings_html, llm_model_type, max_length, max_tokens,
temperature, top_p, top_k, chat_history_format, tts_settings_html, speaker_wav, language, tts_temperature, tts_top_p, tts_top_k, tts_speed, output_format):
def generate_text_and_speech(input_text, system_prompt, input_audio, llm_model_type, llm_model_name, llm_lora_model_name, enable_web_search, enable_libretranslate, target_lang, enable_openparse, pdf_file, enable_multimodal, input_image, enable_tts,
llm_settings_html, max_new_tokens, max_length, min_length, n_ctx, n_batch, temperature, top_p, min_p, typical_p, top_k,
do_sample, early_stopping, stopping, repetition_penalty, frequency_penalty, presence_penalty, length_penalty, no_repeat_ngram_size, num_beams, num_return_sequences, chat_history_format, tts_settings_html, speaker_wav, language, tts_temperature, tts_top_p, tts_top_k, tts_speed, tts_repetition_penalty, tts_length_penalty, output_format):
global chat_history, chat_dir, tts_model, whisper_model

if 'chat_history' not in globals() or chat_history is None:
Expand Down Expand Up @@ -895,7 +893,7 @@ def generate_text_and_speech(input_text, system_prompt, input_audio, llm_model_n
local_dir=moondream2_path,
filename="*text-model*",
chat_handler=chat_handler,
n_ctx=2048,
n_ctx=n_ctx,
)
try:
if input_image:
Expand Down Expand Up @@ -979,7 +977,11 @@ def image_to_base64_data_uri(image_path):
flush()

else:
tokenizer, llm_model, error_message = load_model(llm_model_name, llm_model_type)
if llm_model_type == "llama":
tokenizer, llm_model, error_message = load_model(llm_model_name, llm_model_type, n_ctx, n_batch)
else:
tokenizer, llm_model, error_message = load_model(llm_model_name, llm_model_type, n_ctx=None, n_batch=None)

if llm_lora_model_name:
tokenizer, llm_model, error_message = load_lora_model(llm_model_name, llm_lora_model_name, llm_model_type)
if error_message:
Expand Down Expand Up @@ -1024,6 +1026,9 @@ def image_to_base64_data_uri(image_path):

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
stop_words = stopping.split(',') if stopping.strip() else None
stop_ids = [tokenizer.encode(word.strip(), add_special_tokens=False)[0] for word in
stop_words] if stop_words else None

full_prompt = f"{system_prompt}\n\n{openparse_context}{web_context}{context}Human: {prompt}\nAssistant:"
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(device)
Expand All @@ -1037,13 +1042,20 @@ def image_to_base64_data_uri(image_path):
with torch.no_grad():
output = llm_model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
max_new_tokens=max_new_tokens,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=1.1,
no_repeat_ngram_size=2,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
eos_token_id=stop_ids if stop_ids else tokenizer.eos_token_id
)

next_token = output[0][inputs['input_ids'].shape[1]:]
Expand All @@ -1064,18 +1076,24 @@ def image_to_base64_data_uri(image_path):
if not chat_history or chat_history[-1][1] is not None:
chat_history.append([prompt, ""])

stop_sequences = [seq.strip() for seq in stopping.split(',')] if stopping.strip() else None

full_prompt = f"{system_prompt}\n\n{openparse_context}{web_context}{context}Human: {prompt}\nAssistant:"

for token in llm_model(
full_prompt,
max_tokens=max_tokens,
stop=["Human:", "\n"],
stream=True,
echo=False,
temperature=temperature,
max_tokens=max_new_tokens,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
top_k=top_k,
repeat_penalty=1.1,
temperature=temperature,
repeat_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
stop=stop_sequences,
stream=True,
echo=False
):

text += token['choices'][0]['text']
Expand Down Expand Up @@ -1116,11 +1134,9 @@ def image_to_base64_data_uri(image_path):
with open(chat_history_path, "w", encoding="utf-8") as f:
json.dump(chat_history_json, f, ensure_ascii=False, indent=4)
if enable_tts and text:
repetition_penalty = 2.0
length_penalty = 1.0
wav = tts_model.tts(text=text, speaker_wav=f"inputs/audio/voices/{speaker_wav}", language=language,
temperature=tts_temperature, top_p=tts_top_p, top_k=tts_top_k, speed=tts_speed,
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
repetition_penalty=tts_repetition_penalty, length_penalty=tts_length_penalty)
now = datetime.now()
audio_filename = f"TTS_{now.strftime('%Y%m%d_%H%M%S')}.{output_format}"
audio_path = os.path.join(chat_dir, 'audio', audio_filename)
Expand Down Expand Up @@ -1149,7 +1165,7 @@ def image_to_base64_data_uri(image_path):
yield chat_history, audio_path, chat_dir, None


def generate_tts_stt(text, audio, tts_settings_html, speaker_wav, language, tts_temperature, tts_top_p, tts_top_k, tts_speed, tts_output_format, stt_output_format):
def generate_tts_stt(text, audio, tts_settings_html, speaker_wav, language, tts_temperature, tts_top_p, tts_top_k, tts_speed, tts_repetition_penalty, tts_length_penalty, tts_output_format, stt_output_format):
global tts_model, whisper_model

tts_output = None
Expand All @@ -1168,7 +1184,8 @@ def generate_tts_stt(text, audio, tts_settings_html, speaker_wav, language, tts_

try:
wav = tts_model.tts(text=text, speaker_wav=f"inputs/audio/voices/{speaker_wav}", language=language,
temperature=tts_temperature, top_p=tts_top_p, top_k=tts_top_k, speed=tts_speed)
temperature=tts_temperature, top_p=tts_top_p, top_k=tts_top_k, speed=tts_speed,
repetition_penalty=tts_repetition_penalty, length_penalty=tts_length_penalty )
except Exception as e:
return None, str(e)

Expand Down Expand Up @@ -8567,10 +8584,11 @@ def reload_interface():
gr.Textbox(label=_("Enter your request", lang)),
gr.Textbox(label=_("Enter your system prompt", lang)),
gr.Audio(type="filepath", label=_("Record your request (optional)", lang)),
gr.Dropdown(choices=llm_models_list, label=_("Select LLM model", lang), value=None)
gr.Radio(choices=["transformers", "llama"], label=_("Select model type", lang), value="transformers"),
gr.Dropdown(choices=llm_models_list, label=_("Select LLM model", lang), value=None),
gr.Dropdown(choices=llm_lora_models_list, label=_("Select LoRA model (optional)", lang), value=None),
],
additional_inputs=[
gr.Dropdown(choices=llm_lora_models_list, label=_("Select LoRA model (optional)", lang), value=None),
gr.Checkbox(label=_("Enable WebSearch", lang), value=False),
gr.Checkbox(label=_("Enable LibreTranslate", lang), value=False),
gr.Dropdown(choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh", "ja", "hi"],
Expand All @@ -8581,12 +8599,26 @@ def reload_interface():
gr.Image(label=_("Upload your image (for Multimodal)", lang), type="filepath"),
gr.Checkbox(label=_("Enable TTS", lang), value=False),
gr.HTML(_("<h3>LLM Settings</h3>", lang)),
gr.Radio(choices=["transformers", "llama"], label=_("Select model type", lang), value="transformers"),
gr.Slider(minimum=256, maximum=4096, value=512, step=1, label=_("Max length (for transformers type models)", lang)),
gr.Slider(minimum=256, maximum=4096, value=512, step=1, label=_("Max tokens (for llama type models)", lang)),
gr.Slider(minimum=256, maximum=32768, value=512, step=1, label=_("Max tokens", lang)),
gr.Slider(minimum=256, maximum=32768, value=512, step=1, label=_("Max length", lang)),
gr.Slider(minimum=256, maximum=32768, value=512, step=1, label=_("Min length", lang)),
gr.Slider(minimum=256, maximum=32768, value=512, step=1, label=_("Context size (N_CTX) for llama type models", lang)),
gr.Slider(minimum=256, maximum=32768, value=512, step=1, label=_("Context batch (N_BATCH) for llama type models", lang)),
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label=_("Temperature", lang)),
gr.Slider(minimum=0.01, maximum=1.0, value=0.9, step=0.01, label=_("Top P", lang)),
gr.Slider(minimum=1, maximum=100, value=20, step=1, label=_("Top K", lang)),
gr.Slider(minimum=0.01, maximum=1.0, value=0.05, step=0.01, label=_("Min P", lang)),
gr.Slider(minimum=0.01, maximum=1.0, value=1.0, step=0.01, label=_("Typical P", lang)),
gr.Slider(minimum=1, maximum=200, value=40, step=1, label=_("Top K", lang)),
gr.Checkbox(label=_("Enable Do Sample", lang), value=False),
gr.Checkbox(label=_("Enable Early Stopping", lang), value=False),
gr.Textbox(label=_("Stop sequences (optional)", lang), value=""),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label=_("Repetition penalty", lang)),
gr.Slider(minimum=0.1, maximum=2.0, value=0.0, step=0.1, label=_("Frequency penalty", lang)),
gr.Slider(minimum=0.1, maximum=2.0, value=0.0, step=0.1, label=_("Presence penalty", lang)),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label=_("Length penalty", lang)),
gr.Slider(minimum=1, maximum=10, value=2, step=1, label=_("No repeat ngram size", lang)),
gr.Slider(minimum=0, maximum=10, value=0, step=1, label=_("Num beams", lang)),
gr.Slider(minimum=1, maximum=10, value=1, step=1, label=_("Num return sequences", lang)),
gr.Radio(choices=["txt", "json"], label=_("Select chat history format", lang), value="txt", interactive=True),
gr.HTML(_("<h3>TTS Settings</h3>", lang)),
gr.Dropdown(choices=speaker_wavs_list, label=_("Select voice", lang), interactive=True),
Expand All @@ -8595,6 +8627,8 @@ def reload_interface():
gr.Slider(minimum=0.01, maximum=1.0, value=0.9, step=0.01, label=_("TTS Top P", lang), interactive=True),
gr.Slider(minimum=1, maximum=100, value=20, step=1, label=_("TTS Top K", lang), interactive=True),
gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label=_("TTS Speed", lang), interactive=True),
gr.Slider(minimum=0.1, maximum=2.0, value=2.0, step=0.1, label=_("TTS Repetition penalty", lang), interactive=True),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label=_("TTS Length penalty", lang), interactive=True),
gr.Radio(choices=["wav", "mp3", "ogg"], label=_("Select output format", lang), value="wav", interactive=True)
],
additional_inputs_accordion=gr.Accordion(label=_("LLM and TTS Settings", lang), open=False),
Expand Down Expand Up @@ -8624,6 +8658,9 @@ def reload_interface():
gr.Slider(minimum=0.01, maximum=1.0, value=0.9, step=0.01, label=_("TTS Top P", lang), interactive=True),
gr.Slider(minimum=1, maximum=100, value=20, step=1, label=_("TTS Top K", lang), interactive=True),
gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label=_("TTS Speed", lang), interactive=True),
gr.Slider(minimum=0.1, maximum=2.0, value=2.0, step=0.1, label=_("TTS Repetition penalty", lang),
interactive=True),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label=_("TTS Length penalty", lang), interactive=True),
gr.Radio(choices=["wav", "mp3", "ogg"], label=_("Select TTS output format", lang), value="wav", interactive=True),
gr.Dropdown(choices=["txt", "json"], label=_("Select STT output format", lang), value="txt", interactive=True)
],
Expand Down Expand Up @@ -11184,25 +11221,28 @@ def reload_interface():
folder_button = gr.Button(_("Outputs", lang))

dropdowns_to_update = [
chat_interface.input_components[3],
chat_interface.input_components[4],
chat_interface.input_components[22],
chat_interface.input_components[5],
chat_interface.input_components[38],
tts_stt_interface.input_components[3],
txt2img_interface.input_components[2],
txt2img_interface.input_components[3],
txt2img_interface.input_components[4],
txt2img_interface.input_components[6],
txt2img_interface.input_components[5],
txt2img_interface.input_components[7],
img2img_interface.input_components[5],
img2img_interface.input_components[6],
img2img_interface.input_components[7],
img2img_interface.input_components[9],
img2img_interface.input_components[8],
img2img_interface.input_components[10],
controlnet_interface.input_components[4],
inpaint_interface.input_components[6],
inpaint_interface.input_components[7],
outpaint_interface.input_components[4],
gligen_interface.input_components[5],
animatediff_interface.input_components[5],
sd3_txt2img_interface.input_components[3],
sd3_txt2img_interface.input_components[6],
sd3_img2img_interface.input_components[5],
flux_img2img_interface.input_components[3],
t2i_ip_adapter_interface.input_components[4],
ip_adapter_faceid_interface.input_components[5],
flux_txt2img_interface.input_components[2],
Expand Down
18 changes: 18 additions & 0 deletions translations/ru.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@
"Select LoRA model (optional)": "Выберите модель LoRA (необязательно)",
"LLM and TTS Settings": "LLM и TTS настройки",
"TTS and STT Settings": "TTS и STT настройки",
"Max tokens": "Максимум токенов",
"Min length": "Минимальная длина",
"Context size (N_CTX) for llama type models": "Размер контекста (N_CTX) для моделей типа llama",
"Context batch (N_BATCH) for llama type models": "Размер пакета контекста (N_BATCH) для моделей типа llama",
"Min P": "Минимальное P",
"Typical P": "Типичное P",
"Stop sequences (optional)": "Последовательности остановки (необязательно)",
"TTS Repetition penalty": "TTS Штраф за повторение",
"TTS Length penalty": "TTS Штраф за длину",
"Enable Do Sample": "Включить выборку",
"Enable Early Stopping": "Включить раннюю остановку",
"Repetition penalty": "Штраф за повторение",
"Frequency penalty": "Штраф за частоту",
"Presence penalty": "Штраф за присутствие",
"Length penalty": "Штраф за длину",
"No repeat ngram size": "Размер n-граммы без повторений",
"Num beams": "Количество лучей",
"Num return sequences": "Количество возвращаемых последовательностей",
"SeamlessM4Tv2 Settings": "SeamlessM4Tv2 настройки",
"Select TTS output format": "Выберите формат TTS",
"Select STT output format": "Выберите формат STT",
Expand Down
18 changes: 18 additions & 0 deletions translations/zh.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@
"Select LoRA model (optional)": "选择LoRA模型(可选)",
"LLM and TTS Settings": "LLM和TTS设置",
"TTS and STT Settings": "TTS和STT设置",
"Max tokens": "最大令牌数",
"Min length": "最小长度",
"Context size (N_CTX) for llama type models": "llama类型模型的上下文大小 (N_CTX)",
"Context batch (N_BATCH) for llama type models": "llama类型模型的上下文批次 (N_BATCH)",
"Min P": "最小P值",
"Typical P": "典型P值",
"Stop sequences (optional)": "停止序列(可选)",
"TTS Repetition penalty": "TTS重复惩罚",
"TTS Length penalty": "TTS长度惩罚",
"Enable Do Sample": "启用采样",
"Enable Early Stopping": "启用提前停止",
"Repetition penalty": "重复惩罚",
"Frequency penalty": "频率惩罚",
"Presence penalty": "存在惩罚",
"Length penalty": "长度惩罚",
"No repeat ngram size": "不重复的n元组大小",
"Num beams": "束搜索数量",
"Num return sequences": "返回序列数量",
"SeamlessM4Tv2 Settings": "SeamlessM4Tv2设置",
"Select TTS output format": "选择TTS输出格式",
"Select STT output format": "选择STT输出格式",
Expand Down

0 comments on commit 5ab80fc

Please sign in to comment.