Skip to content

Commit

Permalink
Merge pull request #80 from BatsResearch/modularize_static_batching
Browse files Browse the repository at this point in the history
trigger do_sample automatically based on temperature for huggingface …
  • Loading branch information
dotpyu authored Jul 21, 2024
2 parents 8abca96 + 7fca0f1 commit 8847315
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 11 deletions.
7 changes: 5 additions & 2 deletions alfred/fm/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def __init__(

if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
free_in_GB = sum([int(mem / 1024**3) for mem in torch.cuda.mem_get_info()])
free_in_GB = sum(
[int(mem / 1024**3) for mem in torch.cuda.mem_get_info()]
)

logger.log(
logging.INFO, f"Found {n_gpus} GPUs with {free_in_GB}GB free GPU memory"
Expand Down Expand Up @@ -398,9 +400,10 @@ def _generate_batch(
outputs = self.model.generate(
inputs.input_ids.to(self.model.device),
max_new_tokens=max_new_tokens,
temperature=temprature,
temperature=temprature if temprature != 0 else None,
repetition_penalty=repetition_penalty,
return_dict_in_generate=True,
do_sample=temprature != 0,
)
else:
outputs = [
Expand Down
10 changes: 8 additions & 2 deletions alfred/fm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,13 +526,15 @@ def _process_batch(batch):
return batches


def static_batch(queries: Query, batch_size: int = 1024) -> List[List[Query]]:
def static_batch(
queries: Union[Query, str], batch_size: int = 512
) -> List[List[Query]]:
"""
Static Batching Utility
Batch queries into fixed size batches
:param queries: A list of queries to be batched
:type queries: List[Query]
:type queries: Union[Query, str]
:param batch_sz: The batch size
:type batch_sz: int
:return: A list of batches
Expand All @@ -548,6 +550,10 @@ def static_batch(queries: Query, batch_size: int = 1024) -> List[List[Query]]:
_q = query.load()[0]
elif isinstance(query, RankedQuery):
_q = query.prompt
elif isinstance(query, str):
_q = query
else:
print(f"Unknown query type {type(query)}")
batch.append(_q)
if len(batch) > 0:
batches.append(batch)
Expand Down
10 changes: 5 additions & 5 deletions docs/alfred/fm/huggingface.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class HuggingFaceModel(LocalAccessFoundationModel):

### HuggingFaceModel()._encode_batch

[Show source in huggingface.py:438](../../../alfred/fm/huggingface.py#L438)
[Show source in huggingface.py:441](../../../alfred/fm/huggingface.py#L441)

Encode given batch of instances.

Expand All @@ -71,7 +71,7 @@ def _encode_batch(self, batch_instance, **kwargs) -> List[torch.Tensor]: ...

### HuggingFaceModel()._generate_batch

[Show source in huggingface.py:348](../../../alfred/fm/huggingface.py#L348)
[Show source in huggingface.py:350](../../../alfred/fm/huggingface.py#L350)

Generate completions for a batch of prompts using the model.

Expand Down Expand Up @@ -114,7 +114,7 @@ def _generate_batch(

### HuggingFaceModel()._get_hidden_states

[Show source in huggingface.py:173](../../../alfred/fm/huggingface.py#L173)
[Show source in huggingface.py:175](../../../alfred/fm/huggingface.py#L175)

Get the hidden states of the inputs.
For encoder-decoder models (e.g.) T5, this returns the encoder hidden states.
Expand All @@ -140,7 +140,7 @@ def _get_hidden_states(self, inputs, reduction="mean") -> torch.Tensor: ...

### HuggingFaceModel()._score_batch

[Show source in huggingface.py:212](../../../alfred/fm/huggingface.py#L212)
[Show source in huggingface.py:214](../../../alfred/fm/huggingface.py#L214)

Score a batch of prompts and candidates using the model.

Expand Down Expand Up @@ -180,7 +180,7 @@ def _score_batch(

### HuggingFaceModel().chat

[Show source in huggingface.py:464](../../../alfred/fm/huggingface.py#L464)
[Show source in huggingface.py:467](../../../alfred/fm/huggingface.py#L467)

Launch an interactive chat session

Expand Down
6 changes: 4 additions & 2 deletions docs/alfred/fm/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ Batch queries into fixed size batches
#### Arguments

- `queries` - A list of queries to be batched
:type queries: List[Query]
:type queries: Union[Query, str]
- `batch_sz` - The batch size
:type batch_sz: int

Expand All @@ -392,7 +392,9 @@ Type: *List[List[Query]]*
#### Signature

```python
def static_batch(queries: Query, batch_sz: int = 1024) -> List[List[Query]]: ...
def static_batch(
queries: Union[Query, str], batch_size: int = 512
) -> List[List[Query]]: ...
```


Expand Down

0 comments on commit 8847315

Please sign in to comment.