diff --git a/promptbench/models/models.py b/promptbench/models/models.py index 4a23344..adc81ed 100644 --- a/promptbench/models/models.py +++ b/promptbench/models/models.py @@ -268,6 +268,7 @@ def __init__(self, model_name, max_new_tokens, temperature=0, system_prompt=None super(OpenAIModel, self).__init__(model_name, max_new_tokens, temperature) self.openai_key = openai_key self.sleep_time = sleep_time + self.system_prompt = system_prompt if self.temperature > 0: raise Warning("Temperature is not 0, so that the results may not be reproducable!") @@ -283,8 +284,11 @@ def predict(self, input_text, **kwargs): from openai import OpenAI client = OpenAI(api_key=self.openai_key) - - system_messages = {'role': "system", 'content': "You are a helpful assistant."} + + if self.system_prompt is None: + system_messages = {'role': "system", 'content': "You are a helpful assistant."} + else: + system_messages = {'role': "system", 'content': self.system_prompt} if isinstance(input_text, list): messages = input_text @@ -346,15 +350,15 @@ class PaLMModel(LMMBaseModel): model_dir : str, optional The directory containing the model files (default is None). """ - def __init__(self, model, max_new_tokens, temperature=0, system_prompt=None, palm_key=None, sleep_time=3): + def __init__(self, model, max_new_tokens, temperature=0, system_prompt=None, api_key=None, sleep_time=3): super(PaLMModel, self).__init__(model, max_new_tokens, temperature) - self.palm_key = palm_key + self.api_key = api_key self.sleep_time = sleep_time def predict(self, input_text, **kwargs): import google.generativeai as palm - palm.configure(api_key=self.palm_key) + palm.configure(api_key=self.api_key) models = [m for m in palm.list_models() if 'generateText' in m.supported_generation_methods] model = models[0].name @@ -377,7 +381,6 @@ def predict(self, input_text, **kwargs): return result - class GeminiModel(LMMBaseModel): """ Language model class for interfacing with Google's Gemini models.