Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/promptbench into …
Browse files Browse the repository at this point in the history
…main
  • Loading branch information
jindongwang committed Dec 22, 2023
2 parents b6210b3 + 479cf87 commit d40f7d3
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions promptbench/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def __init__(self, model_name, max_new_tokens, temperature=0, system_prompt=None
super(OpenAIModel, self).__init__(model_name, max_new_tokens, temperature)
self.openai_key = openai_key
self.sleep_time = sleep_time
self.system_prompt = system_prompt

if self.temperature > 0:
raise Warning("Temperature is not 0, so that the results may not be reproducable!")
Expand All @@ -283,8 +284,11 @@ def predict(self, input_text, **kwargs):

from openai import OpenAI
client = OpenAI(api_key=self.openai_key)

system_messages = {'role': "system", 'content': "You are a helpful assistant."}

if self.system_prompt is None:
system_messages = {'role': "system", 'content': "You are a helpful assistant."}
else:
system_messages = {'role': "system", 'content': self.system_prompt}

if isinstance(input_text, list):
messages = input_text
Expand Down Expand Up @@ -346,15 +350,15 @@ class PaLMModel(LMMBaseModel):
model_dir : str, optional
The directory containing the model files (default is None).
"""
def __init__(self, model, max_new_tokens, temperature=0, system_prompt=None, palm_key=None, sleep_time=3):
def __init__(self, model, max_new_tokens, temperature=0, system_prompt=None, api_key=None, sleep_time=3):
super(PaLMModel, self).__init__(model, max_new_tokens, temperature)
self.palm_key = palm_key
self.api_key = api_key
self.sleep_time = sleep_time

def predict(self, input_text, **kwargs):
import google.generativeai as palm

palm.configure(api_key=self.palm_key)
palm.configure(api_key=self.api_key)
models = [m for m in palm.list_models() if 'generateText' in m.supported_generation_methods]
model = models[0].name

Expand All @@ -377,7 +381,6 @@ def predict(self, input_text, **kwargs):

return result


class GeminiModel(LMMBaseModel):
"""
Language model class for interfacing with Google's Gemini models.
Expand Down

0 comments on commit d40f7d3

Please sign in to comment.