@@ -45,16 +45,19 @@ def __init__(
45
45
self ,
46
46
model_name : str = "gpt-3.5-turbo" ,
47
47
max_tokens : int | None = None ,
48
+ api_base : str | None = None ,
48
49
):
49
50
"""Initialize APIAgent with model_name and max_tokens.
50
51
51
52
Args:
52
53
model_name: Name fo the model to use (by default, gpt-3.5-turbo).
53
54
max_tokens: The maximum number of tokens to generate. Defaults to the max
54
55
value for the model if available through litellm.
56
+ api_base: Custom endpoint for Hugging Face's inference API.
55
57
"""
56
58
self .model_name = model_name
57
59
self .max_tokens = max_tokens
60
+ self .api_base = api_base
58
61
if max_tokens is None :
59
62
try :
60
63
self .max_tokens = litellm .utils .get_max_tokens (model_name )
@@ -99,6 +102,7 @@ def generate_one_completion(
99
102
messages = [
100
103
{"role" : "user" , "content" : f"{ prompt } " },
101
104
],
105
+ api_base = self .api_base ,
102
106
temperature = temperature ,
103
107
presence_penalty = presence_penalty ,
104
108
frequency_penalty = frequency_penalty ,
@@ -144,6 +148,7 @@ async def _throttled_completion_acreate(
144
148
return await acompletion (
145
149
model = model ,
146
150
messages = messages ,
151
+ api_base = self .api_base ,
147
152
temperature = temperature ,
148
153
max_tokens = max_tokens ,
149
154
n = n ,
0 commit comments