Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove transformers and make torch optional #8

Merged
merged 2 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 9 additions & 57 deletions ldp/graph/async_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
from typing import Any
from uuid import UUID, uuid4

import torch
from torch import nn
from torch.nn.functional import pad
from torch.utils.data import default_collate
from transformers.generation.utils import GenerateDecoderOnlyOutput
try:
import torch
from torch import nn
from torch.utils.data import default_collate
except ImportError:
raise ImportError(
"ldp.graph.async_torch requires PyTorch as a dependency. "
"Please run `pip install ldp[nn]`."
) from None

_TORCH_LOCK = asyncio.Lock()

Expand Down Expand Up @@ -168,55 +172,3 @@ async def _batched_call(self):
request_ids = [x[1] for x in batch]
results = self.decollate_fn(batched_results)
self._result_buffer.update(zip(request_ids, results, strict=True))

@staticmethod
def collate_fn_transformers_model(
samples: list[dict[str, torch.Tensor]], agg_keys: set[str] | None = None
) -> dict[str, torch.Tensor]:
"""Collates and pads a batch of samples for input into a huggingface transformer model."""
if agg_keys is None:
agg_keys = {"input_ids", "attention_mask"}
seq_lens = [inp["input_ids"].shape[1] for inp in samples]
max_seq_len = max(seq_lens)
n_pads = [max_seq_len - seq_len for seq_len in seq_lens]

batch = {
key: torch.cat(
[
pad(inp[key], (0, n_pad), value=0)
for inp, n_pad in zip(samples, n_pads, strict=True)
],
dim=0,
)
for key in agg_keys
}

# Treating other keys as constant kwargs params for the model
other_keys = set(samples[0].keys()) - agg_keys
for key in other_keys:
for sample in samples:
if key not in sample:
raise ValueError(f"Missing key {key} in sample.")
if key in batch and batch[key] != sample[key]:
raise ValueError(
f"Constant kwarg key {key} has different values within batch."
)
batch[key] = sample[key]

return batch

@staticmethod
def decollate_fn_transformers_decoder(
batched_output: GenerateDecoderOnlyOutput,
) -> list[GenerateDecoderOnlyOutput]:
"""Decollates a batched output from a huggingface transformer decoder."""
batch_size = batched_output.sequences.size(0)

return [
GenerateDecoderOnlyOutput({
"sequences": batched_output.sequences[i][None],
"scores": [v[i][None] for v in batched_output.scores],
# Ignore other keys for now
})
for i in range(batch_size)
]
10 changes: 8 additions & 2 deletions ldp/graph/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
from collections.abc import Mapping, Sequence
from typing import Any, ClassVar

import torch
from torch import nn
try:
import torch
from torch import nn
except ImportError:
raise ImportError(
"ldp.graph.torch_ops requires PyTorch as a dependency. "
"Please run `pip install ldp[nn]`."
) from None

from ldp.graph.async_torch import async_protect_torch_call
from ldp.graph.op_utils import CallID, get_call_id, get_training_mode
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ dependencies = [
"pydantic~=2.0",
"tenacity",
"tiktoken",
"torch",
"tqdm",
"transformers",
"typing-extensions; python_version <= '3.11'", # for typing.override
"usearch>=2.13", # For py.typed
]
Expand All @@ -42,6 +40,9 @@ requires-python = ">=3.11"
monitor = [
"wandb",
]
nn = [
"torch>=2.2",
]
server = [
"fastapi>=0.109", # For Python 3.12 support
]
Expand Down Expand Up @@ -137,7 +138,6 @@ module = [
"litellm", # SEE: https://github.com/BerriAI/litellm/issues/825
"networkx", # SEE: https://github.com/networkx/networkx/issues/3988
"pydot",
"transformers.*", # SEE: https://github.com/huggingface/transformers/pull/18485
"tree", # SEE: https://github.com/google-deepmind/tree/issues/84
]

Expand Down Expand Up @@ -398,7 +398,7 @@ dev-dependencies = [
"codeflash",
"fhaviary[xml]",
"ipython>=8", # Pin to keep recent
"ldp[monitor,server,typing,visualization]",
"ldp[monitor,nn,server,typing,visualization]",
"litellm>=1.40.9,<=1.40.12", # Pin lower for get_supported_openai_params not requiring custom LLM, upper for https://github.com/BerriAI/litellm/issues/4032
"mypy>=1.8", # Pin for mutable-override
"pre-commit~=3.4", # Pin to keep recent
Expand Down
76 changes: 7 additions & 69 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.