Skip to content

Commit 0e060fc

Browse files
Refactor HuggingFace model initialization to include base model name and update tokenizer logic (#190)
1 parent 64c8bc9 commit 0e060fc

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,14 @@ class HuggingFaceURLChatModel(HFBaseChatModel):
406406
def __init__(
407407
self,
408408
model_name: str,
409+
base_model_name: str,
409410
model_url: str,
410411
token: Optional[str] = None,
411412
temperature: Optional[int] = 1e-1,
412413
max_new_tokens: Optional[int] = 512,
413414
n_retry_server: Optional[int] = 4,
414415
):
415-
super().__init__(model_name, n_retry_server)
416+
super().__init__(model_name, base_model_name, n_retry_server)
416417
if temperature < 1e-3:
417418
logging.warning("Models might behave weirdly when temperature is too low.")
418419
self.temperature = temperature

src/agentlab/llm/huggingface_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,17 @@ class HFBaseChatModel(AbstractChatModel):
4040
description="The number of times to retry the server if it fails to respond",
4141
)
4242

43-
def __init__(self, model_name, n_retry_server):
43+
def __init__(self, model_name, base_model_name, n_retry_server):
4444
super().__init__()
4545
self.n_retry_server = n_retry_server
4646

47-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
47+
if base_model_name is None:
48+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
49+
else:
50+
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
4851
if isinstance(self.tokenizer, GPT2TokenizerFast):
4952
logging.warning(
50-
f"No chat template is defined for {model_name}. Resolving to the hard-coded templates."
53+
f"No chat template is defined for {base_model_name}. Resolving to the hard-coded templates."
5154
)
5255
self.tokenizer = None
5356
self.prompt_template = get_prompt_template(model_name)

0 commit comments

Comments
 (0)