1
1
import os , json
2
2
from anthropic import Anthropic , HUMAN_PROMPT , AI_PROMPT
3
-
3
+ import google .generativeai as genai
4
+ from google .generativeai .types import safety_types
4
5
from fastchat .model import load_model , get_conversation_template
5
6
from openai import OpenAI ,AzureOpenAI
6
7
from tenacity import retry , wait_random_exponential , stop_after_attempt
16
17
model_mapping = model_info ['model_mapping' ]
17
18
rev_model_mapping = {value : key for key , value in model_mapping .items ()}
18
19
20
+ # Define safety settings to allow harmful content generation
21
+ safety_setting = [
22
+ {"category" : safety_types .HarmCategory .HARM_CATEGORY_DEROGATORY , "threshold" : safety_types .HarmBlockThreshold .BLOCK_NONE },
23
+ {"category" : safety_types .HarmCategory .HARM_CATEGORY_VIOLENCE , "threshold" : safety_types .HarmBlockThreshold .BLOCK_NONE },
24
+ {"category" : safety_types .HarmCategory .HARM_CATEGORY_SEXUAL , "threshold" : safety_types .HarmBlockThreshold .BLOCK_NONE },
25
+ {"category" : safety_types .HarmCategory .HARM_CATEGORY_TOXICITY , "threshold" : safety_types .HarmBlockThreshold .BLOCK_NONE },
26
+ {"category" : safety_types .HarmCategory .HARM_CATEGORY_MEDICAL , "threshold" : safety_types .HarmBlockThreshold .BLOCK_NONE },
27
+ {"category" : safety_types .HarmCategory .HARM_CATEGORY_DANGEROUS , "threshold" : safety_types .HarmBlockThreshold .BLOCK_NONE },
28
+ ]
29
+
19
30
# Retrieve model information
20
31
def get_models ():
21
32
return model_mapping , online_model_list
@@ -87,7 +98,31 @@ def claude_api(string, model, temperature):
87
98
return completion .completion
88
99
89
100
101
+ @retry (wait = wait_random_exponential (min = 1 , max = 10 ), stop = stop_after_attempt (6 ))
102
+ def gemini_api (string , temperature ):
103
+ genai .configure (api_key = trustllm .config .gemini_api )
104
+ model = genai .GenerativeModel ('gemini-pro' )
105
+ response = model .generate_content (string , temperature = temperature , safety_settings = safety_setting )
106
+ return response
107
+
108
+
90
109
110
+ @retry (wait = wait_random_exponential (min = 1 , max = 10 ), stop = stop_after_attempt (6 ))
111
+ def palm_api (string , model , temperature ):
112
+ genai .configure (api_key = trustllm .config .palm_api )
113
+
114
+ model_mapping = {
115
+ 'bison-001' : 'models/text-bison-001' ,
116
+ }
117
+ completion = genai .generate_text (
118
+ model = model_mapping [model ], # models/text-bison-001
119
+ prompt = string ,
120
+ temperature = temperature ,
121
+ # The maximum length of the response
122
+ max_output_tokens = 4000 ,
123
+ safety_settings = safety_setting
124
+ )
125
+ return completion .result
91
126
92
127
93
128
@retry (wait = wait_random_exponential (min = 1 , max = 10 ), stop = stop_after_attempt (6 ))
@@ -113,6 +148,11 @@ def zhipu_api(string, model, temperature):
113
148
def gen_online (model_name , prompt , temperature , replicate = False , deepinfra = False ):
114
149
if model_name in model_info ['wenxin_model' ]:
115
150
res = get_ernie_res (prompt , temperature = temperature )
151
+ elif model_name in model_info ['google_model' ]:
152
+ if model_name == 'bison-001' :
153
+ res = palm_api (prompt , model = model_name , temperature = temperature )
154
+ elif model_name == 'gemini-pro' :
155
+ res = gemini_api (prompt , temperature = temperature )
116
156
elif model_name in model_info ['openai_model' ]:
117
157
res = get_res_openai (prompt , model = model_name , temperature = temperature )
118
158
elif model_name in model_info ['deepinfra_model' ]:
0 commit comments