Skip to content

Commit

Permalink
Minor fixes and addition of quantization config.
Browse files Browse the repository at this point in the history
  • Loading branch information
souradipp76 committed May 22, 2024
1 parent 96b0008 commit a77a5ba
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 15 deletions.
13 changes: 9 additions & 4 deletions doc_generator/index/process_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
github_file_url,
github_folder_url,
)
from doc_generator.utils.llm_utils import models
from doc_generator.utils.llm_utils import models, get_tokenizer
from doc_generator.utils.traverse_file_system import traverse_file_system

from .prompts import (
Expand Down Expand Up @@ -123,9 +123,14 @@ def process_file(process_file_params: ProcessFileParams):
return
assert model is not None

encoding = tiktoken.encoding_for_model(model.name)
summary_length = len(encoding.encode(summary_prompt))
question_length = len(encoding.encode(questions_prompt))
if "gpt" in model.name.lower():
encoding = tiktoken.encoding_for_model(model.name)
summary_length = len(encoding.encode(summary_prompt))
question_length = len(encoding.encode(questions_prompt))
else:
encoding = get_tokenizer(model.name)
summary_length = len(encoding.tokenize(summary_prompt))
question_length = len(encoding.tokenize(questions_prompt))

if not dry_run:
responses = [call_llm(prompt, model.llm) for prompt in prompts]
Expand Down
4 changes: 3 additions & 1 deletion doc_generator/index/select_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ def select_model(prompts: List[str],
get_max_prompt_length(prompts, model_enum):
return models[model_enum]
return None
else:
elif priority == Priority.PERFORMANCE:
for model_enum in [LLMModels.GPT4, LLMModels.GPT432k, LLMModels.GPT3]:
if model_enum in llms:
if models[model_enum].max_length > \
get_max_prompt_length(prompts, model_enum):
return models[model_enum]
return None
else:
return models[llms[0]]
38 changes: 29 additions & 9 deletions doc_generator/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,30 @@
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers import (AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
pipeline)

from doc_generator.types import LLMModelDetails, LLMModels


def get_gemma_chat_model(model_name: str, model_kwargs):
"""Get GEMMA Chat Model"""
# config = AutoConfig.from_pretrained(model_name)
# config.quantization_config["use_exllama"] = False
# config.quantization_config["exllama_config"] = {"version" : 2}
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
# config=config,
quantization_config=bnb_config
)
return HuggingFacePipeline(
pipeline=pipeline(
Expand All @@ -40,16 +47,23 @@ def get_gemma_chat_model(model_name: str, model_kwargs):

def get_llama_chat_model(model_name: str, model_kwargs):
"""Get LLAMA2 Chat Model"""
# config = AutoConfig.from_pretrained(model_name)
# config.quantization_config["use_exllama"] = False
# config.quantization_config["exllama_config"] = {"version" : 2}

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
# use_exllama=False,
# exllama_config={"version": 2}
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
# config=config,
quantization_config=bnb_config
)
return HuggingFacePipeline(
pipeline=pipeline(
Expand Down Expand Up @@ -84,6 +98,12 @@ def get_openai_api_key():
return ""


def get_tokenizer(model_name: str):
"""Get Tokenizer"""
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
return tokenizer


models = {
LLMModels.GPT3: LLMModelDetails(
name=LLMModels.GPT3,
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ markdown2
langchain_openai
langchain_experimental
hnswlib
accelerate
accelerate
bitsandbytes
optimum
auto-gptq
python-magic
Expand Down

0 comments on commit a77a5ba

Please sign in to comment.