From 7fca0f137f218c8e5976127b7a15262c4aced501 Mon Sep 17 00:00:00 2001 From: Peilin Yu Date: Sun, 21 Jul 2024 16:11:40 -0400 Subject: [PATCH] trigger do_sample automatically based on temperature for huggingface transformers models --- alfred/fm/huggingface.py | 7 +++++-- alfred/fm/utils.py | 10 ++++++++-- docs/alfred/fm/huggingface.md | 10 +++++----- docs/alfred/fm/utils.md | 6 ++++-- 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/alfred/fm/huggingface.py b/alfred/fm/huggingface.py index e4bd5b5..ed14bce 100644 --- a/alfred/fm/huggingface.py +++ b/alfred/fm/huggingface.py @@ -130,7 +130,9 @@ def __init__( if torch.cuda.is_available(): n_gpus = torch.cuda.device_count() - free_in_GB = sum([int(mem / 1024**3) for mem in torch.cuda.mem_get_info()]) + free_in_GB = sum( + [int(mem / 1024**3) for mem in torch.cuda.mem_get_info()] + ) logger.log( logging.INFO, f"Found {n_gpus} GPUs with {free_in_GB}GB free GPU memory" @@ -398,9 +400,10 @@ def _generate_batch( outputs = self.model.generate( inputs.input_ids.to(self.model.device), max_new_tokens=max_new_tokens, - temperature=temprature, + temperature=temprature if temprature != 0 else None, repetition_penalty=repetition_penalty, return_dict_in_generate=True, + do_sample=temprature != 0, ) else: outputs = [ diff --git a/alfred/fm/utils.py b/alfred/fm/utils.py index 4056a64..cfa0987 100644 --- a/alfred/fm/utils.py +++ b/alfred/fm/utils.py @@ -526,13 +526,15 @@ def _process_batch(batch): return batches -def static_batch(queries: Query, batch_size: int = 1024) -> List[List[Query]]: +def static_batch( + queries: Union[Query, str], batch_size: int = 512 +) -> List[List[Query]]: """ Static Batching Utility Batch queries into fixed size batches :param queries: A list of queries to be batched - :type queries: List[Query] + :type queries: Union[Query, str] :param batch_sz: The batch size :type batch_sz: int :return: A list of batches @@ -548,6 +550,10 @@ def static_batch(queries: Query, batch_size: int = 1024) -> List[List[Query]]: _q = query.load()[0] elif isinstance(query, RankedQuery): _q = query.prompt + elif isinstance(query, str): + _q = query + else: + print(f"Unknown query type {type(query)}") batch.append(_q) if len(batch) > 0: batches.append(batch) diff --git a/docs/alfred/fm/huggingface.md b/docs/alfred/fm/huggingface.md index 4a5fecd..6889059 100644 --- a/docs/alfred/fm/huggingface.md +++ b/docs/alfred/fm/huggingface.md @@ -47,7 +47,7 @@ class HuggingFaceModel(LocalAccessFoundationModel): ### HuggingFaceModel()._encode_batch -[Show source in huggingface.py:438](../../../alfred/fm/huggingface.py#L438) +[Show source in huggingface.py:441](../../../alfred/fm/huggingface.py#L441) Encode given batch of instances. @@ -71,7 +71,7 @@ def _encode_batch(self, batch_instance, **kwargs) -> List[torch.Tensor]: ... ### HuggingFaceModel()._generate_batch -[Show source in huggingface.py:348](../../../alfred/fm/huggingface.py#L348) +[Show source in huggingface.py:350](../../../alfred/fm/huggingface.py#L350) Generate completions for a batch of prompts using the model. @@ -114,7 +114,7 @@ def _generate_batch( ### HuggingFaceModel()._get_hidden_states -[Show source in huggingface.py:173](../../../alfred/fm/huggingface.py#L173) +[Show source in huggingface.py:175](../../../alfred/fm/huggingface.py#L175) Get the hidden states of the inputs. For encoder-decoder models (e.g.) T5, this returns the encoder hidden states. @@ -140,7 +140,7 @@ def _get_hidden_states(self, inputs, reduction="mean") -> torch.Tensor: ... ### HuggingFaceModel()._score_batch -[Show source in huggingface.py:212](../../../alfred/fm/huggingface.py#L212) +[Show source in huggingface.py:214](../../../alfred/fm/huggingface.py#L214) Score a batch of prompts and candidates using the model. @@ -180,7 +180,7 @@ def _score_batch( ### HuggingFaceModel().chat -[Show source in huggingface.py:464](../../../alfred/fm/huggingface.py#L464) +[Show source in huggingface.py:467](../../../alfred/fm/huggingface.py#L467) Launch an interactive chat session diff --git a/docs/alfred/fm/utils.md b/docs/alfred/fm/utils.md index a852d9d..294e010 100644 --- a/docs/alfred/fm/utils.md +++ b/docs/alfred/fm/utils.md @@ -380,7 +380,7 @@ Batch queries into fixed size batches #### Arguments - `queries` - A list of queries to be batched -:type queries: List[Query] +:type queries: Union[Query, str] - `batch_sz` - The batch size :type batch_sz: int @@ -392,7 +392,9 @@ Type: *List[List[Query]]* #### Signature ```python -def static_batch(queries: Query, batch_sz: int = 1024) -> List[List[Query]]: ... +def static_batch( + queries: Union[Query, str], batch_size: int = 512 +) -> List[List[Query]]: ... ```