From 1b4aadbb253043378fd58205dc66f4fcc9339976 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 14 Aug 2023 10:33:59 -0700 Subject: [PATCH] with text-bison --- litellm/__init__.py | 10 ++++-- litellm/main.py | 31 ++++++++++++++++-- litellm/tests/test_vertex.py | 63 +++++++++++++++++++++++------------- litellm/utils.py | 10 +++++- 4 files changed, 85 insertions(+), 29 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 4d118bd9aad9..017e7f3515ee 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -100,12 +100,18 @@ def identify(event_details): 'meta-llama/llama-2-70b-chat' ] -vertex_models = [ +vertex_chat_models = [ "chat-bison", "chat-bison@001" ] -model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_models + +vertex_text_models = [ + "text-bison", + "text-bison@001" +] + +model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_chat_models + vertex_text_models ####### EMBEDDING MODELS ################### open_ai_embedding_models = [ diff --git a/litellm/main.py b/litellm/main.py index c047ee0c064a..d8a447995a02 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -47,7 +47,10 @@ def completion( temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'), presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None, # Optional liteLLM function params - *, return_async=False, api_key=None, force_timeout=600, logger_fn=None, verbose=False, azure=False, custom_llm_provider=None, custom_api_base=None + *, return_async=False, api_key=None, force_timeout=600, logger_fn=None, verbose=False, azure=False, custom_llm_provider=None, custom_api_base=None, + # model specific optional params + # used by text-bison only + top_k=40, ): try: global new_response @@ -61,7 +64,7 @@ def completion( temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id, # params to identify the model - model=model, custom_llm_provider=custom_llm_provider + model=model, custom_llm_provider=custom_llm_provider, top_k=top_k, ) # For logging - save the values of the litellm-specific params passed in litellm_params = get_litellm_params( @@ -366,7 +369,7 @@ def completion( "total_tokens": prompt_tokens + completion_tokens } response = model_response - elif model in litellm.vertex_models: + elif model in litellm.vertex_chat_models: # import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere install_and_import("vertexai") import vertexai @@ -387,6 +390,28 @@ def completion( ## LOGGING logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) + ## RESPONSE OBJECT + model_response["choices"][0]["message"]["content"] = completion_response + model_response["created"] = time.time() + model_response["model"] = model + elif model in litellm.vertex_text_models: + # import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere + install_and_import("vertexai") + import vertexai + from vertexai.language_models import TextGenerationModel + + vertexai.init(project=litellm.vertex_project, location=litellm.vertex_location) + # vertexai does not use an API key, it looks for credentials.json in the environment + + prompt = " ".join([message["content"] for message in messages]) + ## LOGGING + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) + vertex_model = TextGenerationModel.from_pretrained(model) + completion_response= vertex_model.predict(prompt, **optional_params) + + ## LOGGING + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) + ## RESPONSE OBJECT model_response["choices"][0]["message"]["content"] = completion_response model_response["created"] = time.time() diff --git a/litellm/tests/test_vertex.py b/litellm/tests/test_vertex.py index 4214c300bccc..742c3423ca25 100644 --- a/litellm/tests/test_vertex.py +++ b/litellm/tests/test_vertex.py @@ -1,32 +1,49 @@ -# import sys, os -# import traceback -# from dotenv import load_dotenv -# load_dotenv() -# import os -# sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path -# import pytest -# import litellm -# from litellm import embedding, completion +import sys, os +import traceback +from dotenv import load_dotenv +load_dotenv() +import os +sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import embedding, completion -# litellm.vertex_project = "hardy-device-386718" -# litellm.vertex_location = "us-central1" +litellm.vertex_project = "hardy-device-386718" +litellm.vertex_location = "us-central1" +litellm.set_verbose = True -# user_message = "what's the weather in SF " -# messages = [{ "content": user_message,"role": "user"}] +user_message = "what's the weather in SF " +messages = [{ "content": user_message,"role": "user"}] +# def logger_fn(user_model_dict): +# print(f"user_model_dict: {user_model_dict}") +# chat-bison # response = completion(model="chat-bison", messages=messages, temperature=0.5, top_p=0.1) # print(response) +# text-bison -# # chat_model = ChatModel.from_pretrained("chat-bison@001") -# # parameters = { -# # "temperature": 0.2, -# # "max_output_tokens": 256, -# # "top_p": 0.8, -# # "top_k": 40 -# # } +# response = completion(model="text-bison@001", messages=messages) +# print(response) + +# response = completion(model="text-bison@001", messages=messages, temperature=0.1, logger_fn=logger_fn) +# print(response) + +# response = completion(model="text-bison@001", messages=messages, temperature=0.4, top_p=0.1, logger_fn=logger_fn) +# print(response) + +# response = completion(model="text-bison@001", messages=messages, temperature=0.8, top_p=0.4, top_k=30, logger_fn=logger_fn) +# print(response) + +# chat_model = ChatModel.from_pretrained("chat-bison@001") +# parameters = { +# "temperature": 0.2, +# "max_output_tokens": 256, +# "top_p": 0.8, +# "top_k": 40 +# } -# # chat = chat_model.start_chat() -# # response = chat.send_message("who are u? write a sentence", **parameters) -# # print(f"Response from Model: {response.text}") \ No newline at end of file +# chat = chat_model.start_chat() +# response = chat.send_message("who are u? write a sentence", **parameters) +# print(f"Response from Model: {response.text}") \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 1e2420170909..6fc85aaa1031 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -246,7 +246,8 @@ def get_optional_params( user = "", deployment_id = None, model = None, - custom_llm_provider = "" + custom_llm_provider = "", + top_k = 40, ): optional_params = {} if model in litellm.anthropic_models: @@ -293,6 +294,13 @@ def get_optional_params( optional_params["top_p"] = top_p if max_tokens != float('inf'): optional_params["max_output_tokens"] = max_tokens + elif model in litellm.vertex_text_models: + # required params for all text vertex calls + # temperature=0.2, top_p=0.1, top_k=20 + # always set temperature, top_p, top_k else, text bison fails + optional_params["temperature"] = temperature + optional_params["top_p"] = top_p + optional_params["top_k"] = top_k else:# assume passing in params for openai/azure openai if functions != []: