Skip to content

Commit

Permalink
separating static batching as a utils function
Browse files Browse the repository at this point in the history
  • Loading branch information
dotpyu committed Jul 8, 2024
1 parent 7c19a4b commit 1486304
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 14 deletions.
10 changes: 3 additions & 7 deletions alfred/fm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .query import Query, RankedQuery, CompletionQuery
from .response import Response, CompletionResponse, RankedResponse
from .utils import DynamicBatcher, clear_cuda_cache, batch_multimodal
from .utils import DynamicBatcher, clear_cuda_cache, batch_multimodal, static_batch

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -131,9 +131,7 @@ def forward(
)
except AttributeError:
if batch_policy == "static":
batched_queries = np.array_split(
queries, max(1, len(queries) // batch_size)
)
batched_queries = static_batch(queries, batch_size=batch_size)
pretokenized = False
elif batch_policy == "dynamic":
if pretokenize:
Expand Down Expand Up @@ -200,9 +198,7 @@ def forward(
clear_cuda_cache()
if batch_policy == "static":
batch_size = int(batch_size * 0.8)
batched_queries = np.array_split(
queries, len(queries) // batch_size
)
batched_queries = static_batch(queries, batch_size=batch_size)
logging.info(f"New batch size: {batch_size}")
elif batch_policy == "dynamic":
DB = DynamicBatcher(
Expand Down
28 changes: 28 additions & 0 deletions alfred/fm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,31 @@ def _process_batch(batch):
clear_cuda_cache()

return batches


def static_batch(queries: Query, batch_sz: int = 1024) -> List[List[Query]]:
"""
Static Batching Utility
Batch queries into fixed size batches
:param queries: A list of queries to be batched
:type queries: List[Query]
:param batch_sz: The batch size
:type batch_sz: int
:return: A list of batches
:rtype: List[List[Query]]
"""
batches = []
batch = []
for query in queries:
if len(batch) == batch_sz:
batches.append(batch)
batch = []
if isinstance(query, CompletionQuery):
_q = query.load()[0]
elif isinstance(query, RankedQuery):
_q = query.prompt
batch.append(_q)
if len(batch) > 0:
batches.append(batch)
return batches
14 changes: 7 additions & 7 deletions docs/alfred/fm/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

## APIAccessFoundationModel

[Show source in model.py:382](../../../alfred/fm/model.py#L382)
[Show source in model.py:378](../../../alfred/fm/model.py#L378)

#### Signature

Expand Down Expand Up @@ -49,7 +49,7 @@ class FoundationModel(abc.ABC): ...

### FoundationModel().__call__

[Show source in model.py:360](../../../alfred/fm/model.py#L360)
[Show source in model.py:356](../../../alfred/fm/model.py#L356)

This function returns the output of the run function when the
model is called as a function. It can be used as model(queries),
Expand Down Expand Up @@ -157,7 +157,7 @@ def _score_batch(

### FoundationModel().encode

[Show source in model.py:277](../../../alfred/fm/model.py#L277)
[Show source in model.py:273](../../../alfred/fm/model.py#L273)

This function is a wrapper around the forward function

Expand Down Expand Up @@ -239,7 +239,7 @@ def forward(

### FoundationModel().generate

[Show source in model.py:226](../../../alfred/fm/model.py#L226)
[Show source in model.py:222](../../../alfred/fm/model.py#L222)

This function is a wrapper around the forward function for running
CompletionQuery objects through the foundation model. It returns a list
Expand Down Expand Up @@ -275,7 +275,7 @@ def generate(

### FoundationModel().run

[Show source in model.py:308](../../../alfred/fm/model.py#L308)
[Show source in model.py:304](../../../alfred/fm/model.py#L304)

This function is the main entry point for users to run queries through the foundation model.
It accepts raw query content and automatically converts it into query objects.
Expand Down Expand Up @@ -308,7 +308,7 @@ def run(

### FoundationModel().score

[Show source in model.py:251](../../../alfred/fm/model.py#L251)
[Show source in model.py:247](../../../alfred/fm/model.py#L247)

This function is a wrapper around the forward function
for running RankedQuery objects through the foundation model.
Expand Down Expand Up @@ -346,7 +346,7 @@ def score(

## LocalAccessFoundationModel

[Show source in model.py:397](../../../alfred/fm/model.py#L397)
[Show source in model.py:393](../../../alfred/fm/model.py#L393)

#### Signature

Expand Down
28 changes: 28 additions & 0 deletions docs/alfred/fm/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- [normalize_logits](#normalize_logits)
- [reorder_array](#reorder_array)
- [retry](#retry)
- [static_batch](#static_batch)
- [tokenize](#tokenize)
- [type_print](#type_print)

Expand Down Expand Up @@ -369,6 +370,33 @@ def retry(num_retries=3, wait_time=0.1, exceptions=(Exception)): ...



## static_batch

[Show source in utils.py:529](../../../alfred/fm/utils.py#L529)

Static Batching Utility
Batch queries into fixed size batches

#### Arguments

- `queries` - A list of queries to be batched
:type queries: List[Query]
- `batch_sz` - The batch size
:type batch_sz: int

#### Returns

A list of batches
Type: *List[List[Query]]*

#### Signature

```python
def static_batch(queries: Query, batch_sz: int = 1024) -> List[List[Query]]: ...
```



## tokenize

[Show source in utils.py:89](../../../alfred/fm/utils.py#L89)
Expand Down

0 comments on commit 1486304

Please sign in to comment.