Skip to content

Commit 1d4df44

Browse files
committed
Revert "update"
This reverts commit 0dae8ed.
1 parent 0dae8ed commit 1d4df44

File tree

4 files changed

+49
-2
lines changed

4 files changed

+49
-2
lines changed

docs/guides/generation_details.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ config.claude_api = "claude api"
7070

7171
config.openai_key = "openai api"
7272

73+
config.palm_api = "palm api"
74+
7375
config.ernie_client_id = "ernie client id"
7476

7577
config.ernie_client_secret = "ernie client secret"

trustllm_pkg/setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
'python-dotenv',
2626
'urllib3',
2727
'anthropic',
28+
'google.generativeai',
2829
'google-api-python-client',
30+
'google.ai.generativelanguage',
2931
'replicate',
3032
'zhipuai>=2.0.1'
3133
],

trustllm_pkg/trustllm/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
deepinfra_api = None
1010
ernie_api = None
1111
claude_api = None
12+
palm_api = None
1213
replicate_api = None
1314
zhipu_api = None
1415

@@ -37,17 +38,19 @@
3738
zhipu_model = ["glm-4", "glm-3-turbo"]
3839
claude_model = ["claude-2", "claude-instant-1"]
3940
openai_model = ["chatgpt", "gpt-4"]
41+
google_model = ["bison-001", "gemini"]
4042
wenxin_model = ["ernie"]
4143
replicate_model=["vicuna-7b","vicuna-13b","vicuna-33b","chatglm3-6b","llama3-70b","llama3-8b"]
4244

43-
online_model = deepinfra_model + zhipu_model + claude_model + openai_model + wenxin_model+replicate_model
45+
online_model = deepinfra_model + zhipu_model + claude_model + openai_model + google_model + wenxin_model+replicate_model
4446

4547
model_info = {
4648
"online_model": online_model,
4749
"zhipu_model": zhipu_model,
4850
"deepinfra_model": deepinfra_model,
4951
'claude_model': claude_model,
5052
'openai_model': openai_model,
53+
'google_model': google_model,
5154
'wenxin_model': wenxin_model,
5255
'replicate_model':replicate_model,
5356
"model_mapping": {

trustllm_pkg/trustllm/utils/generation_utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os, json
22
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
3-
3+
import google.generativeai as genai
4+
from google.generativeai.types import safety_types
45
from fastchat.model import load_model, get_conversation_template
56
from openai import OpenAI,AzureOpenAI
67
from tenacity import retry, wait_random_exponential, stop_after_attempt
@@ -16,6 +17,16 @@
1617
model_mapping = model_info['model_mapping']
1718
rev_model_mapping = {value: key for key, value in model_mapping.items()}
1819

20+
# Define safety settings to allow harmful content generation
21+
safety_setting = [
22+
{"category": safety_types.HarmCategory.HARM_CATEGORY_DEROGATORY, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
23+
{"category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
24+
{"category": safety_types.HarmCategory.HARM_CATEGORY_SEXUAL, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
25+
{"category": safety_types.HarmCategory.HARM_CATEGORY_TOXICITY, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
26+
{"category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
27+
{"category": safety_types.HarmCategory.HARM_CATEGORY_DANGEROUS, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
28+
]
29+
1930
# Retrieve model information
2031
def get_models():
2132
return model_mapping, online_model_list
@@ -87,7 +98,31 @@ def claude_api(string, model, temperature):
8798
return completion.completion
8899

89100

101+
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6))
102+
def gemini_api(string, temperature):
103+
genai.configure(api_key=trustllm.config.gemini_api)
104+
model = genai.GenerativeModel('gemini-pro')
105+
response = model.generate_content(string, temperature=temperature, safety_settings=safety_setting)
106+
return response
107+
108+
90109

110+
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6))
111+
def palm_api(string, model, temperature):
112+
genai.configure(api_key=trustllm.config.palm_api)
113+
114+
model_mapping = {
115+
'bison-001': 'models/text-bison-001',
116+
}
117+
completion = genai.generate_text(
118+
model=model_mapping[model], # models/text-bison-001
119+
prompt=string,
120+
temperature=temperature,
121+
# The maximum length of the response
122+
max_output_tokens=4000,
123+
safety_settings=safety_setting
124+
)
125+
return completion.result
91126

92127

93128
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6))
@@ -113,6 +148,11 @@ def zhipu_api(string, model, temperature):
113148
def gen_online(model_name, prompt, temperature, replicate=False, deepinfra=False):
114149
if model_name in model_info['wenxin_model']:
115150
res = get_ernie_res(prompt, temperature=temperature)
151+
elif model_name in model_info['google_model']:
152+
if model_name == 'bison-001':
153+
res = palm_api(prompt, model=model_name, temperature=temperature)
154+
elif model_name == 'gemini-pro':
155+
res = gemini_api(prompt, temperature=temperature)
116156
elif model_name in model_info['openai_model']:
117157
res = get_res_openai(prompt, model=model_name, temperature=temperature)
118158
elif model_name in model_info['deepinfra_model']:

0 commit comments

Comments
 (0)