Skip to content

Commit

Permalink
with text-bison
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaan-jaff committed Aug 14, 2023
1 parent 15944eb commit 1b4aadb
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 29 deletions.
10 changes: 8 additions & 2 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,18 @@ def identify(event_details):
'meta-llama/llama-2-70b-chat'
]

vertex_models = [
vertex_chat_models = [
"chat-bison",
"chat-bison@001"
]

model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_models

vertex_text_models = [
"text-bison",
"text-bison@001"
]

model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_chat_models + vertex_text_models

####### EMBEDDING MODELS ###################
open_ai_embedding_models = [
Expand Down
31 changes: 28 additions & 3 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def completion(
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
# Optional liteLLM function params
*, return_async=False, api_key=None, force_timeout=600, logger_fn=None, verbose=False, azure=False, custom_llm_provider=None, custom_api_base=None
*, return_async=False, api_key=None, force_timeout=600, logger_fn=None, verbose=False, azure=False, custom_llm_provider=None, custom_api_base=None,
# model specific optional params
# used by text-bison only
top_k=40,
):
try:
global new_response
Expand All @@ -61,7 +64,7 @@ def completion(
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
# params to identify the model
model=model, custom_llm_provider=custom_llm_provider
model=model, custom_llm_provider=custom_llm_provider, top_k=top_k,
)
# For logging - save the values of the litellm-specific params passed in
litellm_params = get_litellm_params(
Expand Down Expand Up @@ -366,7 +369,7 @@ def completion(
"total_tokens": prompt_tokens + completion_tokens
}
response = model_response
elif model in litellm.vertex_models:
elif model in litellm.vertex_chat_models:
# import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere
install_and_import("vertexai")
import vertexai
Expand All @@ -387,6 +390,28 @@ def completion(
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)

## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
elif model in litellm.vertex_text_models:
# import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere
install_and_import("vertexai")
import vertexai
from vertexai.language_models import TextGenerationModel

vertexai.init(project=litellm.vertex_project, location=litellm.vertex_location)
# vertexai does not use an API key, it looks for credentials.json in the environment

prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
vertex_model = TextGenerationModel.from_pretrained(model)
completion_response= vertex_model.predict(prompt, **optional_params)

## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)

## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
Expand Down
63 changes: 40 additions & 23 deletions litellm/tests/test_vertex.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,49 @@
# import sys, os
# import traceback
# from dotenv import load_dotenv
# load_dotenv()
# import os
# sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
# import pytest
# import litellm
# from litellm import embedding, completion
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion


# litellm.vertex_project = "hardy-device-386718"
# litellm.vertex_location = "us-central1"
litellm.vertex_project = "hardy-device-386718"
litellm.vertex_location = "us-central1"
litellm.set_verbose = True

# user_message = "what's the weather in SF "
# messages = [{ "content": user_message,"role": "user"}]
user_message = "what's the weather in SF "
messages = [{ "content": user_message,"role": "user"}]

# def logger_fn(user_model_dict):
# print(f"user_model_dict: {user_model_dict}")
# chat-bison
# response = completion(model="chat-bison", messages=messages, temperature=0.5, top_p=0.1)
# print(response)

# text-bison

# # chat_model = ChatModel.from_pretrained("chat-bison@001")
# # parameters = {
# # "temperature": 0.2,
# # "max_output_tokens": 256,
# # "top_p": 0.8,
# # "top_k": 40
# # }
# response = completion(model="text-bison@001", messages=messages)
# print(response)

# response = completion(model="text-bison@001", messages=messages, temperature=0.1, logger_fn=logger_fn)
# print(response)

# response = completion(model="text-bison@001", messages=messages, temperature=0.4, top_p=0.1, logger_fn=logger_fn)
# print(response)

# response = completion(model="text-bison@001", messages=messages, temperature=0.8, top_p=0.4, top_k=30, logger_fn=logger_fn)
# print(response)

# chat_model = ChatModel.from_pretrained("chat-bison@001")
# parameters = {
# "temperature": 0.2,
# "max_output_tokens": 256,
# "top_p": 0.8,
# "top_k": 40
# }

# # chat = chat_model.start_chat()
# # response = chat.send_message("who are u? write a sentence", **parameters)
# # print(f"Response from Model: {response.text}")
# chat = chat_model.start_chat()
# response = chat.send_message("who are u? write a sentence", **parameters)
# print(f"Response from Model: {response.text}")
10 changes: 9 additions & 1 deletion litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def get_optional_params(
user = "",
deployment_id = None,
model = None,
custom_llm_provider = ""
custom_llm_provider = "",
top_k = 40,
):
optional_params = {}
if model in litellm.anthropic_models:
Expand Down Expand Up @@ -293,6 +294,13 @@ def get_optional_params(
optional_params["top_p"] = top_p
if max_tokens != float('inf'):
optional_params["max_output_tokens"] = max_tokens
elif model in litellm.vertex_text_models:
# required params for all text vertex calls
# temperature=0.2, top_p=0.1, top_k=20
# always set temperature, top_p, top_k else, text bison fails
optional_params["temperature"] = temperature
optional_params["top_p"] = top_p
optional_params["top_k"] = top_k

else:# assume passing in params for openai/azure openai
if functions != []:
Expand Down

0 comments on commit 1b4aadb

Please sign in to comment.