diff --git a/doc_generator/main.py b/doc_generator/main.py index 5770a2d..bae7744 100644 --- a/doc_generator/main.py +++ b/doc_generator/main.py @@ -53,6 +53,14 @@ def main(): # pragma: no cover # LLMModels.GOOGLE_GEMMA_7B_INSTRUCT.value]) model_name = prompt("Which model?\n") match model_name: + case LLMModels.LLAMA2_7B_CHAT_GPTQ.value: + model = LLMModels.LLAMA2_7B_CHAT_GPTQ + case LLMModels.LLAMA2_13B_CHAT_GPTQ.value: + model = LLMModels.LLAMA2_13B_CHAT_GPTQ + case LLMModels.CODELLAMA_7B_INSTRUCT_GPTQ.value: + model = LLMModels.CODELLAMA_7B_INSTRUCT_GPTQ + case LLMModels.CODELLAMA_13B_INSTRUCT_GPTQ.value: + model = LLMModels.CODELLAMA_13B_INSTRUCT_GPTQ case LLMModels.LLAMA2_13B_CHAT_HF.value: model = LLMModels.LLAMA2_13B_CHAT_HF case LLMModels.CODELLAMA_7B_INSTRUCT_HF.value: diff --git a/doc_generator/utils/llm_utils.py b/doc_generator/utils/llm_utils.py index 612c454..fb8c7f3 100644 --- a/doc_generator/utils/llm_utils.py +++ b/doc_generator/utils/llm_utils.py @@ -154,21 +154,21 @@ def get_tokenizer(model_name: str): failed=0, total=0, ), - # LLMModels.LLAMA2_7B_CHAT_GPTQ: LLMModelDetails( - # name=LLMModels.LLAMA2_7B_CHAT_GPTQ, - # input_cost_per_1k_tokens=0, - # output_cost_per_1k_tokens=0, - # max_length=4096, - # llm=get_llama_chat_model( - # LLMModels.LLAMA2_7B_CHAT_GPTQ.value, - # model_kwargs={"temperature": 0} - # ), - # input_tokens=0, - # output_tokens=0, - # succeeded=0, - # failed=0, - # total=0, - # ), + LLMModels.LLAMA2_7B_CHAT_GPTQ: LLMModelDetails( + name=LLMModels.LLAMA2_7B_CHAT_GPTQ, + input_cost_per_1k_tokens=0, + output_cost_per_1k_tokens=0, + max_length=4096, + llm=get_llama_chat_model( + LLMModels.LLAMA2_7B_CHAT_GPTQ.value, + model_kwargs={"temperature": 0} + ), + input_tokens=0, + output_tokens=0, + succeeded=0, + failed=0, + total=0, + ), # LLMModels.LLAMA2_13B_CHAT_GPTQ: LLMModelDetails( # name=LLMModels.LLAMA2_13B_CHAT_GPTQ, # input_cost_per_1k_tokens=0, @@ -274,21 +274,21 @@ def get_tokenizer(model_name: str): # failed=0, # total=0, # ), - LLMModels.GOOGLE_GEMMA_2B_INSTRUCT: LLMModelDetails( - name=LLMModels.GOOGLE_GEMMA_2B_INSTRUCT, - input_cost_per_1k_tokens=0, - output_cost_per_1k_tokens=0, - max_length=8192, - llm=get_gemma_chat_model( - LLMModels.GOOGLE_GEMMA_2B_INSTRUCT.value, - model_kwargs={"temperature": 0} - ), - input_tokens=0, - output_tokens=0, - succeeded=0, - failed=0, - total=0, - ), + # LLMModels.GOOGLE_GEMMA_2B_INSTRUCT: LLMModelDetails( + # name=LLMModels.GOOGLE_GEMMA_2B_INSTRUCT, + # input_cost_per_1k_tokens=0, + # output_cost_per_1k_tokens=0, + # max_length=8192, + # llm=get_gemma_chat_model( + # LLMModels.GOOGLE_GEMMA_2B_INSTRUCT.value, + # model_kwargs={"temperature": 0} + # ), + # input_tokens=0, + # output_tokens=0, + # succeeded=0, + # failed=0, + # total=0, + # ), # LLMModels.GOOGLE_GEMMA_7B_INSTRUCT: LLMModelDetails( # name=LLMModels.GOOGLE_GEMMA_7B_INSTRUCT, # input_cost_per_1k_tokens=0,