From ecd19189d52922959b731b1e66b3664b8ac0dcbf Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 22 Sep 2025 17:34:02 -0400 Subject: [PATCH 01/15] clean history --- Dockerfile | 5 +- fast_llm/data/data/gpt/data.py | 18 + fast_llm/data/dataset/gpt/config.py | 13 +- fast_llm/data/dataset/gpt/fim.py | 6 +- fast_llm/data/dataset/gpt/indexed.py | 26 +- fast_llm/data/dataset/gpt/memmap.py | 168 +++++++-- fast_llm/data/dataset/gpt/sampled.py | 167 ++++++++- fast_llm/data/preparator/gpt_memmap/config.py | 17 + .../data/preparator/gpt_memmap/prepare.py | 136 ++++--- fast_llm/data/tokenizer.py | 83 +++-- fast_llm/engine/config_utils/run.py | 2 +- fast_llm/engine/multi_stage/stage.py | 2 +- fast_llm/engine/schedule/config.py | 6 + fast_llm/engine/training/config.py | 2 +- fast_llm/functional/config.py | 14 +- fast_llm/functional/cross_entropy.py | 210 +++++++++-- fast_llm/functional/triton/mlp.py | 6 +- fast_llm/layers/attention/attention.py | 4 +- fast_llm/layers/attention/config.py | 5 + fast_llm/layers/attention/rotary/config.py | 8 + fast_llm/layers/attention/rotary/rotary.py | 70 ++++ fast_llm/layers/language_model/config.py | 6 + fast_llm/layers/language_model/head.py | 116 ++++-- fast_llm/layers/multi_modal/embedding.py | 183 +++++++++ fast_llm/layers/ssm/preprocessing.py | 68 ++++ fast_llm/layers/vision_encoder/adapter.py | 55 +++ fast_llm/layers/vision_encoder/config.py | 181 +++++++++ fast_llm/layers/vision_encoder/patch_conv.py | 62 ++++ .../layers/vision_encoder/preprocessing.py | 281 ++++++++++++++ fast_llm/models/custom/model.py | 70 ++++ fast_llm/models/gpt/config.py | 29 ++ fast_llm/models/gpt/conversion/auto.py | 6 + fast_llm/models/gpt/conversion/config.py | 24 ++ fast_llm/models/gpt/conversion/llava.py | 155 ++++++++ .../models/gpt/conversion/llava_hybrid.py | 40 ++ fast_llm/models/gpt/conversion/pixtral.py | 266 +++++++++++++ fast_llm/models/gpt/llava.py | 0 fast_llm/models/gpt/model.py | 219 ++++++++++- fast_llm/models/gpt/trainer.py | 9 + .../modeling_apriel_hybrid_ssm.py | 66 +++- .../configuration_llava_hybrid.py | 117 ++++++ .../llava_hybrid/modeling_llava_hybrid.py | 132 +++++++ setup.cfg | 15 +- tests/data/common.py | 11 +- tests/data/test_sampling.py | 14 +- tests/layers/test_lm_head.py | 2 - tests/models/test_checkpoint.py | 15 +- tests/test_ssms.py | 349 ++++++++++++++++++ tests/utils/model_configs.py | 72 ++++ 49 files changed, 3307 insertions(+), 224 deletions(-) create mode 100644 fast_llm/layers/multi_modal/embedding.py create mode 100644 fast_llm/layers/ssm/preprocessing.py create mode 100644 fast_llm/layers/vision_encoder/adapter.py create mode 100644 fast_llm/layers/vision_encoder/config.py create mode 100644 fast_llm/layers/vision_encoder/patch_conv.py create mode 100644 fast_llm/layers/vision_encoder/preprocessing.py create mode 100644 fast_llm/models/custom/model.py create mode 100644 fast_llm/models/gpt/conversion/llava.py create mode 100644 fast_llm/models/gpt/conversion/llava_hybrid.py create mode 100644 fast_llm/models/gpt/conversion/pixtral.py create mode 100644 fast_llm/models/gpt/llava.py create mode 100644 fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py create mode 100644 fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py create mode 100644 tests/test_ssms.py diff --git a/Dockerfile b/Dockerfile index 526026fa4..abb0759b9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,8 +29,9 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) +# Using varlen_mamba for variable length sequence support RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" +RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/ @@ -38,7 +39,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" triton==3.1.0 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb59..9df9b9b86 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,6 +32,8 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None + images: list[torch.Tensor] | None = None + image_positions: list[torch.Tensor] | None = None chosen_spans: list[torch.Tensor] | None = None rejected_spans: list[torch.Tensor] | None = None @@ -49,12 +51,28 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + has_images = False + batch_images = [] + for sample in batch: + if sample.images is not None: + batch_images.append([torch.from_numpy(image) for image in sample.images]) + has_images = True + else: + batch_images.append([]) + batch_image_positions = [] + for sample in batch: + if sample.image_positions is not None: + batch_image_positions.append(torch.from_numpy(sample.image_positions)) + else: + batch_image_positions.append([]) return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, chosen_spans=stacked_chosen_spans, rejected_spans=stacked_rejected_spans, + images=batch_images if has_images else None, + image_positions=batch_image_positions if has_images else None, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 656cd7d24..8835480a9 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -72,6 +72,10 @@ class GPTSamplingParameters(SamplingParameters): use_preference_loss_spans: bool = False cross_document_attention: bool = True truncate_documents: bool = True + patch_size: int | None = None + max_image_size: int | None = None + image_break_token: int | None = None + image_end_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 @@ -138,11 +142,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of tokens in the dataset.", hint=FieldHint.optional, ) + num_pixels: int | None = Field( + default=None, + desc="Expected number of pixels in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) + return GPTMemmapDataset( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels + ) @config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 2b2c8b3be..b05b79b24 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -158,9 +158,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer._tokenize(prefix, end=False)], dtype=np.int64) + middle = np.array([*self._tokenizer._tokenize(middle, begin=False, end=False)], dtype=np.int64) + suffix = np.array([*self._tokenizer._tokenize(suffix, begin=False)], dtype=np.int64) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 896229772..b069e36e2 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -30,6 +30,14 @@ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": return GPTSampledIndexedDataset(self, sampling) + @property + @abc.abstractmethod + def has_images(self) -> bool: + """ + Whether the dataset contains images. + This is used to determine whether to use image-related fields in the sampled data. + """ + class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): """ @@ -40,11 +48,16 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] + doc_sizes, im_sizes = self._dataset.get_document_sizes() + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else np.array([]) def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) + @property + def has_images(self) -> bool: + return self._dataset.has_images + class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset @@ -53,8 +66,17 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + # return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + sizes = [dataset.get_document_sizes() for dataset in self._datasets] + return ( + np.concatenate([size[0] for size in sizes]), + np.concatenate([size[1] for size in sizes]) if sizes[0][1] is not None else np.array([]), + ) def get_document_size(self, index: int) -> int: dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + + @property + def has_images(self) -> bool: + return any(dataset.has_images for dataset in self._datasets) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index f39fd56f4..4f62561a8 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,8 +1,10 @@ +import io import pathlib import struct import typing import numpy as np +import PIL.Image from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample @@ -26,32 +28,46 @@ def __init__( prefix: pathlib.Path | str, num_documents: int | None = None, num_tokens: int | None = None, + num_pixels: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix, num_documents, num_tokens, num_pixels) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init( + self, + name: str, + prefix: pathlib.Path | str, + num_documents: int | None, + num_tokens: int | None, + num_pixels: int | None, + ) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) self._has_spans = 0 + self._has_images = 0 self._has_preference_spans = False with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 3: self._has_preference_spans = struct.unpack("= 4: + self._has_images = struct.unpack("= 2: @@ -77,9 +94,8 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._index_bin_buffer, dtype=np.int32, count=self._num_documents, - offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, + offset=offset, ) - span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] for idx in range(self._num_documents): self._spans.append( @@ -87,30 +103,29 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._index_bin_buffer, dtype=np.int32, count=self._num_spans[idx] * 2, - offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, + offset=offset + + self._num_spans.nbytes + + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) - + offset += self._num_spans.nbytes + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize # read preference spans self._chosen_spans = None self._rejected_spans = None if self._has_preference_spans and self._version >= 3: self._chosen_spans = [] self._rejected_spans = [] - chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes for idx in range(self._num_documents): self._chosen_spans.append( np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=2, - offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, + offset=offset + idx * 2 * np.dtype(np.int32).itemsize, ) ) - rejected_span_offset = ( - offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes - ) + rejected_span_offset = offset + np.array(self._chosen_spans).nbytes for idx in range(self._num_documents): self._rejected_spans.append( np.frombuffer( @@ -120,16 +135,53 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, ) ) + offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes + + self._num_pixels = 0 + self._image_sizes = [] + self._image_positions = None + if self._has_images and self._version >= 4: + self._n_images = np.frombuffer( + self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + ) + self._image_sizes = [] + self._image_positions = [] + images_seen = 0 + num_total_images = self._n_images.sum() + for n_images in self._n_images: + self._image_sizes.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images * 2, + offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, + ).reshape(-1, 2) + ) + self._num_pixels += self._image_sizes[-1].prod(axis=1, initial=3).sum() + self._image_positions.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images, + offset=offset + + self._n_images.nbytes + + 2 * num_total_images * np.dtype(np.int32).itemsize + + +images_seen * np.dtype(np.int32).itemsize, + ) + ) + images_seen += n_images self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) + self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) + if num_pixels is not None: + assert self._num_pixels == num_pixels if num_tokens is not None: assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): self._init(*state) @@ -156,6 +208,24 @@ def get( count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) + images = None + image_positions = None + if self._has_images: + image_positions = self._image_positions[idx] + + # Truncations with images are not yet supported, so we get all images from the document + pixels = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.uint8), + count=self._image_sizes[idx].prod(initial=3, axis=1).sum(), + offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + ) + images = [] + start = 0 + for image_size in self._image_sizes[idx]: + n_pixels = image_size.prod(initial=3) + images.append(pixels[start : start + n_pixels].reshape(3, image_size[0], image_size[1])) + start += n_pixels sample_spans = None if use_loss_masking_spans and self._spans is not None: sample_spans = self._spans[idx] @@ -202,6 +272,8 @@ def get( return GPTSample( token_ids=token_ids, + images=images, + image_positions=image_positions, loss_masking_spans=sample_spans, chosen_span=chosen_span, rejected_span=rejected_span, @@ -218,23 +290,31 @@ def __len__(self) -> int: def num_tokens(self) -> int: return self._num_tokens - def get_document_sizes(self) -> np.ndarray: + @property + def has_images(self) -> bool: + return self._has_images + + def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes + return self._document_sizes, self._image_sizes def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item() + return self._document_sizes[index].item(), self._image_sizes[index] if self._has_images else [] @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): # Initialize metadata dtype = None num_documents = 0 - lengths = [] + doc_lengths = [] + n_images = [] + image_sizes = [] + im_positions = [] + total_images = 0 pointers = [] offset = 0 # number of spans for each document @@ -259,10 +339,28 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) + total_im_size = 0 + if document.images: + n_images.append(len(document.images)) + total_images += len(document.images) + for image in document.images: + # assume 3 channels (RGB) for all images + with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + if img.mode != "RGB": + # Convert all images to RGB + img = img.convert("RGB") + pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." + image_sizes.append(np.array(pixels.shape[1:])) + bin_stream.write(pixels.tobytes(order="C")) + total_im_size += pixels.size + im_positions.extend(document.image_positions) + else: + n_images.append(0) # Update metadata doc_length = len(document.token_ids) - lengths.append(doc_length) + doc_lengths.append(doc_length) pointers.append(offset) if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) @@ -271,11 +369,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP chosen_spans.append(document.chosen_span) if document.rejected_span is not None: rejected_spans.append(document.rejected_span) - offset += doc_length * np.dtype(dtype).itemsize + offset += doc_length * np.dtype(dtype).itemsize + total_im_size * np.dtype(np.uint8).itemsize num_documents += 1 # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) + doc_lengths = np.array(doc_lengths, dtype=np.int32) pointers = np.array(pointers, dtype=np.int64) num_spans = np.array(num_spans, dtype=np.int32) if len(spans) > 0: @@ -285,25 +383,37 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) + if total_images: + n_images = np.array(n_images, dtype=np.int32) + image_sizes = np.stack(image_sizes, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) + else: + n_images = np.array([]) + image_sizes = np.array([]) + im_positions = np.array([]) + # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version - # Version 2 optionally adds loss-masking spans + # Version 2 onwards optionally add loss-masking spans # Version 3 optionally adds chosen/rejected spans - idx_stream.write(struct.pack(" 0 else 0)) # Flag to indicate whether preference loss-masking spans are present idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) + # Flag to indicate whether images are present + idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() + document_sizes = torch.from_numpy(document_sizes).to(self._device) + if image_sizes: + image_token_sizes = [] + for i, sizes in enumerate(image_sizes): + image_token_sizes.append( + sum( + get_num_image_tokens( + *get_resize_dims( + *size, + self._parameters.max_image_size, + self._parameters.max_image_size, + self._parameters.patch_size, + ), + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for size in sizes + ) + ) + image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + else: + image_token_sizes = torch.zeros_like(document_sizes) + documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() # Calculate basic stats. if not self._truncate_documents: @@ -143,14 +175,14 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._parameters.sequence_length + 1 + long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1 ignored_documents = long_docs_filter.sum().item() if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + tokens_per_epoch = (document_sizes[~long_docs_filter] + image_token_sizes[~long_docs_filter]).sum().item() if tokens_per_epoch == 0: raise RuntimeError( f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." @@ -193,7 +225,10 @@ def _sample(self) -> None: "num_samples": self._parameters.num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._parameters.sequence_length, + "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, + "image_break_token": self._parameters.image_break_token, + "image_end_token": self._parameters.image_end_token, "config": self._config.to_dict(), } if self._truncate_documents: @@ -294,7 +329,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes, + document_sizes + image_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), @@ -317,6 +352,9 @@ def _sample(self) -> None: document_shuffling.to( dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 ) + ] + + image_token_sizes[ + document_shuffling.to(torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32) ], offset=self._unshuffled_tokens, # TODO: Allowing for max 100% extra tokens for padding, is that enough? @@ -442,6 +480,10 @@ def __getitem__(self, index: int) -> typing.Any: token_ids = [] loss_masking_spans = [] + images = [] + image_positions = [] + image_tokens_added = 0 + text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -449,7 +491,28 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index) + text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) + + resized_image_lengths = [ + get_resize_dims( + *image_length, + self._parameters.max_image_size, + self._parameters.max_image_size, + self._parameters.patch_size, + ) + for image_length in image_lengths + ] + image_sizes = [ + get_num_image_tokens( + *image_length, + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for image_length in resized_image_lengths + ] + image_tokens = sum(image_sizes) + document_size = text_size + image_tokens if not self._truncate_documents: if document_size > self._parameters.sequence_length + 1: @@ -468,21 +531,97 @@ def __getitem__(self, index: int) -> typing.Any: else: # Move on to the next sample. token_count += padding_size + continue + elif document_size + tokens_in_sample == self._parameters.sequence_length + 1: + if token_count + document_size == token_start: + token_count += document_size + document_sampling_index += 1 + continue # Determine if the document belongs to the requested sample. if token_count + document_size > token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) - token_end_index_in_document = min(token_end - token_count, document_size) + token_end_index_in_document = min(token_end - token_count, text_size) sample = self._indexed_dataset.get( document_index, offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) - token_ids.append(sample.token_ids) + start_pos = 0 + has_images = sample.image_positions is not None + if has_images: + sample_token_ids = [] + for idx, im_position in enumerate(sample.image_positions): + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] + if self._parameters.image_break_token is not None: + height, width = resized_image_lengths[idx] + num_patches_h = div(height, self._parameters.patch_size) + num_patches_w = div(width, self._parameters.patch_size) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + # account for break tokens after each row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = self._parameters.image_break_token + # handle the last row separately + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if self._parameters.image_end_token is not None: + image_token_array[last_row_position] = self._parameters.image_end_token + else: + image_token_array[last_row_position] = self._parameters.image_break_token + else: + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if self._parameters.image_end_token is not None: + image_token_array[-1] = self._parameters.image_end_token + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) + image_tokens_added += image_sizes[idx] + start_pos = im_position + # Add the last text segment after the last image + sample_token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(sample_token_ids[-1]) + token_ids.append(np.concatenate(sample_token_ids)) + else: + token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(token_ids[-1]) + if sample.images: + images.append(sample.images) + else: + images.append([]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: + prev_image_tokens = 0 + image_idx = 0 + image_position = ( + sample.image_positions[image_idx] + if has_images and image_idx < len(sample.image_positions) + else float("inf") + ) + while image_position < loss_masking_span[0]: + prev_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if has_images and image_idx < len(sample.image_positions) + else float("inf") + ) + span_image_tokens = 0 + while image_position <= loss_masking_span[1]: + span_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if has_images and image_idx < len(sample.image_positions) + else float("inf") + ) + loss_masking_span[0] += prev_image_tokens + loss_masking_span[1] += prev_image_tokens + span_image_tokens + prev_image_tokens += span_image_tokens span = np.clip( loss_masking_span + token_count - token_start, 0, @@ -506,9 +645,17 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) + images = [im for img_list in images for im in img_list] if images else None + image_positions = np.array(image_positions) if image_positions else None Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + return GPTSample( + token_ids=token_ids, + loss_masking_spans=loss_masking_spans, + sequence_lengths=sequence_lengths, + images=images, + image_positions=image_positions, + ) @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index d2aaee5e2..da353793d 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -42,6 +42,18 @@ class TextColumnConfig(SourceSchemaConfig): ) +@config_class(dynamic_type={SourceSchemaConfig: "text_image_column"}) +class TextImageColumnConfig(TextColumnConfig): + images_column: str = Field( + default="images", + desc="Field containing images relevant to a document.", + ) + image_positions_column: None | str = Field( + default="image_positions", + desc="Field containing image positions within a document.", + ) + + @config_class() class GPTHuggingfaceDatasetConfig(Config): path: str = Field( @@ -175,6 +187,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) + image_patch_size: int = Field( + default=16, + desc="Patch size for images. This is used solely for computing the number of tokens in an image to get an even split.", + hint=FieldHint.optional, + ) splits: dict[str, float] | None = Field( default=None, desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 33c40bf8f..94eede19f 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,3 +1,5 @@ +import io +import itertools import json import logging import multiprocessing @@ -8,6 +10,7 @@ import datasets import huggingface_hub import numpy as np +import PIL.Image import requests import torch.distributed import tqdm @@ -24,7 +27,11 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + GPTMemmapDatasetPreparatorConfig, + TextColumnConfig, + TextImageColumnConfig, +) from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -39,36 +46,44 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _loss_masking_spans_column: str | None def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) for text in batch[self._text_column] - ] - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "num_tokens": num_tokens, - } - - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( + input_ids, token_spans, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2), + np.array(image_token_positions, dtype=np.int32), ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip(batch[self._text_column], batch[self._loss_masking_spans_column]) + for input_ids, token_spans, image_token_positions in [ + self._tokenizer.tokenize( + text, + loss_mask_spans, + im_char_positions, + ) + for text, loss_mask_spans, im_char_positions in zip( + batch[self._text_column], + batch.get(self._loss_masking_spans_column, itertools.repeat(None)), + batch.get(self._image_positions_column, itertools.repeat(None)), + ) ] ] ), ) num_tokens = [len(x) for x in input_ids] + num_pixels = [0] * len(input_ids) + for idx, images in enumerate(batch.get("images", [])): + for bytes_im in images: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_pixels[idx] += width * height * 3 + return { "input_ids": input_ids, + "image_positions": image_token_positions, "token_spans": token_spans, "num_tokens": num_tokens, + "num_pixels": num_pixels, } def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -141,27 +156,22 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - ) - elif ( - "chosen_token_spans" in shard_dataset.column_names - and "rejected_token_spans" in shard_dataset.column_names - and self._config.dataset.chosen_text is not None - and self._config.dataset.rejected_text is not None - ): - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), - chosen_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), - rejected_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + has_preference_spans = ( + self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None + ) + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + np.array(item["input_ids"], dtype=self._data_type.numpy), + item["images"] if self._images_column else None, + item["image_positions"] if self._image_positions_column else None, + ( + np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) + if self._loss_masking_spans_column + else None + ), + item["chosen_token_spans"] if has_preference_spans else None, + item["rejected_token_spans"] if has_preference_spans else None, + ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -171,6 +181,7 @@ def _document_generator(): "path": prefix, "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), + "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), } ) @@ -290,6 +301,11 @@ def run(self) -> None: if isinstance(self._config.dataset.source_schema, TextColumnConfig): self._text_column = self._config.dataset.source_schema.input_column self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column + if isinstance(self._config.dataset.source_schema, TextImageColumnConfig): + self._images_column = self._config.dataset.source_schema.images_column + self._image_positions_column = self._config.dataset.source_schema.image_positions_column + # decoding bytes to images is slow and should be done only when needed + dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) else: raise ValueError( f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'." @@ -298,18 +314,17 @@ def run(self) -> None: if self._text_column not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._text_column}'.") - if self._config.dataset.source_schema.loss_masking_spans_column is not None and ( + if self._loss_masking_spans_column is not None and ( self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None ): - raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: + raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if self._loss_masking_spans_column not in dataset.column_names: + raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") # route tokenize function - if self._loss_masking_spans_column is not None: - if self._loss_masking_spans_column not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") - tokenize_fn = self._tokenize_batch_with_spans elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: if self._config.dataset.chosen_text not in dataset.column_names: raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") @@ -329,6 +344,13 @@ def run(self) -> None: # Calculate total number of tokens total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + total_pixels = ( + sum(tqdm.tqdm(tokenized_dataset["num_pixels"], desc="Counting pixels", unit="pixels")) + if self._images_column + else 0 + ) + # Add the token-equivalent bytes of pixels to determine shard size + total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) @@ -357,7 +379,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa # Create the config file(s) on rank 0 if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, self._config.splits, self._config.output_path, self._config.image_patch_size ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" @@ -397,7 +419,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path + cls, + dataset_configs: list[GPTMemmapDatasetConfig], + splits: dict[str, int | float], + output_path: pathlib.Path, + image_patch_size: None | int = None, ) -> dict[str, GPTSampledDatasetConfig]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] @@ -427,10 +453,20 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + text_sizes, image_sizes = dataset.get_document_sizes() + tokens_cumsum = text_sizes.cumsum() + Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) + if image_sizes: + num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) + # We use the patch sizes only for the purposes of even splitting and blending weights. + # We can always use a different patch size for training without any significant impact + # Unless the patch size used at training time is significantly different from the one used here + image_tokens_cumsum = num_pixels_cumsum // (image_patch_size**2) + tokens_cumsum += image_tokens_cumsum + num_pixels_cumsum = num_pixels_cumsum * 3 + Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) + begin_index = _get_nearest_split(tokens_cumsum, split_begin_in_dataset * tokens_cumsum[-1]) + end_index = _get_nearest_split(tokens_cumsum, split_end_in_dataset * tokens_cumsum[-1]) if end_index > begin_index: datasets_in_split.append( GPTDatasetSliceConfig.from_dict( @@ -443,8 +479,8 @@ def _split_and_blend_dataset_configs( ) ) dataset_tokens_in_split.append( - sizes_cumsum[end_index - 1].item() - - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) + tokens_cumsum[end_index - 1].item() + - (tokens_cumsum[begin_index - 1].item() if begin_index > 0 else 0) ) # [else] None of the dataset belongs to the split. diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c74586207..d46e38935 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -41,44 +41,75 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, begin=True, end=True) -> list[int]: + def _tokenize(self, text: str, begin=True, end=True) -> list[int]: return ( ([self.bod_id] if begin else []) + self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end else []) ) - def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] - ) -> tuple[list[int], list[tuple[int, int]]]: + def tokenize( + self, text: str, char_spans=None, image_positions=None + ) -> tuple[list[int], list[tuple[int, int]], list[int]]: """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. + Tokenize the input text and return the tokenized input_ids, token spans, and image token positions. + This version simplifies logic by merging all relevant positions, sorting, and tokenizing between them. """ - input_ids = [] + if not image_positions: + image_positions = [] + if not char_spans: + char_spans = [] + + # Collect all positions with their type + positions = [] + for pos in image_positions: + positions.append((pos, "image")) + + for start, end in char_spans: + positions.append((start, "span_start")) + positions.append((end + 1, "span_end")) + # Sort positions by character index. We assume that image and span positions are individually sorted and spans do not overlap + positions = sorted(positions, key=lambda x: x[0]) + + token_ids = [] token_spans = [] + image_token_positions = [] char_pos = 0 - beginning_of_text = True + current_span_start = None - for start, end in char_spans: - if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 + for position in positions: + # We only tokenize if there is at least one character, else we might potentially add begin/end multiple times + if char_pos < position[0]: + tokenized_text = self._tokenize( + text[char_pos : position[0]], begin=(char_pos == 0), end=position[0] > len(text) - 1 + ) + token_ids.extend(tokenized_text) + char_pos = position[0] + # beginning_of_text = False + if position[1] == "image": + if position[0] == 0: + # image should be after the bos token + image_token_positions.append(1) + else: + image_token_positions.append(len(token_ids)) + elif position[1] == "span_start": + assert ( + current_span_start is None + ), "Starting a new span before current has ended, please check for overlapping spans" + current_span_start = len(token_ids) + elif position[1] == "span_end": + assert ( + current_span_start is not None + ), "Closing a span that has not started, please check for overlapping spans" + # spans are inclusive, so we take the index of the last token in the span + token_spans.append((current_span_start, len(token_ids) - 1)) + current_span_start = None + # Handle any remaining text after the last position and add EOS token if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans + tokenized_text = self._tokenize(text[char_pos:], begin=(char_pos == 0), end=True) + token_ids.extend(tokenized_text) + + return token_ids, token_spans, image_token_positions def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 1737f4308..1849a2316 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -136,7 +136,7 @@ def __init__( self._distributed.config.data_rank == 0 and self._distributed.config.tensor_rank == 0 ) config_dict = config.to_dict() - config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance) + config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.debug) if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 7829c243b..132bfd383 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -153,7 +153,7 @@ def backward( assert self._mode.support_backward input_, output = grad_context output.backward(output_grad) - return input_.grad + return input_.grad if input_.grad is not None else torch.zeros_like(input_) def restore_parameters(self) -> None: assert self._is_setup diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 272b7c6ae..a5e0a86a6 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -48,6 +48,12 @@ class BatchConfig(Config): desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", hint=FieldHint.setup, ) + # Image inputs + max_image_size: int | None = Field( + default=None, + desc="Maximum image height and width", + hint=FieldHint.optional, + ) def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 531bc206b..809d46801 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -361,7 +361,7 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) + # Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 684193848..5c8d75a6f 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -40,6 +40,7 @@ class ActivationType(enum.StrEnum): """ gelu = "gelu" + gelu_pytorch_tanh = "gelu_pytorch_tanh" silu = "silu" relu = "relu" squared_relu = "squared_relu" @@ -67,7 +68,8 @@ def _set_activation_fn_map() -> None: global _ACTIVATION_FN_MAP _ACTIVATION_FN_MAP = { - ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.gelu: torch.nn.functional.gelu, + ActivationType.gelu_pytorch_tanh: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), @@ -78,7 +80,8 @@ def _set_activation_fn_map() -> None: _ACTIVATION_FN_MAP: dict[ActivationType, typing.Callable[["torch.Tensor"], "torch.Tensor"]] = {} _ACTIVATION_HF_NAMES = { - ActivationType.gelu: "gelu_pytorch_tanh", + ActivationType.gelu: "gelu", + ActivationType.gelu_pytorch_tanh: "gelu_pytorch_tanh", ActivationType.silu: "silu", ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", @@ -86,9 +89,16 @@ def _set_activation_fn_map() -> None: } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} + MAX_DROPLESS_BLOCK_SIZE_ROW = 128 +class ReverseKLImpl(str, enum.Enum): + tp = "tp" + stp = "stp" + no_tp = "no_tp" + + class CrossEntropyImpl(str, enum.Enum): auto = "auto" torch = "torch" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d56dce98d..d9ca547a7 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -1,7 +1,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat +from fast_llm.functional.config import CrossEntropyImpl, ReverseKLImpl, TargetFormat from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.utils import Assert @@ -49,6 +49,19 @@ def _torch_cross_entropy_forward_backward( return loss.detach_(), grad +def distributed_log_softmax(logits: torch.Tensor, group: ProcessGroup, dim: int = -1): + logits = logits.float() + local_max = logits.max(dim=dim, keepdim=True)[0] + all_reduce(local_max, op=ReduceOp.MAX, group=group) + + logits_shifted = logits - local_max + exp_logits = torch.exp(logits_shifted) + sum_exp = exp_logits.sum(dim=dim, keepdim=True) + all_reduce(sum_exp, op=ReduceOp.SUM, group=group) + + return logits_shifted - sum_exp.log() # log_softmax + + @torch.compile def _fused_softmax_base( logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 @@ -151,7 +164,8 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= group.size() return loss, grad @@ -213,20 +227,30 @@ def cross_entropy_forward_backward( ) -def _torch_reverse_kl_forward_backward( +def _torch_reverse_kl_forward_backward_vocab_parallel( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, - logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. - Much simpler and more reliable than custom implementation! + This is used for TP version where we split accross vocab dimantion. + This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. + In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. """ + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") + # TODO: merge into single function _torch_reverse_kl_forward_backward Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype @@ -234,32 +258,78 @@ def _torch_reverse_kl_forward_backward( Assert.eq(loss_mask.shape, logits.shape[:-1]) # Compute log probabilities - let _fused_softmax handle scaling internally - # teacher_probs = _fused_softmax(target, logits_scale_factor * (1 / teacher_softmax_temperature), group) - # # teacher_log_probs = torch.log(teacher_probs + 1e-8) # log(p) - # teacher_probs = torch.clamp(teacher_probs, min=1e-7) # or even 1e-6 - # teacher_log_probs = torch.log(teacher_probs) + teacher_log_probs = distributed_log_softmax(target.float(), group=group) + batch_size = logits.shape[0] + with torch.enable_grad(): + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + student_log_probs = distributed_log_softmax(logits_, group=group) + + # Reverse KL: input=teacher_log_probs, target=student_probs + if loss_mask is None: + loss = torch.nn.functional.kl_div( + teacher_log_probs, # input = log(p) + student_log_probs, # target = log(q) + reduction="sum", + log_target=True, + ) + else: + # Apply loss mask - this requires some reshaping + raise NotImplementedError("Loss mask not implemented with TP for reverse KL , it must be doublechecked") + loss_per_sample = torch.nn.functional.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ).sum(dim=-1) + loss = (loss_per_sample * loss_mask).sum() + + if group is not None and target_format != TargetFormat.labels: + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= batch_size + + if grad_output is not None: + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.to(logits.dtype) + else: + grad = None + + return loss.detach_(), grad + +def _torch_reverse_kl_forward_backward_no_tp( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + teacher_softmax_temperature: float = 1.0, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Reverse KL using PyTorch's native kl_div function. + THis is only used for no-TP case. + """ + Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) # Scale target logits more carefully scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) - # Clamp to prevent extreme values before log_softmax - scaled_target = torch.clamp(scaled_target, min=-50, max=50) - teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) + teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1) # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) # Use kl_div with: input=log(p), target=q, log_target=False # This gives: Σ q * (log(q) - log(p)) = exactly what we want! with torch.enable_grad(): - logits_ = logits.detach().requires_grad_(grad_output is not None) + logits_ = logits.float().detach().requires_grad_(grad_output is not None) - # Use log_softmax for consistency instead of _fused_softmax scaled_logits = logits_ * logits_scale_factor - scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) - student_log_probs = torch.log_softmax(scaled_logits, dim=-1) - - # Convert to probabilities for kl_div - # student_probs_ = torch.exp(student_log_probs) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) + student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1) # Reverse KL: input=teacher_log_probs, target=student_probs if loss_mask is None: @@ -274,12 +344,85 @@ def _torch_reverse_kl_forward_backward( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).mean() + loss = (loss_per_sample * loss_mask).sum() / loss_mask.sum() - if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + if grad_output is not None: + # note, we never get here in TP over seq. dim. + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.to(logits.dtype) + else: + grad = None + + return loss.detach_(), grad + + +def _torch_reverse_kl_forward_backward_sequence_tensor_parallel( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + teacher_softmax_temperature: float = 1.0, + total_valid_tokens: int | None = None, # total number of unmasked tokens in the batch + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Reverse KL using PyTorch's native kl_div function. + THis is only used for sequence-tensor-parallel case where we split over sequence dimension. + """ + Assert.eq( + total_valid_tokens is not None, + msg="Total valid tokens must be provided for sequence-tensor-parallel reverse KL", + ) + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") + Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + # Scale target logits more carefully + scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) + + teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1) + + # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) + # Use kl_div with: input=log(p), target=q, log_target=False + # This gives: Σ q * (log(q) - log(p)) = exactly what we want! + + with torch.enable_grad(): + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + + scaled_logits = logits_ * logits_scale_factor + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) + student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1) + + # Reverse KL: input=teacher_log_probs, target=student_probs + if loss_mask is None: + loss = torch.nn.functional.kl_div( + teacher_log_probs, # input = log(p) + student_log_probs, # target = log(q) + reduction="sum", + log_target=True, + ) + else: + # Apply loss mask - this requires some reshaping + loss_per_sample = torch.nn.functional.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ).sum(dim=-1) + loss = (loss_per_sample * loss_mask).sum() # this can be 0.0 if all tokens are masked if grad_output is not None: + # note, if we compute gradient w.r.t sum of losses, + # and grad_output should reflect the scaling by 1/valid samples loss.backward(torch.full_like(loss, grad_output)) grad = logits_.grad.to(logits.dtype) else: @@ -288,6 +431,13 @@ def _torch_reverse_kl_forward_backward( return loss.detach_(), grad +REVERSE_KL_IMPLEMENTATIONS = { + ReverseKLImpl.no_tp: _torch_reverse_kl_forward_backward_no_tp, + ReverseKLImpl.tp: _torch_reverse_kl_forward_backward_vocab_parallel, + ReverseKLImpl.stp: _torch_reverse_kl_forward_backward_sequence_tensor_parallel, +} + + def reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -297,6 +447,8 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, + reverse_kl_impl: ReverseKLImpl = ReverseKLImpl.no_tp, + total_valid_tokens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -339,7 +491,15 @@ def reverse_kl_forward_backward( assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - # TODO: implement fused? - return _torch_reverse_kl_forward_backward( - logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group, teacher_softmax_temperature + # TODO: implement fused reverse KL? + return REVERSE_KL_IMPLEMENTATIONS[reverse_kl_impl]( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + teacher_softmax_temperature=teacher_softmax_temperature, + group=group, + total_valid_tokens=total_valid_tokens, ) diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index ab408368f..f3d9d7d0c 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -47,8 +47,7 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - # Triton doesn't like enums, so we use str instead of ActivationType. - if activation_type == "gelu": + if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) @@ -98,8 +97,7 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - # Triton doesn't like enums, so we use str instead of ActivationType. - if activation_type == "gelu": + if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 9a940f4cb..9a2be6b49 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -340,7 +340,7 @@ def forward( max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), dropout_p=self._config.dropout if self.training else 0.0, window_size=window_size, - causal=True, + causal=self._config.causal, softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -350,7 +350,7 @@ def forward( value, window_size=window_size, dropout_p=self._config.dropout if self.training else 0.0, - causal=True, + causal=self._config.causal, softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 2910c7c76..924e06057 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -80,6 +80,11 @@ class AttentionConfig(MixerConfig): desc="Add biases to linear layers. May be overridden for individual layers.", hint=FieldHint.architecture, ) + causal: bool = Field( + default=True, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) dropout: float = Field( default=0.0, desc="Dropout applied to the attention intermediate states.", diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 5bd7a9b87..5e24af9a6 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -135,3 +135,11 @@ def _get_configurable_class(self) -> "type[YarnRotary]": from fast_llm.layers.attention.rotary.rotary import YarnRotary return YarnRotary + + +@config_class(dynamic_type={RotaryConfig: "rope_2d"}) +class Rotary2DConfig(DefaultRotaryConfig): + def _get_configurable_class(self) -> "type[Rotary2D]": + from fast_llm.layers.transformer.rotary.rotary import Rotary2D + + return Rotary2D diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 889711839..58a786947 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -13,9 +13,11 @@ DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, + Rotary2DConfig, RotaryConfig, YarnRotaryConfig, ) +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -199,3 +201,71 @@ def _get_correction(self, beta: float, dim: int) -> float: * math.log(self._config.original_context_length / (beta * 2 * math.pi)) / (2 * math.log(self._config.theta)) ) + + +class Rotary2D[ConfigType: DefaultRotaryConfig](DefaultRotary[Rotary2DConfig]): + _rotary_embedding_frequencies: torch.Tensor + _tensor_cache_max_num_patches: int = -1 + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + assert self._tensor_space is not None + max_num_patches = kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size] + self._create_tensors(max_num_patches) + position_ids = kwargs[VisionTransformerKwargs.patch_position_ids] + kwargs[VisionTransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[VisionTransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + assert self._tensor_space is not None + kwargs[VisionTransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + ( + self._scalar_dim, + kwargs[TransformerKwargs.sequence_q_dim], + self._scalar_dim, + self._kv_channels_dim, + ), + tensor_name=VisionTransformerKwargs.rotary_freq_q, + ) + kwargs[VisionTransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + ( + self._scalar_dim, + kwargs[TransformerKwargs.sequence_k_dim], + self._scalar_dim, + self._kv_channels_dim, + ), + tensor_name=VisionTransformerKwargs.rotary_freq_k, + ) + + def _create_tensors(self, max_num_patches: int) -> None: + if max_num_patches <= self._tensor_cache_max_num_patches: + return + self._tensor_cache_max_num_patches = max_num_patches + + self._rotary_embedding_frequencies = self._get_frequencies( + max_num_patches, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) + + def _get_frequencies(self, max_num_patches: int, kv_channels: int, device="cuda") -> torch.Tensor: + # Calculate complex frequencies by using alternating channels for width and height + height_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + width_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + frequencies = self._config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + angles_h = torch.outer(height_positions, frequencies[::2]) + angles_w = torch.outer(width_positions, frequencies[1::2]) + angles = torch.cat( + [ + angles_h[:, None, :].repeat(1, max_num_patches, 1), + angles_w[None, :, :].repeat(max_num_patches, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not self._config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + + return frequencies diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index f59b4cffd..5bd35eb71 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -9,6 +9,7 @@ from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -34,6 +35,7 @@ class LanguageModelKwargs(BlockKwargs): position_ids = "position_ids" # TODO: These are generic labels = "labels" + tokens = "tokens" phase = "phase" chosen_spans = "chosen_spans" rejected_spans = "rejected_spans" @@ -48,6 +50,10 @@ class LanguageModelEmbeddingsConfig(BlockConfig): desc="Configuration for the word embedding (weight).", hint=FieldHint.architecture, ) + vision_encoder: VisionEncoderConfig = Field( + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) position_embeddings: OptionalParameterConfig = Field( desc="Configuration for the word embedding (weight).", hint=FieldHint.architecture, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index ade1144d2..eb8302025 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -11,7 +11,13 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig +from fast_llm.functional.config import ( + CrossEntropyImpl, + DistillationLossImpl, + ReverseKLImpl, + TargetFormat, + TritonConfig, +) from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward @@ -242,10 +248,24 @@ def _get_targets( ).flatten() else: lm_target = None - - targets = (dpo_target, lm_target, distillation_target, loss_mask) - if self._sequence_parallel_logits: - targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] + targets = (dpo_target, lm_target, distillation_target) + # If we do distillation, no need to split it here as it has already been split in the embedding layer! + # if we do CPT/language modeling, we need to split the targets here! + if ( + self._config.distillation_model is not None + and self._sequence_parallel_logits + and not self._parallel_embeddings + and not self._sequence_parallel + ) or (self._config.distillation_model is None and self._sequence_parallel_logits): + # We dont split targets if they already have been split in the embedding layer! + targets = [ + None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) + for target in targets + ] + # Loss mask may need to be split. It was not split in the embedding layer as it is not used there. + if loss_mask is not None and self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._tensor_space.distributed.tensor_group, 0) + targets = (*targets, loss_mask) if not any(target is not None for target in targets): # Simplify so we don't have to check every time. targets = None @@ -305,14 +325,13 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad_.copy_(grad_) loss = loss_ if loss is None else loss + loss_ del grad_, loss_ - loss_count = (self._config.cross_entropy_splits or 1) * ( - self._parallel_dim.size if self._sequence_parallel_logits else 1 - ) - if loss_count != 1: - loss.div_(loss_count) - if self._sequence_parallel_logits: - # TODO: Async - all_reduce(loss, group=self._parallel_dim.group) + assert self._cross_entropy_splits is None, "This is not supported for now" + # loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) + # if loss_count != 1: + # loss.div_(loss_count) + # if self._sequence_parallel_logits: + # # TODO: Async + # all_reduce(loss, group=self._tensor_space.distributed.tensor_group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None def _logits_cross_entropy_forward_backward( @@ -343,13 +362,34 @@ def _logits_cross_entropy_forward_backward( LanguageModelLossNames.z_loss, logits_scale_factor=self._config.logits_scale_factor, ) - if self._debug.enabled and self._config.cross_entropy_splits is None: - sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q - batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] - dims = ( - (sequence_dim, batch_dim, self._vocab_dim) - if kwargs[LanguageModelKwargs.sequence_first] - else (batch_dim, sequence_dim, self._vocab_dim) + if self._debug_transformer and self._cross_entropy_splits is None: + vocab_dim = self._tensor_space[ + LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp + ] + dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] + sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) + dims[sequence_index] = ( + TensorDim( + TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor + ) + if self._sequence_parallel_logits + else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) + ) + + dim_names = ( + [TransformerDimNames.sequence_q_tp, LanguageModelDimNames.vocab] + if self._sequence_parallel_logits + else [TransformerDimNames.sequence_q, LanguageModelDimNames.vocab_tp] + ) + + dim_names.insert(int(kwargs[TransformerKwargs.sequence_first]), TransformerDimNames.batch) + log_distributed_tensor( + "", + logits, + level=self._debug_transformer, + meta=TensorMeta.from_dims(tuple(dims), tensor_name="transformer logits", dtype=logits.dtype), + distributed=self._tensor_space.distributed, + scale=self._logits_scale_factor, ) self._debug(logits, "Language model logits", dims, kwargs, scale=self._config.logits_scale_factor) @@ -385,8 +425,31 @@ def _logits_cross_entropy_forward_backward( else: lm_loss, lm_grad = None, None - if distillation_target is not None and self._config.distillation_loss_factor > 0.0: - if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: + if distillation_target is not None and self._distillation_loss_factor > 0.0: + if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: + local_valid_tokens = total_valid_tokens = logits.shape[0] + if logits.shape[-1] != self._config.vocab_size: + reverse_kl_impl = ReverseKLImpl.tp + assert loss_mask is None, "Loss mask is not implemented for TP (vocab dim) reverse KL yet" + elif self._sequence_parallel_logits: + # grad_output already reflects scaling 1/ number of ranks (group_size), see _forward_backward + reverse_kl_impl = ReverseKLImpl.stp + if loss_mask is not None: + local_valid_tokens = loss_mask.sum() + total_valid_tokens = local_valid_tokens.clone() + all_reduce( + total_valid_tokens, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group + ) + else: + local_valid_tokens = logits.shape[0] + total_valid_tokens = local_valid_tokens * self._group_size + # in the loss function we compute grads w.r.t sum of losses, + # so we need to multiply back by the group size and divide by the number of valid tokens to get the correct scaling + # note, the function returns the sum of local losses, so we need to handle this properly for reporting + grad_output *= self._group_size / total_valid_tokens # multiply back by the group size + else: + reverse_kl_impl = ReverseKLImpl.no_tp + distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, @@ -398,8 +461,15 @@ def _logits_cross_entropy_forward_backward( target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), + reverse_kl_impl=reverse_kl_impl, + total_valid_tokens=total_valid_tokens, ) - elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: + if self._sequence_parallel_logits: + # distillation_loss is local sum, so we need to divide by the number of valid tokens to get the correct scaling + all_reduce(distillation_loss, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group) + distillation_loss /= total_valid_tokens # final global loss + + elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py new file mode 100644 index 000000000..a5a789f9e --- /dev/null +++ b/fast_llm/layers/multi_modal/embedding.py @@ -0,0 +1,183 @@ +import typing + +import torch + +from fast_llm.core.distributed import set_generator +from fast_llm.core.ops import reduce_forward, split +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches +from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert, div + + +class MultiModalEmbedding(LanguageModelEmbedding): + """ + Multi-modal embedding layer to combine embeddings from text, image and more modalities. + """ + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + ): + super().__init__(config, tensor_space) + + # @torch.compile + def _forward( + self, + input_: torch.Tensor, + tokens: torch.Tensor, + position_ids: torch.Tensor | None, + image_positions: list[torch.Tensor] | None, + image_sizes: list[list[tuple[int, int]]] | None, + ) -> torch.Tensor: + """ + Forward pass for the multi-modal embedding layer. + Args: + input_: The input tensor (image embeddings). + tokens: The tokenized text input. + position_ids: The position ids for the text input. + image_positions: The positions of the image tokens in the input. + image_sizes: The sizes of the images in the input. + Returns: + The combined embeddings for text and images. + """ + Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) + group = self._tensor_space.distributed.tensor_group + if self._sequence_parallel: + micro_seqlen = input_.size(0) + patch_start_offset = self._distributed_config.tensor_rank * micro_seqlen + patch_end_offset = (self._distributed_config.tensor_rank + 1) * micro_seqlen + else: + patch_start_offset = 0 + patch_end_offset = input_.size(0) + if self._parallel_embeddings: + token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) + masked_tokens = (tokens - self._vocab_start_index) * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa + # Cloning since we will modify the embeddings in-place + embeddings = embeddings.clone() + # the embeddings tensor are full-sized, but we might get a split of the patch embeddings + # We need to determine the offset in the embeddings tensor for each sample + # and also account for the special image tokens if applicable + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if image_embedding_offset + num_patches < patch_start_offset: + image_embedding_offset += num_patches + continue + if self._config.vision_encoder.image_break_token is not None: + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + if row_start_src > patch_end_offset: + break + if row_start_src + patch_width <= patch_start_offset: + continue + + input_start_index = max(row_start_src, patch_start_offset) - patch_start_offset + input_end_index = min(row_start_src + patch_width, patch_end_offset) - patch_start_offset + embeddings_start_index = row_start_dst + max(patch_start_offset - row_start_src, 0) + embeddings_end_index = ( + row_start_dst + patch_width - max(row_start_src + patch_width - patch_end_offset, 0) + ) + # row_end_src = min(row_start_src + patch_width, patch_end_offset) + if self._sequence_parallel: + embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ + input_start_index:input_end_index, sample_idx + ] + else: + embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ + sample_idx, input_start_index:input_end_index + ] + else: + input_start_index = max(image_embedding_offset, patch_start_offset) - patch_start_offset + input_end_index = ( + min(image_embedding_offset + num_patches, patch_end_offset) - patch_start_offset + ) + embedding_start_index = position - max(patch_start_offset - image_embedding_offset, 0) + embedding_end_index = ( + position + num_patches - max(image_embedding_offset + num_patches - patch_end_offset, 0) + ) + embeddings[sample_idx, embedding_start_index:embedding_end_index] = input_[ + input_start_index:input_end_index, sample_idx + ] + # embeddings[sample_idx, position : position + num_patches] = input_[ + # sample_idx, image_embedding_offset : image_embedding_offset + num_patches + # ] + image_embedding_offset += num_patches + if image_embedding_offset > patch_end_offset: + break + embeddings = reduce_forward(embeddings, group) + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + if self._sequence_parallel: + embeddings = split(embeddings, group=group, dim=0) + else: + if self._sequence_parallel: + tokens = split(tokens, group=group, dim=0) + if self._use_absolute_position_embeddings: + position_ids = split(position_ids, group=group, dim=0) + # mask padded tokens + token_mask = tokens >= 0 + masked_tokens = tokens * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) + embeddings = embeddings.clone() + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if self._config.vision_encoder.image_break_token is not None: + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) + + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + else: + embeddings[sample_idx, position : position + num_patches] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_patches + ] + # Move to the next image in the input tensor + image_embedding_offset += num_patches + + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + with set_generator( + self._tensor_space.distributed.tp_generator + if self._sequence_parallel + else self._tensor_space.distributed.pp_generator + ): + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + return embeddings.to(dtype=self._residual_dtype) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Embedding output", + dtype=self._residual_dtype, + ) + position_ids = kwargs.get(LanguageModelKwargs.position_ids) + image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + tokens = kwargs.get(LanguageModelKwargs.tokens) + + return self._forward(input_, tokens, position_ids, image_positions, image_sizes) diff --git a/fast_llm/layers/ssm/preprocessing.py b/fast_llm/layers/ssm/preprocessing.py new file mode 100644 index 000000000..343f0bb28 --- /dev/null +++ b/fast_llm/layers/ssm/preprocessing.py @@ -0,0 +1,68 @@ +import logging +import typing + +import torch + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.ssm.config import SSMKwargs +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +class Mamba2Preprocessor(Preprocessor): + def __init__(self, config: HybridSSMBaseModelConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + self._transformer_dim_names = config.transformer._transformer_dim_names + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + """ + Simplified preprocessor that does not take into account micro-sequences. + """ + if TransformerKwargs.sequence_lengths not in kwargs: + return + sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] + if TransformerKwargs.cu_seqlens_k in kwargs: + # already set this in the transformer preprocessor, so we can use it here + cu_seqlens_k = kwargs[TransformerKwargs.cu_seqlens_k] + cu_seqlens_q = kwargs[TransformerKwargs.cu_seqlens_q] + Assert.eq( + cu_seqlens_k.shape[0], + cu_seqlens_q.shape[0], + msg="cu_seqlens_k and cu_seqlens_q have different lengths, is micro_sequence_length being used? This is currently not supported for Mamba.", + ) + Assert.all_equal(cu_seqlens_k, cu_seqlens_q) + cu_seqlens = cu_seqlens_k + else: + seqlens = torch.cat(sequence_lengths) + cu_seqlens = torch.cat( + ( + torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), + torch.cumsum(seqlens, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + ) + ) + kwargs[SSMKwargs.cu_seqlens] = cu_seqlens + # from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/verl/models/mamba/hybrid_wrapper.py#L152 + kwargs[SSMKwargs.seq_idx] = torch.cat( + [ + torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1]) + ], + dim=0, + ).unsqueeze(0) + + sequence_lengths = kwargs.get(TransformerKwargs.sequence_lengths) + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + position_ids = torch.stack( + [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + ).to(self._tensor_space.distributed.device, dtype=torch.int64) + position_ids = position_ids[ + :, sequence_k - sequence_q : sequence_k + ] # this is only needed if we do micro-sequences? + kwargs[SSMKwargs.ssm_position_ids] = position_ids.to(torch.int32) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py new file mode 100644 index 000000000..7ec50dfee --- /dev/null +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -0,0 +1,55 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.triton.mlp import torch_mlp_activation +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames +from fast_llm.tensor import TensorMeta, init_normal_ + + +class VisionAdapter(Layer): + """ + Vision adapter layer for the LLM. + """ + + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + super().__init__() + input_dim = tensor_space[VisionEncoderDimNames.out_channels] + self._activation_type = config.adapter_activation_type + self.layer_1 = Linear( + input_dim, + tensor_space[VisionEncoderDimNames.adapter_size], + bias=True, + weight_init_method=init_normal_(std=config.adapter_init_method_std), + bias_init_method=init_normal_(std=config.adapter_init_method_std), + lr_scale=config.adapter_lr_scale, + ) + self.layer_2 = Linear( + tensor_space[VisionEncoderDimNames.adapter_size], + tensor_space[TransformerDimNames.hidden], + bias=True, + weight_init_method=init_normal_(std=config.adapter_init_method_std), + bias_init_method=init_normal_(std=config.adapter_init_method_std), + lr_scale=config.adapter_lr_scale, + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision adapter output", + dtype=input_.dtype, + ) + return self.layer_2( + torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py new file mode 100644 index 000000000..a705d948a --- /dev/null +++ b/fast_llm/layers/vision_encoder/config.py @@ -0,0 +1,181 @@ +import enum + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.utils import Assert + + +class VisionEncoderDimNames: + in_channels = "vision_in_channels" + out_channels = "vision_out_channels" + adapter_size = "vision_adapter_size" + patch_size = "vision_patch_size" + kv_channels = "vision_kv_channels" + + +class VisionEncoderKwargs: + patch_size = "patch_size" + images = "images" + image_patches = "image_patches" + image_positions = "image_positions" + max_image_size = "max_image_size" + image_sizes = "image_sizes" + image_mean = "image_normalization_mean" + image_std = "image_normalization_std" + image_rescale_factor = "image_rescale_factor" + rope_theta = "vit_rope_theta" + rotary_inv_freq = "vit_rotary_inv_freq" + kv_channels = "vit_kv_channels" + max_image_tokens = "max_image_tokens" + patch_embeddings = "patch_embeddings" + hidden_dims = "vit_hidden_dims" + image_patches_meta = "vit_image_patches_meta" + out_channels = "vit_out_channels" + + +@config_class() +class ImageNormalizationConfig(Config): + mean_r: float = Field( + default=0.48145466, + desc="Mean value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_g: float = Field( + default=0.4578275, + desc="Mean value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_b: float = Field( + default=0.40821073, + desc="Mean value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_r: float = Field( + default=0.26862954, + desc="Standard deviation value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_g: float = Field( + default=0.26130258, + desc="Standard deviation value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_b: float = Field( + default=0.27577711, + desc="Standard deviation value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Rescale factor for the image normalization process.", + hint=FieldHint.optional, + ) + + +class VisionEncoderType(str, enum.Enum): + none = "none" + # TODO: better name? normalization, patch size, adapter can change based on implementation, no standard way currently. + pixtral = "pixtral" + + +@config_class(registry=True) +class VisionEncoderConfig(BaseModelConfig): + _abstract = False + + type: VisionEncoderType = Field( + default=VisionEncoderType.none, + desc="Type of the vision encoder. Choices: none, pixtral.", + hint=FieldHint.architecture, + ) + transformer: TransformerConfig = Field( + desc="Configuration for the vision transformer architecture.", + hint=FieldHint.core, + ) + patch_size: int = Field( + default=16, + desc="Patch size for the image encoder.", + hint=FieldHint.core, + ) + conv_bias: bool = Field( + default=False, + desc="Whether to use bias in the convolutional layer.", + hint=FieldHint.optional, + ) + patch_norm: NormalizationConfig = Field( + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + adapter_size: int = Field( + default=5120, + desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + hint=FieldHint.core, + ) + adapter_activation_type: ActivationType = Field( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) + adapter_bias: bool = Field( + default=True, + desc="Whether to use bias in the adapter linear layer.", + hint=FieldHint.optional, + ) + image_normalization: ImageNormalizationConfig = Field( + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + image_break_token: int | None = Field( + default=None, + desc="Token id to separate image rows. If None, no token id is applied.", + hint=FieldHint.optional, + ) + image_end_token: int | None = Field( + default=None, + desc="Token id to indicate the end of an image. If None, no token id is applied.", + hint=FieldHint.optional, + ) + adapter_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the adapter weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + conv_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the convolutional layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + adapter_init_method_std: float = Field( + default=None, + desc="Standard deviation for the normal initialization of the adapter weights. Default: adapter_size ** -0.5.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.adapter_init_method_std is None: + self.adapter_init_method_std = self.adapter_size**-0.5 + super()._validate() + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) + self.transformer.setup_tensor_space(tensor_space) + + @property + def enabled(self) -> bool: + return self.type != VisionEncoderType.none + + +for name in VisionEncoderType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + VisionEncoderConfig.register_subclass(name.value, VisionEncoderConfig) diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py new file mode 100644 index 000000000..6c2a70930 --- /dev/null +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -0,0 +1,62 @@ +import typing + +import torch + +from fast_llm.core.ops import split +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ + + +class PatchConv(Layer): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + super().__init__() + self._tensor_space = tensor_space + self._distributed_config = tensor_space.distributed_config + self._sequence_parallel = self._distributed_config.sequence_tensor_parallel + self._lr_scale = config.adapter_lr_scale + self.weight = ParameterMeta.from_dims( + ( + self._tensor_space[VisionEncoderDimNames.out_channels], + self._tensor_space[VisionEncoderDimNames.in_channels], + self._tensor_space[VisionEncoderDimNames.patch_size], + self._tensor_space[VisionEncoderDimNames.patch_size], + ), + init_method=init_normal_(), + lr_scale=self._lr_scale, + ) + if config.conv_bias: + self.bias = ParameterMeta.from_dims( + (self._tensor_space[VisionEncoderDimNames.out_channels],), + init_method=init_normal_(), + lr_scale=self._lr_scale, + ) + else: + self.bias = None + self.norm = config.patch_norm.get_layer(tensor_space[VisionEncoderDimNames.out_channels]) + self.stride = config.patch_size + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + hidden_dims = kwargs[VisionTransformerKwargs.hidden_dims] + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) + micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] + sequence_length = kwargs[TransformerKwargs.sequence_length] + out_channels = kwargs[VisionEncoderKwargs.out_channels] + reshape_dims = (micro_batch_size, sequence_length, out_channels) + group = self._tensor_space.distributed.tensor_group + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) + patch_embeddings = patch_embeddings.view(reshape_dims) + if self._sequence_parallel: + patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() + patch_embeddings = split(patch_embeddings, group=group, dim=0) + return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py new file mode 100644 index 000000000..adacd380c --- /dev/null +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -0,0 +1,281 @@ +import math +import typing + +import torch +import torchvision.transforms.v2 as torchvision_transforms + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.tensor import TensorMeta +from fast_llm.utils import div + + +def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the number of patches in height and width dimensions. + """ + return div(height, patch_size) * div(width, patch_size) + + +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) + + +def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + target_height, target_width = get_resize_dims( + image.size(1), image.size(2), max_height, max_width, patch_size=patch_size + ) + height, width = image.size(1), image.size(2) + while height > 2 * target_height or width > 2 * target_width: + # cap the resizing to half of the current size as a workaround for large images + # See pytorch issue: https://github.com/pytorch/pytorch/issues/103589 + intermediate_max_width = max(target_width, width // 2) + intermediate_max_height = max(target_height, height // 2) + height, width = get_resize_dims( + height, width, intermediate_max_height, intermediate_max_width, patch_size=patch_size + ) + image = torchvision_transforms.functional.resize( + image, size=(height, width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC + ) + + # TODO: options for interpolation mode? + return torchvision_transforms.functional.resize( + image, size=(target_height, target_width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC + ) + + +def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: + """ + Normalize the image using the specified mean and standard deviation. + """ + return torchvision_transforms.functional.normalize(image, mean=mean, std=std) + + +def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: + """ + Pad images on the right and bottom with 0s untitl max_height and max_width + """ + width_padding = max(0, max_height - image.size(1)) + depth_padding = max(0, max_width - image.size(2)) + return torchvision_transforms.functional.pad(image, (0, 0, depth_padding, width_padding), 0) + + +def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: + freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) + max_patches_per_side = max_image_size // patch_size + + h = torch.arange(max_patches_per_side) + w = torch.arange(max_patches_per_side) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + return torch.cat((inv_freq, inv_freq), dim=-1) + + +def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: + patch_height = height // patch_size + patch_width = width // patch_size + mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + return ids[:, 0] + + +class VisionPreprocessor(Preprocessor): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( + ( + TensorDim( + VisionTransformerDimNames.batch, + kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + ), + TensorDim(VisionEncoderDimNames.in_channels, 3), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + ), + dtype=self._distributed_config.training_dtype.torch, + ) + + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: + images = kwargs.get(VisionEncoderKwargs.images) + max_image_size = kwargs.get(VisionEncoderKwargs.max_image_size) + im_width = kwargs.get(VisionEncoderKwargs.max_image_size) + patch_size = kwargs[VisionEncoderKwargs.patch_size] + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + image_sizes = [ + [get_resize_dims(im.size(1), im.size(2), max_image_size, im_width, patch_size=patch_size) for im in ims] + for ims in images + ] + kwargs[VisionEncoderKwargs.image_sizes] = image_sizes + images = [ + [ + normalize( + resize(image, max_image_size, im_width, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch + ) + / kwargs[VisionEncoderKwargs.image_rescale_factor], + mean=kwargs[VisionEncoderKwargs.image_mean], + std=kwargs[VisionEncoderKwargs.image_std], + ) + for image in imgs + ] + for imgs in images + ] + + if LanguageModelKwargs.labels in kwargs: + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() + + patches = [] + patch_position_ids = [] + cu_seqlens = [0] + max_seqlen = -1 + kwargs.get(TransformerKwargs.sequence_first) + for idx, (imgs, sizes, positions) in enumerate(zip(images, image_sizes, image_positions)): + # add an empty tensor for clean concatenation in case of no images + seq_patches = [ + torch.tensor([]).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + ] + sample_cu_seqlen = 0 + for image, size, position in zip(imgs, sizes, positions): + seqlen = get_num_patches(*size, patch_size) + num_tokens = get_num_image_tokens( + *size, + patch_size=patch_size, + image_break=self._config.image_break_token is not None, + image_end=self._config.image_end_token is not None, + ) + if LanguageModelKwargs.labels in kwargs: + # set labels for image patches to -100 + labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 + if seqlen > max_seqlen: + max_seqlen = seqlen + cu_seqlens.append(cu_seqlens[-1] + seqlen) + sample_cu_seqlen += seqlen + seq_patches.append( + torch.cat( + [ + torch.nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).T.reshape( + -1, 3, patch_size, patch_size + ), + ] + ) + ) + padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen + if padding_size > max_seqlen: + max_seqlen = padding_size + cu_seqlens.append(kwargs[TransformerKwargs.sequence_length] * (idx + 1)) + patches.append( + torch.cat( + [ + *seq_patches, + torch.zeros(padding_size, 3, patch_size, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ), + ] + ) + ) + if sizes: + position_ids = torch.cat( + [position_ids_in_meshgrid(*size, max_image_size // patch_size, patch_size) for size in sizes] + ).to(device=self._tensor_space.distributed.device) + else: + position_ids = torch.tensor( + [], + dtype=torch.int64, + device=self._tensor_space.distributed.device, + ) + # We pad at the end instead of padding at the position in meshgrid because flash attention does not support custom attention masks + patch_position_ids.append( + torch.cat( + [ + position_ids, + torch.full((padding_size,), 0).to(device=self._tensor_space.distributed.device), + ] + ) + ) + assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] + patches = torch.cat(patches) + patch_position_ids = torch.cat(patch_position_ids) + kwargs[VisionEncoderKwargs.image_patches] = patches + kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids + kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( + kwargs[VisionEncoderKwargs.rope_theta], + kwargs[VisionEncoderKwargs.kv_channels], + max_image_size, + patch_size, + ).to(device=self._tensor_space.distributed.device) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size * im_width, patch_size**2) + # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k + kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.cu_seqlens_k] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen + kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen + if LanguageModelKwargs.labels in kwargs: + kwargs[LanguageModelKwargs.labels] = labels + + # TODO: add proper preprocessing for attention-mask when not using flash attention + # Following is just a dummy code to run the tests. + kwargs[self._config.transformer._transformer_kwargs.attention_mask] = torch.ones( + (1, 1, kwargs[TransformerKwargs.sequence_length], 1, kwargs[TransformerKwargs.sequence_length]), + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + kwargs[self._config.transformer._transformer_kwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._distributed_config.training_dtype.torch).min, + dtype=self._distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py new file mode 100644 index 000000000..534d813ff --- /dev/null +++ b/fast_llm/models/custom/model.py @@ -0,0 +1,70 @@ +import typing + +import torch + +from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.engine.base_model.base_model import Layer, LossDef +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.transformer.transformer import TransformerBlock +from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig +from fast_llm.models.custom.head import CustomHead +from fast_llm.models.gpt.config import GPTBaseModelConfig +from fast_llm.models.gpt.model import GPTBaseModel, GPTModel +from fast_llm.tensor import TensorMeta + + +class CustomBaseModel[ConfigType: CustomBaseModelConfig](GPTBaseModel[ConfigType]): + config_class: typing.ClassVar[type[GPTBaseModelConfig]] = GPTBaseModelConfig + + def __init__( + self, + config: CustomBaseModelConfig, + distributed_config: DistributedConfig, + ): + # TODO: Implement / update. + super().__init__(config, distributed_config) + + def get_layers(self) -> list[Layer]: + # TODO: Adjust as needed. + return [ + LanguageModelEmbedding(self._config, self._tensor_space), + *[ + TransformerBlock( + self._config.transformer, + self._tensor_space, + block_index=i + 1, + ) + for i in range(self._config.transformer.num_layers) + ], + CustomHead(self._config, self._tensor_space), + ] + + def preprocess_meta( + self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType + ) -> list[tuple[TensorMeta, dict]]: + # TODO: Adjust or reimplement. + return super().preprocess_meta(batch_meta, phase) + + def preprocess( + self, + batch: GPTBatch, + preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, + *, + phase: PhaseType, + iteration: int, + metrics: dict | None = None, + ) -> list[tuple[torch.Tensor, dict]]: + # TODO: Adjust or reimplement. + return super().preprocess(batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics) + + @property + def loss_defs(self) -> list[LossDef]: + # TODO: Adjust or reimplement. + return super().loss_defs + + +class CustomModel[ConfigType: CustomBaseModelConfig](GPTModel[ConfigType]): + config_class: typing.ClassVar[type[CustomModelConfig]] = CustomModelConfig + base_model_class: typing.ClassVar[type[CustomBaseModel]] = CustomBaseModel diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 8fbb99cad..7712d764d 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -16,9 +16,12 @@ DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, LlamaCheckpointFormat, + LlavaCheckpointFormat, + LlavaHybridCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, MTPLlamaCheckpointFormat, + PixtralCheckpointFormat, Qwen2CheckpointFormat, ) from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds @@ -104,6 +107,9 @@ class GPTModelConfig(FastLLMModelConfig): DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, AprielHybridSSMCheckpointFormat, + PixtralCheckpointFormat, + LlavaCheckpointFormat, + LlavaHybridCheckpointFormat, ) @classmethod @@ -124,6 +130,25 @@ def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelF return HuggingfaceGPTModelForCausalLM + @classmethod + def get_checkpoint_format(cls, format: type[CheckpointFormat]) -> type[CheckpointFormat]: + if isinstance(format, type) and issubclass(format, CheckpointFormat): + format_ = cls.get_checkpoint_format(format.name) + Assert.is_(format, format_) + return format_ + elif isinstance(format, dict): + for format_ in cls.checkpoint_formats: + if format_.name == format["name"]: + if (vision_name := format.get("vision_name")) is not None: + format_.vision_name = vision_name + if (text_name := format.get("text_name")) is not None: + format_.text_name = text_name + return format_ + for format_ in cls.checkpoint_formats: + if format_.name == format: + return format_ + raise ValueError(f"Checkpoint format {format} not supported for model {cls.model_name}") + @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): @@ -177,6 +202,10 @@ def _validate(self) -> None: ) Assert.geq(output_layer.prediction_heads, output_layer.prediction_heads) + if self.model.base_model.vision_encoder.enabled: + assert self.batch.max_image_size is not None, "max_image_size must be set when using vision encoder" + Assert.gt(self.batch.max_image_size, 0) + @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: from fast_llm.models.gpt.trainer import GPTTrainer diff --git a/fast_llm/models/gpt/conversion/auto.py b/fast_llm/models/gpt/conversion/auto.py index 659d1f12c..bd2c3d2c7 100644 --- a/fast_llm/models/gpt/conversion/auto.py +++ b/fast_llm/models/gpt/conversion/auto.py @@ -8,9 +8,12 @@ DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, LlamaCheckpointFormat, + LlavaCheckpointFormat, + LlavaHybridCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, MTPLlamaCheckpointFormat, + PixtralCheckpointFormat, Qwen2CheckpointFormat, ) from fast_llm.models.gpt.conversion.diffusion_dream import DiffusionDreamHuggingfaceCheckpointHandler @@ -35,4 +38,7 @@ class AutoGPTHuggingfaceCheckpointHandler( DiffusionDreamCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler, DiffusionLlamaCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, AprielHybridSSMCheckpointFormat.name: AprielHuggingfaceCheckpointHandler, + PixtralCheckpointFormat: PixtralHuggingfaceCheckpointHandler, + LlavaCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, + LlavaHybridCheckpointFormat: LlavaHybridCHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/conversion/config.py b/fast_llm/models/gpt/conversion/config.py index 7c06906ad..f6e3d65c9 100644 --- a/fast_llm/models/gpt/conversion/config.py +++ b/fast_llm/models/gpt/conversion/config.py @@ -47,3 +47,27 @@ class DiffusionLlamaCheckpointFormat(GPTHuggingfaceCheckpointFormat): class AprielHybridSSMCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_hybrid_ssm" + + +class LlavaCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llava" + # Using default values for vision and text models. Can be overridden in the config + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "mistral" + + +class PixtralCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "pixtral" + + +class LlavaHybridCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llava_hybrid" + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + trust_remote_code: typing.ClassVar[bool] = True + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import LlavaHybridHuggingfaceCheckpointHandler + + return LlavaHybridHuggingfaceCheckpointHandler diff --git a/fast_llm/models/gpt/conversion/llava.py b/fast_llm/models/gpt/conversion/llava.py new file mode 100644 index 000000000..996269862 --- /dev/null +++ b/fast_llm/models/gpt/conversion/llava.py @@ -0,0 +1,155 @@ +import typing + +from fast_llm import __version__ +from fast_llm.config import MISSING, get_nested_dict_value, set_nested_dict_value +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadMetadataConfig +from fast_llm.engine.checkpoint.external import ExternalStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig +from fast_llm.functional.config import ActivationType +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.conversion.auto import AutoGPTHuggingfaceCheckpointHandler +from tests.utils.model_configs import LlavaGPTHuggingfaceCheckpointFormat + + +class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaForConditionalGeneration" + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def get_vision_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + + @classmethod + def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + + @classmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + vision_handler_cls = cls.get_vision_handler_class() + text_handler_cls = cls.get_text_handler_class() + cfg_dict = cls._load_config(config.path) + kwargs = {} + if "text_config" in cfg_dict: + text_kwargs = text_handler_cls._import_config_dict(cfg_dict["text_config"]) + kwargs.update(text_kwargs) + if "vision_config" in cfg_dict: + vision_kwargs = vision_handler_cls._import_config_dict(cfg_dict["vision_config"]) + vision_kwargs = {tuple(["vision_encoder"] + list(key)): value for key, value in vision_kwargs.items()} + kwargs.update(vision_kwargs) + kwargs.update( + cls._import_config( + {key: value for key, value in cfg_dict.items() if key not in ("text_config", "vision_config")} + ) + ) + imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + return CheckpointMetadata( + fast_llm_version=__version__, + model=cls._model_class, + format=config.format, + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + shards=["weights"], + ) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=(("projector_intermediate_size",),), + ), + ] + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: + # handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + kwargs = {} + for converter in cls._create_config_converters(): + try: + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return kwargs + + @classmethod + def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: + exported_config = {} + vision_handler_cls = cls.get_vision_handler_class() + text_handler_cls = cls.get_text_handler_class() + for converter in vision_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in text_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("text_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return exported_config + + def _create_weight_converters(self): + vision_handler_cls = self.get_vision_handler_class() + vision_handler = vision_handler_cls(self._model) + converters = vision_handler._create_weight_converters(hf_base_prefix="vision_tower.", offset=0) + text_handler_cls = self.get_text_handler_class() + text_handler = text_handler_cls(self._model) + converters.extend( + text_handler._create_weight_converters(hf_base_prefix="language_model.", offset=vision_handler.num_layers) + ) + return converters diff --git a/fast_llm/models/gpt/conversion/llava_hybrid.py b/fast_llm/models/gpt/conversion/llava_hybrid.py new file mode 100644 index 000000000..45eb1cf2d --- /dev/null +++ b/fast_llm/models/gpt/conversion/llava_hybrid.py @@ -0,0 +1,40 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import ExternalStateDictCheckpointHandler +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.models.gpt.conversion.config import LlavaHybridCheckpointFormat +from fast_llm.models.gpt.conversion.llava import LlavaHuggingfaceCheckpointHandler + + +class LlavaHybridHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlavaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaHybridForConditionalGeneration" + modeling_file = modeling_llava_hybrid.__file__ + configuration_file = configuration_llava_hybrid.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = configuration_llava_hybrid.LlavaHybridConfig + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + additional_files = [ + modeling_ssm_hybrid_apriel15b.__file__, + configuration_ssm_hybrid_apriel15b.__file__, + ] + + @classmethod + def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", + "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", + "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + }, + ), + ] diff --git a/fast_llm/models/gpt/conversion/pixtral.py b/fast_llm/models/gpt/conversion/pixtral.py new file mode 100644 index 000000000..da055a5db --- /dev/null +++ b/fast_llm/models/gpt/conversion/pixtral.py @@ -0,0 +1,266 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import SplitWeightConverter, WeightConverter +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.rotary.config import Rotary2DConfig +from fast_llm.layers.common.normalization.config import LayerNormalizationConfig +from fast_llm.models.gpt.conversion.llama import KeyValueWeightConverter, MLPLayer2Converter, QueryWeightConverter +from fast_llm.utils import Assert + + +class PixtralNumHeadsConverter(ParamConverter): + """ + Pixtral encoder uses Multi-Head Attention. + Map `num_attention_heads` and `head_groups` to a single `num_heads` parameter. + """ + + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 2) + Assert.eq(len(self.export_names), 1) + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (num_heads, head_groups) = fast_llm_values + assert head_groups == num_heads, "Pixtral encoder expects num_heads == head_groups (MHA)" + return (num_heads,) + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (num_heads,) = export_values + return (num_heads, num_heads) + + +class PixtralRotaryParamConverter(ParamConverter): + """ + Pixtral encoder uses 2D Rotary Embeddings. + Map `rope_theta` to a single `rotary` parameter. `rotary_scaling` is not needed. + """ + + def __init__(self, fast_llm_names, export_names): + Assert.eq(len(fast_llm_names), 1) + Assert.eq(len(export_names), 1) + self.fast_llm_names = fast_llm_names + self.export_names = export_names + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (rotary_config,) = fast_llm_values + if type(rotary_config) is Rotary2DConfig: + return (rotary_config.theta,) + else: + raise ValueError(f"Unsupported rotary type: {type(rotary_config).__name__}") + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (rotary_theta,) = export_values + rotary_config = { + "type": "rope_2d", + "theta": rotary_theta, + } + return (rotary_config,) + + +class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = FastLLMModelConfig + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value="pixtral"), + ConstantImportParamConverter(fast_llm_names=(("patch_norm", "type"),), fast_llm_value="rms_norm"), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value="rms_norm" + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "type"),), fast_llm_value="image_encoder"), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), + ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_layers", + ), + ), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "hidden_size", + ), + ), + export_names=(("hidden_size",),), + ), + PixtralNumHeadsConverter( + fast_llm_names=( + ( + "transformer", + "num_attention_heads", + ), + ( + "transformer", + "head_groups", + ), + ), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "ffn_hidden_size", + ), + ), + export_names=(("intermediate_size",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "kv_channels", + ), + ), + export_names=(("head_dim",),), + ), + # ConstantImportParamConverter( + # fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.rope_2d + # ), + # RenameParamConverter( + # fast_llm_names=( + # ( + # "transformer", + # "rotary", + # "theta", + # ), + # ), + # export_names=(("rope_theta",),), + # ), + PixtralRotaryParamConverter( + fast_llm_names=(("transformer", "rotary"),), + export_names=(("rope_theta",),), + ), + RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), + ] + + def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] + + def _create_vision_transformer_layer_converters( + self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" + ) -> list[WeightConverter]: + # Vision transformer layer + transformer_config = self._model.config.base_model.vision_encoder.transformer + norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) + name_bias_cls = [ + # Self-attn + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.q_proj", + transformer_config.add_attn_qkv_bias, + QueryWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", + ( + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.k_proj", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.v_proj", + ), + transformer_config.add_attn_qkv_bias, + KeyValueWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.o_proj", + transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention_norm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.ffn_norm", + norm_bias, + WeightConverter, + ), + ] + converters = [] + for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + hf_prefix, + use_bias, + cls, + ) + # MLP + converters += self._get_transformer_mlp_converters( + f"layers.{fast_llm_offset + transformer_layer_index}", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}", + ) + return converters + + def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: + converters = [] + norm_bias = isinstance(self._model.config.base_model.vision_encoder.patch_norm, LayerNormalizationConfig) + converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) + if self._model.config.base_model.vision_encoder.conv_bias: + converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) + converters.append(WeightConverter(f"layers.{offset}.norm.weight", f"{hf_base_prefix}ln_pre.weight")) + if norm_bias: + converters.append(WeightConverter(f"layers.{offset}.norm.bias", f"{hf_base_prefix}ln_pre.bias")) + + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers + for i in range(num_layers): + converters += self._create_vision_transformer_layer_converters(i, offset + 1, hf_base_prefix) + + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.weight", "multi_modal_projector.linear_1.weight" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.weight", "multi_modal_projector.linear_2.weight" + ), + ] + ) + if self._model.config.base_model.vision_encoder.adapter_bias: + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.bias", "multi_modal_projector.linear_1.bias" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.bias", "multi_modal_projector.linear_2.bias" + ), + ] + ) + + return converters + + @property + def num_layers(self) -> int: + # +2 for projector and conv layers + return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 diff --git a/fast_llm/models/gpt/llava.py b/fast_llm/models/gpt/llava.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b7d751a61..bbe9f5cb6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -11,10 +11,17 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor +from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding +from fast_llm.layers.vision_encoder.adapter import VisionAdapter +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.layers.vision_encoder.patch_conv import PatchConv +from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -44,10 +51,82 @@ def __init__( param, self._config.decoder.block, config.embeddings_layer.hidden_size ) # Noqa # `self._reference_models` is not populated at this point, so we pass a mutable dict. - self._preprocessors: list[Preprocessor] = self._config.get_preprocessors(distributed_config) + self._preprocessors: list[Preprocessor] = [] + if self._config.use_absolute_position_embeddings: + self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) + # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. + # TODO: Find a better solution. + self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) + if self._use_flash_attention: + self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) + else: + self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + + if self._config.enable_dpo: # TODO better way to pass in? + self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + + if self._config.vision_encoder.enabled: + self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) + self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) + + def get_output_layers(self) -> list[Layer]: + layers = [] + for i in range(self._config.prediction_heads): + if i > 0: + layers.append( + TransformerBlock( + self._config.transformer, + self._tensor_space, + # TODO MTP: which index? + block_index=max(self._config.transformer.num_layers + i, 1), + # The last layer only returns the transformer output. + # The previous layers return a stack of shared_hidden and transformer_output. + return_input=i < self._config.prediction_heads - 1, + ) + ) + layers.append( + LanguageModelHead( + self._config, + self._tensor_space, + prediction_distance=i, + ) + ) + return layers + + def get_vision_layers(self) -> list[Layer]: + vit_layers = [ + VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, block_index=idx + 1) + for idx in range(self._config.vision_encoder.transformer.num_layers) + ] + return [ + PatchConv(self._config.vision_encoder, self._tensor_space), + *vit_layers, + VisionAdapter(self._config.vision_encoder, self._tensor_space), + MultiModalEmbedding(self._config, self._tensor_space), + ] + + def get_embedding_layers(self) -> list[Layer]: + if self._config.vision_encoder.enabled: + return self.get_vision_layers() + else: + return [LanguageModelEmbedding(self._config, self._tensor_space)] def get_layers(self) -> list[Layer]: - return self._config.get_blocks(self._distributed_config) + return [ + *(self.get_embedding_layers()), + *[ + TransformerBlock( + self._config.transformer, + self._tensor_space, + block_index=i + 1, + # The last layer only returns the transformer output. + # The previous layers return a stack of shared_hidden and transformer_output. + return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, + ) + for i in range(self._config.transformer.num_layers) + ], + *self.get_output_layers(), + ] def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType @@ -67,8 +146,41 @@ def preprocess_meta( micro_sequence_length = sequence_length truncate_documents = True - batch_data = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(BlockDimNames.batch, micro_batch_size * batch_data.size, batch_data) + if self._config.vision_encoder.enabled: + try: + max_image_size = batch_meta.max_image_size + except AttributeError: + max_image_size = 256 + logger.warning("Inference mode: max_image_size not provided, defaulting to 256") + image_mean = [ + self._config.vision_encoder.image_normalization.mean_r, + self._config.vision_encoder.image_normalization.mean_g, + self._config.vision_encoder.image_normalization.mean_b, + ] + image_std = [ + self._config.vision_encoder.image_normalization.std_r, + self._config.vision_encoder.image_normalization.std_g, + self._config.vision_encoder.image_normalization.std_b, + ] + image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor + vision_kwargs = { + VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, + VisionEncoderKwargs.max_image_size: max_image_size, + VisionEncoderKwargs.image_mean: image_mean, + VisionEncoderKwargs.image_std: image_std, + VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, + VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, + VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size, + VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size, + } + else: + vision_kwargs = {} + + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) + batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) + + if isinstance(batch_meta, GPTBatchConfig): + micro_sequence_length = batch_meta.micro_sequence_length if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -98,11 +210,24 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) + self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, self._hidden_dim) if sequence_first else (batch_dim, hidden_sequence_q_dim, self._hidden_dim) ) + if self._config.vision_encoder.enabled: + vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden] + vision_hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + ) + vision_kwargs.update( + { + VisionTransformerKwargs.hidden_dims: vision_hidden_dims, + } + ) common_kwargs = { LanguageModelKwargs.phase: phase, @@ -110,8 +235,10 @@ def preprocess_meta( AttentionKwargs.hidden_dims: hidden_dims, AttentionKwargs.sequence_length: sequence_length, AttentionKwargs.sequence_q_dim: sequence_q_dim, + AttentionKwargs.micro_batch_size: micro_batch_size, LanguageModelKwargs.mask_inputs: not truncate_documents, } + common_kwargs.update(vision_kwargs) sequence_k_pasts = range( sequence_q_dim.size * self._distributed_config.sequence_data_rank, @@ -157,7 +284,11 @@ def preprocess_meta( reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs - preprocessed_meta.append((tokens, kwargs)) + if self._config.vision_encoder.enabled: + # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + else: + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -203,19 +334,20 @@ def preprocess( reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + token_ids = batch.token_ids if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. - batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() + token_ids = token_ids.transpose(0, 1).contiguous() preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size if sequence_first: - tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] + tokens = token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? - tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() + tokens = token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: @@ -235,16 +367,18 @@ def preprocess( if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: - labels = batch.token_ids[sequence_offset : sequence_k + prediction_heads] + labels = token_ids[sequence_offset : sequence_k + prediction_heads] else: # TODO: Avoid multiple contiguous calls? - labels = batch.token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() + labels = token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config + labels_cloned = False if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() - for idx, spans in enumerate(batch.loss_masking_spans): + labels_cloned = True + for i, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue valid_spans = spans[ @@ -255,27 +389,72 @@ def preprocess( valid_spans[:, 0].clamp_(min=sequence_offset) valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) valid_spans -= sequence_offset - loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - loss_mask[start : end + 1, idx] = False + labels[start : end + 1, i] = -100 else: - loss_mask[idx, start : end + 1] = False - if self._config.output_layer.distillation_model is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask - labels = torch.where(loss_mask, labels, -100) + labels[i, start : end + 1] = -100 + if self._config.vision_encoder.enabled: + if self._config.vision_encoder.image_break_token is not None: + if not labels_cloned: + labels = labels.clone() + labels_cloned = True + labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + if self._config.vision_encoder.image_end_token is not None: + if not labels_cloned: + labels = labels.clone() + labels_cloned = True + labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) + # Loss-masking for distillation losses + if self._config.distillation_model is not None: + loss_mask = torch.ones_like(labels, dtype=torch.bool) + loss_mask = torch.where(labels == -100, False, loss_mask) + kwargs[LanguageModelKwargs.loss_mask] = loss_mask kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) + if self._config.vision_encoder.enabled: + batch_images = ( + batch.images if batch.images is not None else [[]] * kwargs[AttentionKwargs.micro_batch_size] + ) + kwargs[VisionEncoderKwargs.images] = [ + [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for img in images + ] + for images in batch_images + ] + kwargs[VisionEncoderKwargs.image_positions] = ( + batch.image_positions + if batch.image_positions is not None + else [[]] * kwargs[AttentionKwargs.micro_batch_size] + ) + kwargs[LanguageModelKwargs.tokens] = tokens + for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) - preprocessed.append((tokens, kwargs)) + image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + if image_patches is not None: + preprocessed.append((image_patches, kwargs)) + else: + preprocessed.append((tokens, kwargs)) return preprocessed @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[0] + return self.layers[self.embedding_layer_index] + + @property + def transformer_layers(self) -> list[TransformerBlock]: + return self.layers[self.embedding_layer_index + 1 : -1] + + @property + def embedding_layer_index(self) -> int: + if self._config.vision_encoder.enabled: + return self._config.vision_encoder.transformer.num_layers + 2 + else: + return 0 @property def model_head(self) -> LanguageModelHead: @@ -290,7 +469,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - (0, *self.model_head_indices), + (self.embedding_layer_index, *self.model_head_indices), ) } elif self._config.output_layer.prediction_heads > 1: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 4dbbfbb1c..cc676d184 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -31,4 +31,13 @@ def _get_sampling_parameters( "extra_tokens": self._config.model.base_model.output_layer.prediction_heads, } ) + if self._config.model.base_model.vision_encoder.enabled: + parameters.update( + { + "patch_size": self._config.model.base_model.vision_encoder.patch_size, + "max_image_size": self._config.batch.max_image_size, + "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, + "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, + } + ) return parameters if _return_dict else GPTSamplingParameters(**parameters) diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index 5c0a2216c..ad76c56de 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -18,7 +18,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import LossKwargs, can_return_tuple, logging from transformers.utils.generic import ModelOutput from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig @@ -357,7 +357,13 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx if len(self.key_cache) <= layer_idx: return 0 - return self.key_cache[layer_idx].shape[-2] + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + return self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + # return self.key_cache[layer_idx].shape[-2] def reset(self): self.conv_states.zero_() @@ -886,7 +892,7 @@ def forward( self, hidden_states: torch.Tensor, past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - attention_mask: Optional[torch.Tensor] = None, + mamba_mask: Optional[torch.Tensor] = None, return_mixer_matrix=False, **kwargs, ): @@ -898,6 +904,10 @@ def forward( assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda" cache_position = kwargs.get("cache_position", None) batch, seqlen, dim = hidden_states.shape + # mamba_mask = ( + # None if seqlen == 1 else mamba_mask + # ) # prevent that hidden_states are expanded to mask's seq. dimention., i.e. we do not need apply_mask_to_padding_states when generating single token at a time + # hidden_states = apply_mask_to_padding_states(hidden_states, mamba_mask) ssm_state, conv_state = None, None use_precomputed_states = False @@ -978,7 +988,7 @@ def forward( # Update state (B D W) conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) + x = self.act(self.conv1d(x)[..., :seqlen]).transpose(1, 2) else: assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( @@ -986,7 +996,10 @@ def forward( weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, - ) + ) # .transpose(1, 2) + # x = apply_mask_to_padding_states(x, mamba_mask).transpose( + # 1, 2 + # ) # zero out everything that comes from padding tokens if not self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) @@ -1041,14 +1054,14 @@ def step(self, hidden_states, conv_state, ssm_state): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states_input) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states_input) + z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group) C = rearrange(C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states_input)) # B, d_inner if self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) @@ -1216,6 +1229,42 @@ def __init__(self, config: AprielHybridSSMConfig, **kwargs): # Initialize weights and apply final processing self.post_init() + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache and past_key_values is None: + # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) + output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **flash_attn_kwargs, + ) + past_key_values: HybridMambaAttentionDynamicCache = output.past_key_values + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + return output + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -1397,6 +1446,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, + mamba_mask=attention_mask, # non-expended mask **kwargs, ) diff --git a/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py b/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py new file mode 100644 index 000000000..b8e822d9f --- /dev/null +++ b/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py @@ -0,0 +1,117 @@ +from transformers import MistralConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +# Copied from configuration_ssm_hybrid_apriel15b.py +# TODO: split into mamba 2 and discrete mamba 2 configs with a base dict +ssm_config_default = { + # discrete mamba2 + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 32 * 128, + # mamba2 + "d_xb": None, # will be set to model dim + "dt_rank": "auto", + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init": "random", + "dt_scale": 1.0, + "dt_init_floor": 1e-4, + "conv_bias": True, +} + + +class AprielSSMHybridConfig(MistralConfig): + model_type = "apriel_ssm_thinker_hybrid" + + def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + super().__init__(**kwargs) + self.hybrid_block_layout = hybrid_block_layout + self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 + self.ssm_cfg = ssm_cfg or ssm_config_default + + for k, v in ssm_config_default.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v # to make sure all elements are present in the config + + +class LlavaHybridConfig(PretrainedConfig): + """ + Configuration class for Llava SSM-Hybrid-decoder model. + """ + + model_type = "llava_hybrid" + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + projector_hidden_act="gelu", + projector_intermediate_size=4096, + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=576, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + # projector_intermediate_size is an addition to the original Llava config + self.projector_intermediate_size = projector_intermediate_size + self.image_seq_length = image_seq_length + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + # Load the custom SSM hybrid config if specified + if text_config.get("model_type") == "apriel_ssm_thinker_hybrid": + text_config = AprielSSMHybridConfig(**text_config) + else: + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + + super().__init__(**kwargs) + + +__all__ = ["LlavaHybridConfig"] diff --git a/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py b/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py new file mode 100644 index 000000000..68073f9cd --- /dev/null +++ b/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py @@ -0,0 +1,132 @@ +from torch import nn +from transformers import AutoModel, LlavaForConditionalGeneration, LlavaModel +from transformers.activations import ACT2FN + +from .configuration_llava_hybrid import LlavaHybridConfig + +try: + # In the fast-llm repo, import from the SSM modeling file + from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielThinkerSSMHybridModel, + HybridMambaAttentionDynamicCache, + ) +except ImportError: + # In the exported checkpoint, import from local file + from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlavaHybridConfig): + super().__init__() + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.projector_intermediate_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.projector_intermediate_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LlavaHybridModel(LlavaModel): + """ + Llava SSM-Hybrid-decoder model. + """ + + config_class = LlavaHybridConfig + + def __init__(self, config: LlavaHybridConfig): + super(LlavaModel, self).__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + assert ( + config.text_config.model_type == "apriel_ssm_thinker_hybrid" + ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" + + self.language_model = AprielThinkerSSMHybridModel(config.text_config) + self.post_init() + + +class LlavaHybridForConditionalGeneration(LlavaForConditionalGeneration): + config_class = LlavaHybridConfig + + def __init__(self, config: LlavaHybridConfig): + super(LlavaForConditionalGeneration, self).__init__(config) + self.model = LlavaHybridModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Copy of the method from `AprielThinkerSSMHybridForCausalLM` + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config.text_config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + # Copy from `LlavaForConditionalGeneration.prepare_inputs_for_generation` + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs diff --git a/setup.cfg b/setup.cfg index 77073ab55..f65f21a85 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ OPTIONAL = # Huggingface tools HUGGINGFACE = - transformers>=4.52.4 + transformers==4.53.2 hf-transfer>=0.1.9 datasets>=3.6.0 huggingface-hub>=0.32.6 @@ -52,13 +52,20 @@ HUGGINGFACE = # To install on cpu environment (ex. for IDE support): # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = - mamba_ssm[causal-conv1d]==2.2.4 + mamba_ssm[causal-conv1d] @ git+https://github.com/jxiw/varlen_mamba.git@varlen_mamba cartesia_pytorch>=0.0.2 -GENERATION = - lm_eval>=0.4.9 +# GENERATION = +# lm_eval>=0.4.9 +# Required for supporting vision inputs +VISION = + # Vision Tools + webp>=0.4.0 + pillow-simd>=9.5.0 + torchvision>=0.20.0 + DEV = # Pre-commit git hook pre-commit>=4.2.0 diff --git a/tests/data/common.py b/tests/data/common.py index d8cc6fff2..b7284872f 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -122,10 +122,10 @@ def compare_indexed_dataset( loss_masking_spans: dict[int, list[int]] | None = None, ) -> None: Assert.eq(len(dataset), length) - sizes = dataset.get_document_sizes() + text_sizes, image_sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] + [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], text_sizes[: min(len(dataset), 100)] ) for i, expected_sample in expected_samples.items(): Assert.all_equal(dataset.get(i).token_ids, np.array(expected_sample, dtype=np.uint16)) @@ -219,10 +219,15 @@ def __len__(self) -> int: return self._config.num_documents def get_document_sizes(self) -> np.ndarray: - return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) + return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64), np.array( + [], dtype=np.int64 + ) def get_document_size(self, index: int) -> int: return self._config.num_tokens_per_document def get(self, index: int, *args, **kwargs) -> typing.Any: raise NotImplementedError() + + def has_images(self) -> bool: + return False diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 6a2be3dcc..e04af129a 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -79,14 +79,24 @@ def __len__(self) -> int: return len(self._samples) def get_document_sizes(self) -> np.ndarray: - return np.array([self.get_document_size(index) for index in range(len(self))], dtype=np.int64) + doc_sizes = [] + im_sizes = [] + for index in range(len(self)): + doc_size, im_size = self.get_document_size(index) + doc_sizes.append(doc_size) + im_sizes.append(im_size) + return np.array(doc_sizes, dtype=np.int64), np.array(im_sizes, dtype=np.int64) def get_document_size(self, index: int) -> int: - return len(self._samples[index]) + return len(self._samples[index]), [] def name(self) -> str: return "dataset" + @property + def has_images(self) -> bool: + return False + TEST_DATASET = SimpleGPTIndexedDataset( [ diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index f14f028e1..c836df9fe 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -23,12 +23,10 @@ def _reverse_kl_loss( ): scaled_target = target / teacher_softmax_temperature - scaled_target = torch.clamp(target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) with torch.enable_grad(): # Use log_softmax for consistency instead of _fused_softmax - logits = torch.clamp(logits, min=-50, max=50) student_log_probs = torch.log_softmax(logits, dim=-1) if loss_mask is None: loss = torch.nn.functional.kl_div( diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 714abc130..97a618cfc 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -343,12 +343,15 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) ) errors = [] - auto_model = ( - transformers.AutoModel - if model_testing_config.name in ("diffusion_llama", "dream") - else transformers.AutoModelForCausalLM - ) - model_as_hf = auto_model.from_pretrained(hf_path, trust_remote_code=True).cuda() + if model_testing_config.name in ("diffusion_llama", "dream"): + auto_model = transformers.AutoModel + elif model_testing_config.name in ("llava", "vision_hybrid_mamba2"): + auto_model = transformers.AutoModelForVision2Seq + else: + auto_model = transformers.AutoModelForCausalLM + model_as_hf = auto_model.from_pretrained( + hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code + ).cuda() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), (model_from_fast_llm, model_from_hf, model_as_hf), diff --git a/tests/test_ssms.py b/tests/test_ssms.py new file mode 100644 index 000000000..2a338f1ba --- /dev/null +++ b/tests/test_ssms.py @@ -0,0 +1,349 @@ +import inspect +import itertools +import pathlib +from functools import partial + +import pytest +import torch +from mamba2 import Mamba2 + +from fast_llm.config import NoAutoValidate +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.model import HybridSSMModel + +_mamba_varlen = False +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa + + _mamba_available = True + sig = inspect.signature(selective_scan_fn) + if "position_indices" in sig.parameters: + _mamba_varlen = True + else: + _mamba_varlen = False + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + +except (ImportError, RuntimeError): + _mamba_available = False + + +def get_hybrid_config(hybrid_block_layout=["t", "m2"], prediction_heads=1, default_mtp_type=None): + hidden_size = 512 + config = HybridSSMBaseModelConfig( + transformer=TransformerConfig(num_layers=len(hybrid_block_layout), hidden_size=hidden_size), + ssm=SSMConfig(d_xb=hidden_size, dt_rank=10, d_inner=hidden_size * 2), + hybrid_block_layout=hybrid_block_layout, + prediction_heads=prediction_heads, + default_mtp_type=default_mtp_type, + init_method_std_embed=0.02, + init_method_min_embed=-0.02, + init_method_max_embed=0.02, + use_position_embeddings=True, + tie_word_embeddings=False, + ) + return config + + +@pytest.mark.skip("Disabled due to cartesia_pytorch installation issue") +@pytest.mark.slow +def test_load_from_llamba_checkpoint(): + """ + Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. + """ + import cartesia_pytorch.Llamba.llamba + + vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json + batch_size = 2 + seq_length = 32 + + path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") + format = LLambaHuggingfaceCheckpointFormat + + x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") + + hf_model = cartesia_pytorch.Llamba.llamba.LMHeadModel.from_pretrained(path, strict=True).to("cuda") + parameter_sum_hf = sum(p.detach().sum().cpu().item() for p in hf_model.parameters()) + hf_logits = hf_model(x)["logits"].cpu() + del hf_model + torch.cuda.empty_cache() + + # Create checkpoint load config + checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) + # Initialize model + model = HybridSSMModel.from_pretrained(checkpoint_config) + param_sum = 0 + for stage in model.stages: + for fsdp in stage.fsdps: + if hasattr(fsdp, "_weight_shard"): + param_sum += torch.sum(fsdp._weight_shard).item() + assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 + + # model = GPTModel.from_pretrained(checkpoint_config) + assert model.config.base_model.vocab_size == vocab_size + schedule_config = ScheduleConfig() + with NoAutoValidate(): + batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) + batch_config.setup(DistributedConfig.from_dict({})) + batch_config.validate() + schedule_runner = ScheduleRunner( + config=schedule_config, + multi_stage=model, + distributed_config=model.distributed.config, + ) + schedule = Schedule( + multi_stage=model, + batch_config=batch_config, + schedule_config=schedule_config, + distributed_config=model.distributed.config, + phase=PhaseType.inference, + ) + schedule_runner.setup(model.distributed, optimizer=None) + + common_kwargs = { + TransformerKwargs.sequence_first: True, + TransformerKwargs.grad_output: False, + } + input_data = [(x, common_kwargs)] + + schedule_runner.run_step(iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True) + + logits = input_data[0][1]["logits"].cpu() + assert torch.allclose(logits, hf_logits, atol=1e-2) + + +@pytest.fixture +def distributed_config(): + return DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + + +@pytest.fixture +def distributed(distributed_config): + return Distributed(config=distributed_config) + + +def materialize_meta_tensors(model, tensor_space): + # Materialize parameters that are on meta device + for name, param in model.named_parameters(): + if param.device.type == "meta": + # Check if the parameter is a custom tensor type + if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): + param_data = param.new_empty(param.shape, device="cuda") + # Initialize param_data + param.init_parameter(param_data, tensor_space.distributed) + # Replace the parameter in the module + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + module = model + if module_path is not None: + for part in module_path.split("."): + module = getattr(module, part) + param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation + param.grad = None + param.grad_buffer = torch.empty_like(param) + param.param_grad_is_zero = True + module._parameters[param_name] = param + return model + + +def unpack(packed_hidden_states, cu_seqlens): + batch_size = packed_hidden_states.shape[0] + package_num = cu_seqlens.shape[0] - 1 + seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + hidden_dim = packed_hidden_states.shape[2] + hidden_states = torch.zeros( + package_num * batch_size, + seq_len, + hidden_dim, + dtype=packed_hidden_states.dtype, + device=packed_hidden_states.device, + ) + for j in range(batch_size): + for i in range(package_num): + line = j * package_num + i + hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[ + j, cu_seqlens[i] : cu_seqlens[i + 1], : + ] + return hidden_states + + +def pack(hidden_states, cu_seqlens, batch_size): + package_num, seq_len, hidden_dim = hidden_states.shape + seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] + seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) + indices_3d = ( + torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).unsqueeze(2).repeat(package_num, 1, hidden_dim) + ) + mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) + packed_hidden_states = hidden_states[mask_3d].view(batch_size, -1, hidden_dim) + return packed_hidden_states + + +def generate_random_cu_seqlens(seq_len, packages_num=2): + if packages_num < 1: + raise ValueError("packages_num must be at least 1") + + # base size of each chunk, and how many get an extra token + base, rem = divmod(seq_len, packages_num) + # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] + lengths = [base + 1 if i < rem else base for i in range(packages_num)] + + # split points exclude the final cumulative (seq_len) + split_points = list(itertools.accumulate(lengths))[:-1] + + # cu_seqlens = [0] + split_points + [seq_len] + cu_seqlens = [0] + split_points + [seq_len] + + # index: for each chunk, we emit 0,1,...,length-1 + index = [] + for length in lengths: + index.extend(range(length)) + + # sanity check + assert len(cu_seqlens) - 1 == packages_num + assert sum(lengths) == seq_len + assert len(index) == seq_len + + return cu_seqlens, index + + +# Quick and dirty test for Mamba2 varlen block from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/tests/pack_mamba/test_mamba_layer.py +# TODO: integrate in the testing framework +@pytest.mark.slow +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") +@pytest.mark.skipif(not _mamba_available, reason="Mamba2 is not available") +@pytest.mark.skipif(not _mamba_varlen, reason="Mamba2 varlen is not available") +def test_mamba_varlen_block(distributed_config, distributed): + """ + Compare that the output and grads of packed and unpacked Mamba2 varlen block are the same. + """ + hybrid_config = get_hybrid_config(hybrid_block_layout=["m2", "t"]) + tensor_space = TensorSpace(distributed_config=distributed_config) + tensor_space.setup(distributed) + hybrid_config.setup_tensor_space(tensor_space) + layer_idx = 0 + + mixer_cls = partial(Mamba2, block_index=layer_idx) + block_packed = SSMBlock( + hybrid_config.transformer, + hybrid_config.ssm, + mixer_cls=mixer_cls, + tensor_space=tensor_space, + block_index=layer_idx, + ) + block_ref = SSMBlock( + hybrid_config.transformer, + hybrid_config.ssm, + mixer_cls=mixer_cls, + tensor_space=tensor_space, + block_index=layer_idx, + ) + device = "cuda" + materialize_meta_tensors(block_packed, tensor_space) + materialize_meta_tensors(block_ref, tensor_space) + block_ref.load_state_dict(block_packed.state_dict()) + block_packed.to(device) + block_ref.to(device) + + batch_size = 2 + seq_len = 64 + packages_num = 2 + hidden_dim = hybrid_config.transformer.hidden_size + + cu_seqlens, index = generate_random_cu_seqlens(seq_len, packages_num=packages_num) + cu_seqlens = torch.tensor(cu_seqlens).cuda() + ssm_position_ids = torch.tensor(index, dtype=torch.int32).unsqueeze(0).expand(batch_size, -1).contiguous().cuda() + seq_idx = ( + torch.cat( + [ + torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1]) + ], + dim=0, + ) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + + # Generate packed_hidden_states with random values for testing + hidden_states_list = [ + torch.randn(l, hidden_dim, device=device, dtype=torch.bfloat16, requires_grad=True) + for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + ] + packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0) + packed_hidden_states = packed_hidden_states.expand(batch_size, -1, -1).contiguous() + # hidden_states should be forwarded without cu_seqlens + hidden_states = unpack(packed_hidden_states, cu_seqlens) + + # Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states + assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1] + # Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states + assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] + + output_states_packed = block_packed( + packed_hidden_states, + {"cu_seqlens": cu_seqlens, "seq_idx": seq_idx, "ssm_position_ids": ssm_position_ids, "sequence_first": False}, + ) + output_states_unpacked = block_ref( + hidden_states.clone(), {"cu_seqlens": None, "seq_idx": None, "ssm_position_ids": None, "sequence_first": False} + ) + tollerance = 1e-4 + assert output_states_packed.shape == packed_hidden_states.shape + assert output_states_unpacked.shape == hidden_states.shape + assert not torch.isnan(hidden_states).any() + assert not torch.isinf(hidden_states).any() + + output_states_unpacked = pack(output_states_unpacked, cu_seqlens, batch_size) + torch.allclose(output_states_packed, output_states_unpacked, atol=tollerance) + + loss = output_states_packed.sum() + loss.backward() + loss_ref = output_states_unpacked.sum() + loss_ref.backward() + assert torch.allclose(block_packed.mixer.conv1d_weight.grad, block_ref.mixer.conv1d_weight.grad, atol=tollerance) + assert torch.allclose(block_packed.mixer.conv1d_bias.grad, block_ref.mixer.conv1d_bias.grad, atol=tollerance) + assert torch.allclose( + block_packed.mixer.in_proj.weight.grad_buffer, block_ref.mixer.in_proj.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mixer.out_proj.weight.grad_buffer, block_ref.mixer.out_proj.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mixer.dt_in_proj.weight.grad_buffer, + block_ref.mixer.dt_in_proj.weight.grad_buffer, + atol=tollerance, + ) + + assert torch.allclose( + block_packed.mlp.layer_1.weight.grad_buffer, block_ref.mlp.layer_1.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_1.bias.grad_buffer, block_ref.mlp.layer_1.bias.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_2.weight.grad_buffer, block_ref.mlp.layer_2.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_2.bias.grad_buffer, block_ref.mlp.layer_2.bias.grad_buffer, atol=tollerance + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index aa8100126..a818a1f23 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -16,6 +16,7 @@ DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, LlamaCheckpointFormat, + LlavaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, MTPLlamaCheckpointFormat, @@ -680,6 +681,77 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms"), ) +_update_and_add_testing_config( + # Tests hybrid Mamba, llamba converter. + "llama", + "llava", + extra_args=[ + "batch.max_image_size=128", + "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.patch_norm.type=rms_norm", + "model.base_model.vision_encoder.transformer.add_linear_biases=False", + "model.base_model.vision_encoder.transformer.causal=False", + "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", + "model.base_model.vision_encoder.transformer.type=image_encoder", + "model.base_model.vision_encoder.transformer.gated=True", + "model.base_model.vision_encoder.transformer.num_layers=2", + "model.base_model.vision_encoder.transformer.hidden_size=256", + "model.base_model.vision_encoder.transformer.num_attention_heads=8", + "model.base_model.vision_encoder.transformer.head_groups=8", + "model.base_model.vision_encoder.transformer.init_method_std=0.022", + "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", + "model.base_model.vision_encoder.adapter_size=256", + "model.distributed.training_dtype=torch.bfloat16", + ], + megatron_args=None, + checkpoint_format=LlavaCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=8.0, +) + +_update_and_add_testing_config( + # Tests hybrid ssm, llamba converter. + "hybrid_mamba2", + "vision_hybrid_mamba2", + model_type="hybrid_ssm", + extra_args=[ + "batch.max_image_size=128", + "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.patch_norm.type=rms_norm", + "model.base_model.vision_encoder.transformer.add_linear_biases=False", + "model.base_model.vision_encoder.transformer.causal=False", + "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", + "model.base_model.vision_encoder.transformer.type=image_encoder", + "model.base_model.vision_encoder.transformer.gated=True", + "model.base_model.vision_encoder.transformer.num_layers=2", + "model.base_model.vision_encoder.transformer.hidden_size=256", + "model.base_model.vision_encoder.transformer.num_attention_heads=8", + "model.base_model.vision_encoder.transformer.head_groups=8", + "model.base_model.vision_encoder.transformer.init_method_std=0.022", + "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", + "model.base_model.vision_encoder.adapter_size=512", + "model.distributed.training_dtype=torch.bfloat16", + ], + megatron_args=None, + checkpoint_format=LlavaHybridHuggingfaceCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=16.0, +) + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: From 9114ce20fb9dd304e68d653fabca5b5800dd393a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 25 Sep 2025 23:30:09 -0400 Subject: [PATCH 02/15] Vision multimodal --- fast_llm/data/data/gpt/config.py | 40 +- fast_llm/data/data/gpt/data.py | 32 +- fast_llm/data/dataset/config.py | 19 +- .../dataset/gpt/components/__init__.py} | 0 .../data/dataset/gpt/components/config.py | 20 + .../data/dataset/gpt/components/images.py | 243 ++++++++ fast_llm/data/dataset/gpt/components/spans.py | 73 +++ .../data/dataset/gpt/components/tokens.py | 63 ++ fast_llm/data/dataset/gpt/config.py | 15 +- fast_llm/data/dataset/gpt/fim.py | 6 +- fast_llm/data/dataset/gpt/indexed.py | 46 +- fast_llm/data/dataset/gpt/memmap.py | 556 +++++++----------- fast_llm/data/dataset/gpt/sampled.py | 479 +++++++-------- fast_llm/data/tokenizer.py | 8 +- fast_llm/engine/config_utils/data_type.py | 2 +- fast_llm/layers/attention/attention.py | 4 +- fast_llm/layers/attention/rotary/rotary.py | 2 +- fast_llm/layers/block/block.py | 13 - fast_llm/layers/block/config.py | 2 - fast_llm/layers/common/linear/config.py | 50 +- fast_llm/layers/common/linear/convolution.py | 24 + fast_llm/layers/decoder/block.py | 36 +- fast_llm/layers/decoder/config.py | 23 +- .../layers/decoder/mlp/mixture_of_experts.py | 6 +- fast_llm/layers/decoder/mlp/mlp.py | 4 +- fast_llm/layers/language_model/config.py | 7 +- fast_llm/layers/language_model/embedding.py | 4 - fast_llm/layers/language_model/head.py | 4 - fast_llm/layers/ssm/discrete_mamba2.py | 4 +- fast_llm/layers/ssm/mamba.py | 9 +- fast_llm/layers/ssm/mamba2.py | 9 +- fast_llm/layers/vision_encoder/adapter.py | 55 -- fast_llm/layers/vision_encoder/config.py | 150 +++-- fast_llm/layers/vision_encoder/patch_conv.py | 62 -- .../vision_encoder/patch_convolution.py | 77 +++ .../layers/vision_encoder/preprocessing.py | 255 ++++---- fast_llm/models/custom/model.py | 70 --- fast_llm/models/gpt/config.py | 23 - fast_llm/models/gpt/model.py | 296 ++++------ fast_llm/models/gpt/trainer.py | 20 +- tests/data/common.py | 11 +- tests/data/test_sampling.py | 10 +- tools/concatenate_dataset.py | 1 + tools/generate_config_yaml_for_sharded_dst.py | 98 --- 44 files changed, 1470 insertions(+), 1461 deletions(-) rename fast_llm/{models/gpt/llava.py => data/dataset/gpt/components/__init__.py} (100%) create mode 100644 fast_llm/data/dataset/gpt/components/config.py create mode 100644 fast_llm/data/dataset/gpt/components/images.py create mode 100644 fast_llm/data/dataset/gpt/components/spans.py create mode 100644 fast_llm/data/dataset/gpt/components/tokens.py delete mode 100644 fast_llm/layers/vision_encoder/adapter.py delete mode 100644 fast_llm/layers/vision_encoder/patch_conv.py create mode 100644 fast_llm/layers/vision_encoder/patch_convolution.py delete mode 100644 fast_llm/models/custom/model.py delete mode 100644 tools/generate_config_yaml_for_sharded_dst.py diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index efee46959..405d1c672 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,16 +1,23 @@ import logging +import typing from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.config import MultiprocessingContext, TokenizerConfig from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig +from fast_llm.data.dataset.gpt.config import ( + GPTLegacyConfig, + GPTLegacyDatasetConfig, + GPTSampledDatasetConfig, + GPTSamplingConfig, +) +from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert logger = logging.getLogger(__name__) @config_class() -class GPTDataConfig(DataConfig): +class GPTDataConfig(DataConfig, GPTLegacyConfig): """ Configuration for the dataset(s), split and sampling. Currently hard-coded to a GPT dataset. @@ -41,3 +48,32 @@ class GPTDataConfig(DataConfig): desc="Multiprocessing context. Do not touch.", hint=FieldHint.expert, ) + + def _validate(self) -> None: + if not self.datasets: + logger.warning( + "Using the legacy dataset definition format." " Specify it through `data.datasets` instead." + ) + self.datasets = { + phase.value.lower(): GPTLegacyDatasetConfig.from_dict(self, strict=False) + for phase in (PhaseType.training, PhaseType.validation, PhaseType.test) + } + super()._validate() + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + # TODO v0.x: Remove backward compatibility. + if "datasets" in default: + for phase in PhaseType: + if phase.value in default["datasets"]: + rename = phase.value.lower() + logger.warning(f"Renaming dataset {phase.value} to {rename}") + assert rename not in default["datasets"] + default["datasets"][rename] = default["datasets"].pop(phase.value) + + return super()._from_dict(default, strict, flat) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 9df9b9b86..88841966c 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,8 +32,8 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None - images: list[torch.Tensor] | None = None - image_positions: list[torch.Tensor] | None = None + images: list[list[torch.Tensor]] | None = None + image_positions: list[list[torch.Tensor]] | None = None chosen_spans: list[torch.Tensor] | None = None rejected_spans: list[torch.Tensor] | None = None @@ -51,28 +51,24 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] - has_images = False - batch_images = [] - for sample in batch: - if sample.images is not None: - batch_images.append([torch.from_numpy(image) for image in sample.images]) - has_images = True - else: - batch_images.append([]) - batch_image_positions = [] - for sample in batch: - if sample.image_positions is not None: - batch_image_positions.append(torch.from_numpy(sample.image_positions)) - else: - batch_image_positions.append([]) + has_images = any(sample.images is not None for sample in batch) + if has_images: + images = [ + [] if sample.images is None else [torch.from_numpy(image) for image in sample.images] for sample in batch + ] + image_positions = [ + [] if sample.image_positions is None else torch.from_numpy(sample.image_positions) for sample in batch + ] + else: + images, image_positions = None, None return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, chosen_spans=stacked_chosen_spans, rejected_spans=stacked_rejected_spans, - images=batch_images if has_images else None, - image_positions=batch_image_positions if has_images else None, + images=images, + image_positions=image_positions, ) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 0c1b0cd09..5e3ced8a4 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -204,6 +204,11 @@ class BlendedDatasetConfig(SampledDatasetConfig): desc="The blending weight of each dataset.", hint=FieldHint.core, ) + legacy: bool = Field( + default=False, + desc="Use the legacy formulas for sub-dataset seeds and sample sizes.", + hint=FieldHint.deprecated, + ) def _validate(self) -> None: self.weights = normalize_probabilities(self.weights) @@ -226,10 +231,20 @@ def build_and_sample( sampling, parameters=dataclasses.replace( sampling.parameters, - num_samples=math.ceil(weight * sampling.parameters.num_samples) + 1, + num_samples=( + math.ceil( + weight + * ( + sampling.parameters.num_samples + + 5 * (sampling.parameters.num_samples * (1 - weight)) ** 0.5 + ) + ) + if self.legacy + else math.ceil(weight * sampling.parameters.num_samples) + 1 + ), ), # TODO: Seed may not be unique for nested blended datasets. - config=sampling.config.to_copy({"seed": sampling.config.seed + i * 697}), + config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}), ), ) for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True)) diff --git a/fast_llm/models/gpt/llava.py b/fast_llm/data/dataset/gpt/components/__init__.py similarity index 100% rename from fast_llm/models/gpt/llava.py rename to fast_llm/data/dataset/gpt/components/__init__.py diff --git a/fast_llm/data/dataset/gpt/components/config.py b/fast_llm/data/dataset/gpt/components/config.py new file mode 100644 index 000000000..59c419101 --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/config.py @@ -0,0 +1,20 @@ +import dataclasses + +from fast_llm.engine.config_utils.data_type import DataType + +# TODO: Store span type? +# class SpanType(enum.StrEnum): +# none = "none" +# loss_masking = "loss_masking" +# preference = "preference" + + +@dataclasses.dataclass(kw_only=True) +class GPTMemmapDatasetHeader: + num_documents: int + token_data_type: DataType = DataType.int64 + has_spans: bool = False + has_images: bool = False + + def __post_init__(self): + self.token_data_type = DataType(self.token_data_type) diff --git a/fast_llm/data/dataset/gpt/components/images.py b/fast_llm/data/dataset/gpt/components/images.py new file mode 100644 index 000000000..cf217b5f0 --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/images.py @@ -0,0 +1,243 @@ +import io +import math +import typing + +import numpy as np +import PIL.Image + +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.gpt.memmap import BufferOffset, ShiftMap +from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.utils import Assert, div + + +class GPTImageDatasetComponent: + def __init__( + self, + header: GPTMemmapDatasetHeader, + index_binary_buffer: memoryview, + binary_buffer: memoryview, + offset: BufferOffset, + ): + self._header = header + self._index_binary_buffer = index_binary_buffer + self._binary_buffer = binary_buffer + + self._count_cumsum = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._header.num_documents + 1, + offset=offset.value, + ) + offset.value += self._count_cumsum.nbytes + self._sizes = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._count_cumsum[-1] * 2, + offset=offset.value, + ).reshape(-1, 2) + offset.value += self._sizes.nbytes + self._positions = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._count_cumsum[-1], + offset=offset.value, + ).reshape(-1, 2) + offset.value += self._positions.nbytes + + def get( + self, + index: int, + start_offset: int, + end_offset: int, + shift_map: ShiftMap, + buffer_offset: BufferOffset, + parameters: GPTSamplingParameters, + ) -> tuple[list[np.ndarray] | None, np.ndarray | None]: + # We get images from the document, discarding those outside the selected range. + images = [] + positions = [] + for image_index in range(self._count_cumsum[index], self._count_cumsum[index + 1]): + image_buffer_size = self._sizes[image_index].prod(initial=3) + image_position = shift_map.shift(self._positions[image_index].item()) + if start_offset <= image_position < end_offset: + images.append( + np.frombuffer( + self._binary_buffer, + dtype=np.dtype(np.uint8), + count=image_buffer_size, + offset=buffer_offset.value, + ).reshape(3, *self._sizes[image_index]) + ) + positions.append(self._positions[image_index]) + + buffer_offset.value += image_buffer_size + + def _get_insert(self, image_index: int, parameters: GPTSamplingParameters): + height, width = resized_image_length + height_patches = div(height, parameters.patch_size) + width_patches = div(width, parameters.patch_size) + image_size = height_patches * width_patches + if parameters.image_break_token is not None: + image_size += height_patches + elif parameters.image_end_token is not None: + image_size += 1 + + image_token_array = np.full((image_size,), -100, dtype=np.int64) + if parameters.image_break_token is not None: + for row in range(height_patches): + position = (row + 1) * width_patches + row + image_token_array[position] = parameters.image_break_token + + if parameters.image_end_token is not None: + # Will override the last image_break_token. + image_token_array[-1] = parameters.image_end_token + + start_pos = 0 + sample_token_ids = [] + for idx, im_position in enumerate(sample.image_positions): + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] + if parameters.image_break_token is not None: + height, width = resized_image_lengths[idx] + num_patches_h = div(height, parameters.patch_size) + num_patches_w = div(width, parameters.patch_size) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + # account for break tokens after each row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = parameters.image_break_token + # handle the last row separately + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if parameters.image_end_token is not None: + image_token_array[last_row_position] = parameters.image_end_token + else: + image_token_array[last_row_position] = parameters.image_break_token + else: + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if parameters.image_end_token is not None: + image_token_array[-1] = parameters.image_end_token + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) + image_sizes[idx] + start_pos = im_position + + resized_image_lengths = [ + get_resize_dims( + *image_length, + parameters.max_image_size, + parameters.max_image_size, + parameters.patch_size, + ) + for image_length in image_lengths + ] + return images, positions + + @classmethod + def write_document_and_gather_index( + cls, document: GPTSample, index_data: dict[str, typing.Any], binary_stream: io.BufferedWriter + ): + has_images = document.images is not None + if "has_images" in index_data: + Assert.eq(index_data["has_images"], has_images) + else: + index_data["has_images"] = has_images + if has_images: + if "image_sizes" not in index_data: + index_data["image_sizes"] = [] + if "image_positions" not in index_data: + index_data["image_positions"] = [] + if "num_pixels" not in index_data: + index_data["num_pixels"] = 0 + for image, image_position in zip(document.images, document.image_positions, strict=True): + # assume 3 channels (RGB) for all images + # TODO: Not consistent with GPTSample? + with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + if img.mode != "RGB": + # Convert all images to RGB + img = img.convert("RGB") + pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." + index_data["image_sizes"].append(np.array(pixels.shape[1:])) + index_data["image_positions"].append(image_position) + # TODO: Shouldn't pixel count exclude the channel dimension? + index_data["num_pixels"] += pixels.size + binary_stream.write(pixels.tobytes(order="C")) + # Cumsum holds both image counts and buffer offsets. + if "image_cumsum" not in index_data: + index_data["image_cumsum"] = [0] + index_data["image_cumsum"].append(len(index_data["image_sizes"])) + + @classmethod + def write_index(self, index_data: dict[str, typing.Any], index_stream: io.BufferedWriter): + if index_data["has_images"]: + Assert.leq(index_data["image_cumsum"][-1], np.iinfo(np.int32).max) + Assert.eq(len(index_data["image_cumsum"]), index_data["num_documents"] + 1) + Assert.eq(len(index_data["image_sizes"]), index_data["image_cumsum"][-1]) + Assert.eq(len(index_data["image_positions"]), index_data["image_cumsum"][-1]) + index_stream.write(np.array(index_data["image_cumsum"], dtype=np.int32).tobytes(order="C")) + # n_pixels * 3 per image + index_stream.write(np.stack(index_data["image_sizes"], dtype=np.int32).tobytes(order="C")) + # Position of each image in the document + index_stream.write(np.array(index_data["image_positions"], dtype=np.int32).tobytes(order="C")) + + def get_sizes(self, index: int, parameters: GPTSamplingParameters) -> list[int]: + return [ + get_num_image_tokens( + *get_resize_dims( + *size.item(), + parameters.max_image_size, + parameters.max_image_size, + parameters.patch_size, + ), + parameters.patch_size, + image_break=parameters.image_break_token is not None, + image_end=parameters.image_end_token is not None, + ) + for size in self._sizes[self._count_cumsum[index] : self._count_cumsum[index + 1]] + ] + + def get_unshifted_positions_and_sizes( + self, index: int, parameters: GPTSamplingParameters + ) -> list[tuple[int, int]]: + return [ + (position, size) + for position, size in zip( + self._positions[self._count_cumsum[index] : self._count_cumsum[index + 1]], + self.get_sizes(index, parameters), + strict=True, + ) + ] + + +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) diff --git a/fast_llm/data/dataset/gpt/components/spans.py b/fast_llm/data/dataset/gpt/components/spans.py new file mode 100644 index 000000000..a4c331e00 --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/spans.py @@ -0,0 +1,73 @@ +import io +import typing + +import numpy as np + +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.memmap import BufferOffset, ShiftMap +from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.utils import Assert + + +class GPTSpansDatasetComponent: + def __init__( + self, + header: GPTMemmapDatasetHeader, + index_binary_buffer: memoryview, + binary_buffer: memoryview, + offset: BufferOffset, + ): + self._header = header + self._index_binary_buffer = index_binary_buffer + self._binary_buffer = binary_buffer + + self._count_cumsum = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._header.num_documents + 1, + offset=offset.value, + ) + offset.value += self._count_cumsum.nbytes + self._spans = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._count_cumsum[-1] * 2, + offset=offset.value, + ).reshape(-1, 2) + offset.value += self._spans.nbytes + + def get(self, index: int, start_offset: int, end_offset: int, shift_map: ShiftMap) -> list[tuple[int, int]]: + loss_masking_spans = [] + for span_begin, span_end in self._spans[self._count_cumsum[index] : self._count_cumsum[index + 1]].tolist(): + span_begin = max(shift_map.shift(span_begin), start_offset) - start_offset + span_end = min(shift_map.shift(span_end), end_offset - 1) - start_offset + if span_end > span_begin: + loss_masking_spans.append((span_begin, span_end)) + return loss_masking_spans + + @classmethod + def write_document_and_gather_index( + cls, document: GPTSample, index_data: dict[str, typing.Any], binary_stream: io.BufferedWriter + ): + has_spans = document.loss_masking_spans is not None + if "has_span" in index_data: + Assert.eq(index_data["has_span"], has_spans) + else: + index_data["has_span"] = has_spans + if has_spans: + if "spans" not in index_data: + index_data["spans"] = [] + index_data["spans"].extend(document.loss_masking_spans) + if "spans_cumsum" not in index_data: + index_data["spans_cumsum"] = [0] + index_data["spans_cumsum"].append(len(index_data["spans"])) + + @classmethod + def write_index(self, index_data: dict[str, typing.Any], index_stream: io.BufferedWriter): + if index_data["has_spans"]: + # Should be ok, checking just in case. + Assert.leq(index_data["spans_cumsum"][-1], np.iinfo(np.int32).max) + Assert.eq(len(index_data["spans_cumsum"]), index_data["num_documents"] + 1) + Assert.eq(len(index_data["spans"]), index_data["spans_cumsum"][-1]) + index_stream.write(np.array(index_data["spans_cumsum"], dtype=np.int32).tobytes(order="C")) + index_stream.write(np.vstack(index_data["spans"], dtype=np.int32).tobytes(order="C")) diff --git a/fast_llm/data/dataset/gpt/components/tokens.py b/fast_llm/data/dataset/gpt/components/tokens.py new file mode 100644 index 000000000..3a91fef54 --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/tokens.py @@ -0,0 +1,63 @@ +import io +import typing + +import numpy as np + +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.memmap import BufferOffset, ShiftMap +from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.utils import Assert + + +class GPTTokensDatasetComponent: + def __init__( + self, + header: GPTMemmapDatasetHeader, + index_binary_buffer: memoryview, + binary_buffer: memoryview, + offset: BufferOffset, + ): + self._header = header + self._index_binary_buffer = index_binary_buffer + self._binary_buffer = binary_buffer + self.sizes = np.frombuffer( + self._index_binary_buffer, dtype=np.int32, count=self._header.num_documents, offset=offset.value + ) + self._item_size = self._header.token_data_type.numpy.itemsize + offset.value += self.sizes.nbytes + + def get( + self, index: int, start_offset: int, end_offset: int, shift_map: ShiftMap, buffer_offset: BufferOffset + ) -> np.ndarray: + unshifted_start_offset = shift_map.unshift(start_offset) + token_ids = np.frombuffer( + self._binary_buffer, + dtype=self._header.token_data_type, + count=shift_map.unshift(end_offset) - unshifted_start_offset, + offset=buffer_offset.value + unshifted_start_offset * self._item_size, + ) + buffer_offset.value += self.sizes[index] * self._item_size + return token_ids + + @classmethod + def write_document_and_gather_index( + cls, document: GPTSample, index_data: dict[str, typing.Any], binary_stream: io.BufferedWriter + ): + if "token_data_type" in index_data: + Assert.eq(document.token_ids.dtype, index_data["token_data_type"]) + else: + index_data["token_data_type"] = document.token_ids.dtype + if "document_lengths" not in index_data: + index_data["document_lengths"] = [] + index_data["document_lengths"].append(document_length := len(document.token_ids)) + if "num_tokens" not in index_data: + index_data["num_tokens"] = 0 + index_data["num_tokens"] += document_length + + # Write document to binary file + binary_stream.write(document.token_ids.tobytes(order="C")) + + @classmethod + def write_index(self, index_data: dict[str, typing.Any], index_stream: io.BufferedWriter): + # Document (tokens) lengths. + index_stream.write(np.array(index_data["document_lengths"], dtype=np.int32).tobytes(order="C")) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 8835480a9..782ed6b4e 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -61,7 +61,15 @@ class GPTSamplingConfig(SamplingConfig): @dataclasses.dataclass(kw_only=True) -class GPTSamplingParameters(SamplingParameters): +class ImageSamplingParameters: + patch_size: int | None = None + max_image_size: int | None = None + image_break_token: int | None = None + image_end_token: int | None = None + + +@dataclasses.dataclass(kw_only=True) +class GPTSamplingParameters(SamplingParameters, ImageSamplingParameters): """ Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ @@ -70,12 +78,9 @@ class GPTSamplingParameters(SamplingParameters): vocab_size: int use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False + use_images: bool = False cross_document_attention: bool = True truncate_documents: bool = True - patch_size: int | None = None - max_image_size: int | None = None - image_break_token: int | None = None - image_end_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index b05b79b24..843f6735d 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -158,9 +158,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer._tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer._tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer._tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer.tokenize(prefix, add_eos=False)], dtype=np.int64) + middle = np.array([*self._tokenizer.tokenize(middle, add_bos=False, add_eos=False)], dtype=np.int64) + suffix = np.array([*self._tokenizer.tokenize(suffix, add_bos=False)], dtype=np.int64) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index b069e36e2..d2bd84bf9 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -3,7 +3,7 @@ import numpy as np -from fast_llm.data.dataset.gpt.config import GPTSamplingData +from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset if typing.TYPE_CHECKING: @@ -12,7 +12,7 @@ class GPTIndexedDataset(IndexedDataset): @abc.abstractmethod - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, @@ -20,7 +20,7 @@ def get_document_sizes(self) -> np.ndarray: """ @abc.abstractmethod - def get_document_size(self, index: int) -> int: + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: """ The size of a document in the dataset. """ @@ -30,14 +30,6 @@ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": return GPTSampledIndexedDataset(self, sampling) - @property - @abc.abstractmethod - def has_images(self) -> bool: - """ - Whether the dataset contains images. - This is used to determine whether to use image-related fields in the sampled data. - """ - class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): """ @@ -46,17 +38,12 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe _dataset: GPTIndexedDataset - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: # TODO: This can be really big. - doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else np.array([]) + return self._dataset.get_document_sizes(parameters)[self._begin : self._end] - def get_document_size(self, index: int) -> int: - return self._dataset.get_document_size(self._begin + index) - - @property - def has_images(self) -> bool: - return self._dataset.has_images + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: + return self._dataset.get_document_size(self._begin + index, parameters) class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( @@ -64,19 +51,10 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( ): _datasets: list[GPTIndexedDataset] - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: # TODO: This can be really big. - # return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) - sizes = [dataset.get_document_sizes() for dataset in self._datasets] - return ( - np.concatenate([size[0] for size in sizes]), - np.concatenate([size[1] for size in sizes]) if sizes[0][1] is not None else np.array([]), - ) - - def get_document_size(self, index: int) -> int: - dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") - return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + return np.concatenate([dataset.get_document_sizes(parameters) for dataset in self._datasets]) - @property - def has_images(self) -> bool: - return any(dataset.has_images for dataset in self._datasets) + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: + dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item(), parameters) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 4f62561a8..99d9957f8 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,18 +1,52 @@ -import io +import functools +import json import pathlib import struct import typing import numpy as np -import PIL.Image +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.components.images import GPTImageDatasetComponent +from fast_llm.data.dataset.gpt.components.spans import GPTSpansDatasetComponent +from fast_llm.data.dataset.gpt.components.tokens import GPTTokensDatasetComponent +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER -from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_INDEX_HEADER from fast_llm.utils import Assert, div +class BufferOffset: + # This makes offsets mutable. + def __init__(self, value: int): + self.value: int = value + + +class ShiftMap: + """ + A map between original and shifted token indices (i.e., accounting for extra content such as images). + Also serves as a cache so we don't have to recompute positions and sizes every time. + """ + + def __init__(self, positions_and_sizes: list[tuple[int, int]]): + self._positions_and_sizes = positions_and_sizes + + @functools.cached_property + def shifted_positions(self) -> list[int]: + return [self.shift(position) for position, _ in self._positions_and_sizes] + + def shift(self, index: int) -> int: + return index + sum(size for position, size in self._positions_and_sizes if index > position) + + def unshift(self, index: int) -> int: + return index - sum( + size + for shifted_position, (_, size) in zip(self.shifted_positions, self._positions_and_sizes, strict=True) + if shifted_position < index + ) + + class GPTMemmapDataset(GPTIndexedDataset): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, @@ -36,399 +70,233 @@ def _init( self, name: str, prefix: pathlib.Path | str, - num_documents: int | None, - num_tokens: int | None, - num_pixels: int | None, + num_documents: int | None = None, + num_tokens: int | None = None, + num_pixels: int | None = None, ) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) - self._has_spans = 0 - self._has_images = 0 - self._has_preference_spans = False with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: - self._has_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack("= 4: - self._has_images = struct.unpack("= 2 and bool(struct.unpack("= 3 and bool(struct.unpack("= 4 and bool(struct.unpack("= 2: - self._spans = [] - self._num_spans = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_documents, - offset=offset, - ) - self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] - for idx in range(self._num_documents): - self._spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_spans[idx] * 2, - offset=offset - + self._num_spans.nbytes - + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) - ) - offset += self._num_spans.nbytes + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize - # read preference spans - self._chosen_spans = None - self._rejected_spans = None - if self._has_preference_spans and self._version >= 3: - self._chosen_spans = [] - self._rejected_spans = [] - for idx in range(self._num_documents): - self._chosen_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - rejected_span_offset = offset + np.array(self._chosen_spans).nbytes - for idx in range(self._num_documents): - self._rejected_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes - - self._num_pixels = 0 - self._image_sizes = [] - self._image_positions = None - if self._has_images and self._version >= 4: - self._n_images = np.frombuffer( - self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset - ) - self._image_sizes = [] - self._image_positions = [] - images_seen = 0 - num_total_images = self._n_images.sum() - for n_images in self._n_images: - self._image_sizes.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=n_images * 2, - offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) - ) - self._num_pixels += self._image_sizes[-1].prod(axis=1, initial=3).sum() - self._image_positions.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=n_images, - offset=offset - + self._n_images.nbytes - + 2 * num_total_images * np.dtype(np.int32).itemsize - + +images_seen * np.dtype(np.int32).itemsize, - ) - ) - images_seen += n_images - - self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") - self._bin_buffer = memoryview(self._bin_buffer_mmap) - - self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) if num_pixels is not None: - assert self._num_pixels == num_pixels + Assert.eq(num_pixels, self._images.total_pixels) + + # TODO: Simplify. + self._num_tokens = ( + self._binary_buffer_mmap.size + if self._images is None + else self._binary_buffer_mmap.size - self._images.total_pixels + ) if num_tokens is not None: - assert self._num_tokens == num_tokens + Assert.eq(num_tokens, self._num_tokens) - def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) + def __getstate__(self) -> tuple[str, pathlib.Path]: + return (self._name, self._prefix) - def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): + def __setstate__(self, state: tuple[str, pathlib.Path]): self._init(*state) def __del__(self): if hasattr(self, "_bin_buffer_mmap"): - self._bin_buffer_mmap._mmap.close() # noqa - del self._bin_buffer_mmap + self._binary_buffer_mmap._mmap.close() # noqa + del self._binary_buffer_mmap if hasattr(self, "_index_bin_buffer"): - self._index_bin_buffer_mmap._mmap.close() # noqa - del self._index_bin_buffer_mmap + self._index_binary_buffer_mmap._mmap.close() # noqa + del self._index_binary_buffer_mmap def get( self, - idx: int, - offset: int = 0, - length: int | None = None, - use_loss_masking_spans: bool = False, - use_preference_loss_spans: bool = False, + index: int, + start_offset: int = 0, + end_offset: int | None = None, + parameters: GPTSamplingParameters | None = None, ) -> GPTSample: - token_ids = np.frombuffer( - self._bin_buffer, - dtype=self._dtype, - count=self._document_sizes[idx] - offset if length is None else length, - offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - ) - images = None - image_positions = None - if self._has_images: - image_positions = self._image_positions[idx] - - # Truncations with images are not yet supported, so we get all images from the document - pixels = np.frombuffer( - self._bin_buffer, - dtype=np.dtype(np.uint8), - count=self._image_sizes[idx].prod(initial=3, axis=1).sum(), - offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, - ) - images = [] - start = 0 - for image_size in self._image_sizes[idx]: - n_pixels = image_size.prod(initial=3) - images.append(pixels[start : start + n_pixels].reshape(3, image_size[0], image_size[1])) - start += n_pixels - sample_spans = None - if use_loss_masking_spans and self._spans is not None: - sample_spans = self._spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - sample_spans = sample_spans[ - (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - ] - - # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - - chosen_span = None - rejected_span = None - - if use_preference_loss_spans: - if not self._has_preference_spans: - raise ValueError("No preference spans found in memmap dataset.") - elif self._has_preference_spans and self._chosen_spans is None: - raise ValueError("Failed to read chosen spans from memmap dataset.") - elif self._has_preference_spans and self._rejected_spans is None: - raise ValueError("Failed to read rejected spans from memmap dataset.") - else: - chosen_span = self._chosen_spans[idx] - # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] + if end_offset is None: + end_offset = self.get_document_size(index, parameters) - # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset - chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset + shift_map = ShiftMap( + self._images.get_unshifted_positions_and_sizes(index, parameters) if parameters.use_images else [] + ) - rejected_span = self._rejected_spans[idx] + buffer_offset = BufferOffset(self._buffer_offsets[index].item()) + sample = GPTSample(token_ids=self._tokens.get(index, start_offset, end_offset, shift_map, buffer_offset)) - # filter spans that are outside the range of the selected tokens in the document - rejected_span = rejected_span[ - (rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset) - ][0] + if parameters.use_loss_masking_spans: + sample.loss_masking_spans = self._spans.get(index, start_offset, end_offset, shift_map) - # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset - rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + if parameters.use_images: + sample.images, sample.image_positions = self._images.get( + index, start_offset, end_offset, shift_map, buffer_offset + ) - return GPTSample( - token_ids=token_ids, - images=images, - image_positions=image_positions, - loss_masking_spans=sample_spans, - chosen_span=chosen_span, - rejected_span=rejected_span, - ) + start_pos = 0 + sample_token_ids = [] + for idx, im_position in enumerate(sample.image_positions): + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] + if parameters.image_break_token is not None: + height, width = resized_image_lengths[idx] + num_patches_h = div(height, parameters.patch_size) + num_patches_w = div(width, parameters.patch_size) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + # account for break tokens after each row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = parameters.image_break_token + # handle the last row separately + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if parameters.image_end_token is not None: + image_token_array[last_row_position] = parameters.image_end_token + else: + image_token_array[last_row_position] = parameters.image_break_token + else: + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if parameters.image_end_token is not None: + image_token_array[-1] = parameters.image_end_token + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) + image_sizes[idx] + start_pos = im_position + + return sample @property def name(self) -> str: return self._name def __len__(self) -> int: - return self._num_documents - - @property - def num_tokens(self) -> int: - return self._num_tokens + return self._header.num_documents - @property - def has_images(self) -> bool: - return self._has_images - - def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes, self._image_sizes + if parameters is not None and parameters.use_images: + # TODO: Optimize this. + return np.array([self.get_document_size(index, parameters) for index in range(self._header.num_documents)]) + return self._tokens.sizes + + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: + size = self._tokens.sizes[index].item() + if parameters is not None and parameters.use_images: + for _, size_ in self._images.get_positions_and_sizes(index, parameters): + size += size_ + return size + + def _shift_offset(self, offset, index: int, parameters: GPTSamplingParameters | None = None) -> int: + if parameters is not None and parameters.use_images: + offset += sum( + size for position, size in self._images.get_positions_and_sizes(index, parameters) if position < offset + ) + return offset - def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item(), self._image_sizes[index] if self._has_images else [] + def _unshift_offset(self, offset, index: int, parameters: GPTSamplingParameters | None = None) -> int: + unshifted_offset = offset + if parameters is not None and parameters.use_images: + for position, size in self._images.get_positions_and_sizes(index, parameters): + shifted_position = self._shift_offset(position, index, parameters) + if shifted_position < offset: + unshifted_offset -= size + return unshifted_offset @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): - # Initialize metadata - dtype = None + buffer_offsets = [] + index_data = {} num_documents = 0 - doc_lengths = [] - n_images = [] - image_sizes = [] - im_positions = [] - total_images = 0 - pointers = [] - offset = 0 - # number of spans for each document - num_spans = [] - spans = [] - chosen_spans = [] - rejected_spans = [] + component_classes = (GPTTokensDatasetComponent, GPTSpansDatasetComponent, GPTImageDatasetComponent) prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) - # Write the binary data file (.bin) lazily - with prefix.with_suffix(".bin").open("wb") as bin_stream: + with prefix.with_suffix(".bin").open("wb") as binary_stream: + for document in documents: - # Infer dtype from the first document - if dtype is None: - dtype = document.token_ids.dtype - assert dtype is not None, "Document dtype could not be inferred from the data." - - # Ensure all documents have the same dtype - assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}." - - # Write document to binary file - bin_stream.write(document.token_ids.tobytes(order="C")) - total_im_size = 0 - if document.images: - n_images.append(len(document.images)) - total_images += len(document.images) - for image in document.images: - # assume 3 channels (RGB) for all images - with PIL.Image.open(io.BytesIO(image["bytes"])) as img: - if img.mode != "RGB": - # Convert all images to RGB - img = img.convert("RGB") - pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW - assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." - image_sizes.append(np.array(pixels.shape[1:])) - bin_stream.write(pixels.tobytes(order="C")) - total_im_size += pixels.size - im_positions.extend(document.image_positions) - else: - n_images.append(0) - - # Update metadata - doc_length = len(document.token_ids) - doc_lengths.append(doc_length) - pointers.append(offset) - if document.loss_masking_spans is not None: - num_spans.append(len(document.loss_masking_spans)) - spans.append(document.loss_masking_spans) - if document.chosen_span is not None: - chosen_spans.append(document.chosen_span) - if document.rejected_span is not None: - rejected_spans.append(document.rejected_span) - offset += doc_length * np.dtype(dtype).itemsize + total_im_size * np.dtype(np.uint8).itemsize - num_documents += 1 + buffer_offsets.append(binary_stream.tell()) + for component_class in component_classes: + component_class.write_document_and_gather_index(document, index_data, binary_stream) + + # TODO: Address + assert document.chosen_span is None and document.rejected_span is None - # Finalize metadata arrays - doc_lengths = np.array(doc_lengths, dtype=np.int32) - pointers = np.array(pointers, dtype=np.int64) - num_spans = np.array(num_spans, dtype=np.int32) - if len(spans) > 0: - spans = np.vstack(spans, dtype=np.int32) - else: - spans = np.array(spans, dtype=np.int32) - chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) - rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) - - if total_images: - n_images = np.array(n_images, dtype=np.int32) - image_sizes = np.stack(image_sizes, dtype=np.int32) - im_positions = np.array(im_positions, dtype=np.int32) - else: - n_images = np.array([]) - image_sizes = np.array([]) - im_positions = np.array([]) + num_documents += 1 # Write the index file (.idx) - with prefix.with_suffix(".idx").open("wb") as idx_stream: - idx_stream.write(MEMMAP_INDEX_HEADER) - # Indicates the version - # Version 2 onwards optionally add loss-masking spans - # Version 3 optionally adds chosen/rejected spans - # Version 4 onwards optionally add images - idx_stream.write(struct.pack(" 0 else 0)) - # Flag to indicate whether preference loss-masking spans are present - idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) - # Flag to indicate whether images are present - idx_stream.write(struct.pack(" 0 else 0)) - # Data type - idx_stream.write(struct.pack(" None: """ Create a `GPTSampledDataset` with the requested parameters. """ - # Get the document sizes, the main information needed for sampling. - document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() - document_sizes = torch.from_numpy(document_sizes).to(self._device) - if image_sizes: - image_token_sizes = [] - for i, sizes in enumerate(image_sizes): - image_token_sizes.append( - sum( - get_num_image_tokens( - *get_resize_dims( - *size, - self._parameters.max_image_size, - self._parameters.max_image_size, - self._parameters.patch_size, - ), - self._parameters.patch_size, - image_break=self._parameters.image_break_token is not None, - image_end=self._parameters.image_end_token is not None, - ) - for size in sizes - ) - ) - image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + # Get the size each document, the main information needed for sampling. + # Note: "document" may refer to more than just text. + document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes(self._parameters)).to(self._device) + + documents_per_epoch, tokens_per_epoch, long_docs_filter = self._get_epoch_size(document_sizes) + num_epochs, shuffled_epochs = self._get_epoch_count(documents_per_epoch, tokens_per_epoch) + + shuffled_documents = documents_per_epoch * shuffled_epochs + unshuffled_epochs = num_epochs - shuffled_epochs + + yaml_data, cached = self._get_and_compare_yaml_data(documents_per_epoch, tokens_per_epoch, unshuffled_epochs) + if cached: + return + + if shuffled_documents > 1e8: + warnings.warn( + f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." + f" This may take a while and/or use an excessive amount of memory." + ) + elif documents_per_epoch > 1e8: + # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? + warnings.warn( + f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." + f" Sampling may take a while and/or use an excessive amount of memory." + ) + + document_shuffling = self._get_document_shuffling(documents_per_epoch, shuffled_documents, shuffled_epochs) + + # To get a sample on the fly we need to know where it begins, + # and this is a non-trivial information because the documents have variable length. + # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. + # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. + # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. + # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. + # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. + # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` + + # TODO: Allowing for max 100% extra tokens for padding, is that enough? + cumsum_dtype = get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs) + if unshuffled_epochs > 0: + token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum(document_sizes, 0, cumsum_dtype) + self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) else: - image_token_sizes = torch.zeros_like(document_sizes) + unshuffled_tokens = 0 - documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() + if shuffled_epochs > 0: + token_cumsum_shuffled, _ = self._get_token_cumsum( + document_sizes[ + # Torch indexing only works with int32 or int64 + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) + ], + self._unshuffled_tokens, + cumsum_dtype, + ) + self._token_cumsum_shuffled.save(token_cumsum_shuffled) + self._document_shuffling.save( + document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy(force=True) + ) - # Calculate basic stats. - if not self._truncate_documents: + yaml_data["unshuffled_tokens"] = unshuffled_tokens + self._load_yaml_data(yaml_data) + if self._yaml_path is not None: + self._yaml_path.parent.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(yaml_data, self._yaml_path.open("w")) + + def _get_epoch_size(self, document_sizes: torch.Tensor) -> tuple[int, int, torch.Tensor | None]: + documents_per_epoch = document_sizes.numel() + if self._truncate_documents: + tokens_per_epoch = document_sizes.sum().item() + long_docs_filter = None + else: assert _extension_available, ( "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1 - ignored_documents = long_docs_filter.sum().item() - if ignored_documents: + long_docs_filter = document_sizes <= self._parameters.sequence_length + 1 + documents_per_epoch_filtered = long_docs_filter.sum().item() + if ignored_documents := documents_per_epoch_filtered - documents_per_epoch: log_main_rank( - f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", + f" > {ignored_documents}/{documents_per_epoch} documents" + f" are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = (document_sizes[~long_docs_filter] + image_token_sizes[~long_docs_filter]).sum().item() + tokens_per_epoch = document_sizes[long_docs_filter].sum().item() if tokens_per_epoch == 0: raise RuntimeError( - f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." + f" > No documents shorter than {self._parameters.sequence_length+1}" + f" tokens found in dataset {self._indexed_dataset.name}." ) + return documents_per_epoch, tokens_per_epoch, long_docs_filter + def _get_epoch_count(self, documents_per_epoch: int, tokens_per_epoch: int) -> tuple[int, int]: # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, # but in case of truncations we also include those last labels in the following sample, # so we need `sequence_length * num_samples + extra_tokens` tokens in total. - if self._parameters.use_preference_loss_spans: - documents_per_epoch = (~long_docs_filter).sum().item() - num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) - elif self._truncate_documents: + if self._truncate_documents: num_epochs = math.ceil( (self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens) / tokens_per_epoch @@ -206,35 +246,34 @@ def _sample(self) -> None: ) # Prepare for shuffling. - generator = torch.Generator(device=self._device) if self._config.shuffle == ShufflingType.skip_first_epoch: shuffled_epochs = num_epochs - 1 elif self._config.shuffle == ShufflingType.disabled: shuffled_epochs = 0 else: shuffled_epochs = num_epochs - shuffled_documents = documents_per_epoch * shuffled_epochs - unshuffled_epochs = num_epochs - shuffled_epochs + return num_epochs, shuffled_epochs + def _get_and_compare_yaml_data( + self, + documents_per_epoch: int, + tokens_per_epoch: int, + unshuffled_epochs: int, + ) -> tuple[dict[str, typing.Any], bool]: yaml_data = { "dataset": { "name": self._indexed_dataset.name, "documents_per_epoch": documents_per_epoch, "tokens_per_epoch": tokens_per_epoch, }, - "num_samples": self._parameters.num_samples, + "sampling": self._parameters.__dict__, "unshuffled_epochs": unshuffled_epochs, - "sequence_length": self._parameters.sequence_length, - "patch_size": self._parameters.patch_size, - "truncate_documents": self._truncate_documents, - "image_break_token": self._parameters.image_break_token, - "image_end_token": self._parameters.image_end_token, "config": self._config.to_dict(), } if self._truncate_documents: yaml_data["unshuffled_tokens"] = tokens_per_epoch * unshuffled_epochs - if self._yaml_path is not None and self._yaml_path.is_file(): + if cached := (self._yaml_path is not None and self._yaml_path.is_file()): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) # Hack to make sure unshuffled tokens are loaded if not self._truncate_documents: @@ -251,123 +290,8 @@ def _sample(self) -> None: ) # Dataset is already sampled, skip. logger.info(f"Using existing sampling for dataset {self.name}") - return - - if shuffled_documents > 1e8: - warnings.warn( - f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." - f" This may take a while and/or use an excessive amount of memory." - ) - elif documents_per_epoch > 1e8: - # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? - warnings.warn( - f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." - f" Sampling may take a while and/or use an excessive amount of memory." - ) - - # Use the smallest possible data type to save memory and disk usage. - document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch - # Shuffle the dataset (documents) - # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial - # so we only evaluate and store the shuffled part `document_shuffling`. - if self._config.shuffle == ShufflingType.full: - generator.manual_seed(self._config.seed) - # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` - document_shuffling = ( - torch.randperm( - shuffled_documents, - generator=generator, - dtype=get_unsigned_integer_type(shuffled_documents).torch, - device=self._device, - ) - .remainder_(documents_per_epoch) - .to(dtype=document_shuffling_dtype) - ) - elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): - document_shuffling = torch.empty( - shuffled_documents, - dtype=document_shuffling_dtype, - device=self._device, - ) - for i in range(shuffled_epochs): - generator.manual_seed(self._config.seed + i * 571) - torch.randperm( - documents_per_epoch, - generator=generator, - out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], - ) - elif self._config.shuffle == ShufflingType.disabled: - document_shuffling = None - else: - raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") - - if self._parameters.use_preference_loss_spans: - yaml_data["unshuffled_tokens"] = 0 # not used, ignore - - # index of all documents less than seq length long - doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] - self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) - - # apply shuffling on doc_length_filtered_indicies - if shuffled_epochs > 0: - self._document_shuffling.save( - document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu) - ) - self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - return - - # To get a sample on the fly we need to know where it begins, - # and this is a non-trivial information because the documents have variable length. - # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. - # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. - # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. - # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. - # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. - # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` - if unshuffled_epochs > 0: - token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes + image_token_sizes, - offset=0, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) - else: - unshuffled_tokens = 0 - if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens - self._load_yaml_data(yaml_data) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - - if shuffled_epochs > 0: - token_cumsum_shuffled, _ = self._get_token_cumsum( - document_sizes[ - # Torch indexing only works with int32 or int64 - document_shuffling.to( - dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 - ) - ] - + image_token_sizes[ - document_shuffling.to(torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32) - ], - offset=self._unshuffled_tokens, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_shuffled.save(token_cumsum_shuffled) - self._document_shuffling.save( - document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy( - force=self._config.gpu - ) - ) - # Free memory - del document_shuffling + return yaml_data, cached def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: if self._truncate_documents: @@ -410,6 +334,50 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - ] return out, num_tokens + def _get_document_shuffling( + self, + documents_per_epoch: int, + shuffled_documents: int, + shuffled_epochs: int, + ) -> torch.Tensor | None: + generator = torch.Generator(device=self._device) + # Use the smallest possible data type to save memory and disk usage. + document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch + # Shuffle the dataset (documents) + # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial + # so we only evaluate and store the shuffled part `document_shuffling`. + if self._config.shuffle == ShufflingType.full: + generator.manual_seed(self._config.seed) + # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` + document_shuffling = ( + torch.randperm( + shuffled_documents, + generator=generator, + dtype=get_unsigned_integer_type(shuffled_documents).torch, + device=self._device, + ) + .remainder_(documents_per_epoch) + .to(dtype=document_shuffling_dtype) + ) + elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): + document_shuffling = torch.empty( + shuffled_documents, + dtype=document_shuffling_dtype, + device=self._device, + ) + for i in range(shuffled_epochs): + generator.manual_seed(self._config.seed + i * 571) + torch.randperm( + documents_per_epoch, + generator=generator, + out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], + ) + elif self._config.shuffle == ShufflingType.disabled: + document_shuffling = None + else: + raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") + return document_shuffling + def __len__(self) -> int: return self._parameters.num_samples @@ -421,39 +389,6 @@ def __getitem__(self, index: int) -> typing.Any: """ self._lazy_load() - if self._parameters.use_preference_loss_spans: - if index < self._unshuffled_documents: - document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] - else: - document_index = self._doc_length_filtered_indicies[ - self._document_shuffling[index - self._unshuffled_documents].item() - ] - - sample = self._indexed_dataset.get( - document_index, - offset=0, - length=self._document_sizes[document_index], - use_loss_masking_spans=self._parameters.use_loss_masking_spans, - use_preference_loss_spans=self._parameters.use_preference_loss_spans, - ) - - chosen_span_end = sample.chosen_span[1] + 1 - sequence_lengths = [ - chosen_span_end, - len(sample.token_ids) - chosen_span_end, - ] - - # compute padding size - padding = np.full((self._parameters.sequence_length + 1,), 0) - padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) - sample.token_ids = padding - - if not self._parameters.cross_document_attention: - sample.sequence_lengths = np.array(sequence_lengths) - - return sample - # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample sample_length = ( @@ -479,10 +414,12 @@ def __getitem__(self, index: int) -> typing.Any: token_count = token_start_array[token_start_cumsum_index] token_ids = [] - loss_masking_spans = [] - images = [] - image_positions = [] - image_tokens_added = 0 + if self._parameters.use_loss_masking_spans: + loss_masking_spans = [] + if self._parameters.use_images: + images = [] + image_positions = [] + image_tokens_added = 0 text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. @@ -491,28 +428,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) - - resized_image_lengths = [ - get_resize_dims( - *image_length, - self._parameters.max_image_size, - self._parameters.max_image_size, - self._parameters.patch_size, - ) - for image_length in image_lengths - ] - image_sizes = [ - get_num_image_tokens( - *image_length, - self._parameters.patch_size, - image_break=self._parameters.image_break_token is not None, - image_end=self._parameters.image_end_token is not None, - ) - for image_length in resized_image_lengths - ] - image_tokens = sum(image_sizes) - document_size = text_size + image_tokens + document_size = self._indexed_dataset.get_document_size(document_index, self._parameters) if not self._truncate_documents: if document_size > self._parameters.sequence_length + 1: @@ -542,16 +458,15 @@ def __getitem__(self, index: int) -> typing.Any: if token_count + document_size > token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) - token_end_index_in_document = min(token_end - token_count, text_size) - sample = self._indexed_dataset.get( + token_end_index_in_document = min(token_end - token_count, document_size) + sample: GPTSample = self._indexed_dataset.get( document_index, offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) - start_pos = 0 - has_images = sample.image_positions is not None - if has_images: + if self._parameters.use_images: + start_pos = 0 sample_token_ids = [] for idx, im_position in enumerate(sample.image_positions): # add placeholder masked tokens for images @@ -586,42 +501,25 @@ def __getitem__(self, index: int) -> typing.Any: sample_token_ids.append(sample.token_ids[start_pos:]) text_tokens_added += len(sample_token_ids[-1]) token_ids.append(np.concatenate(sample_token_ids)) - else: - token_ids.append(sample.token_ids[start_pos:]) - text_tokens_added += len(token_ids[-1]) - if sample.images: images.append(sample.images) else: - images.append([]) + token_ids.append(sample.token_ids) + text_tokens_added += len(token_ids[-1]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: - prev_image_tokens = 0 - image_idx = 0 - image_position = ( - sample.image_positions[image_idx] - if has_images and image_idx < len(sample.image_positions) - else float("inf") - ) - while image_position < loss_masking_span[0]: - prev_image_tokens += image_sizes[image_idx] - image_idx += 1 - image_position = ( - sample.image_positions[image_idx] - if has_images and image_idx < len(sample.image_positions) - else float("inf") + if self._parameters.use_images: + # Shift the spans to account for the images. + loss_masking_span[0] += sum( + image_size + for image_size, image_position in zip(image_sizes, sample.image_positions) + if image_position < loss_masking_span[0] ) - span_image_tokens = 0 - while image_position <= loss_masking_span[1]: - span_image_tokens += image_sizes[image_idx] - image_idx += 1 - image_position = ( - sample.image_positions[image_idx] - if has_images and image_idx < len(sample.image_positions) - else float("inf") + loss_masking_span[1] += sum( + image_size + for image_size, image_position in zip(image_sizes, sample.image_positions) + if image_position < loss_masking_span[1] ) - loss_masking_span[0] += prev_image_tokens - loss_masking_span[1] += prev_image_tokens + span_image_tokens - prev_image_tokens += span_image_tokens + span = np.clip( loss_masking_span + token_count - token_start, 0, @@ -639,14 +537,15 @@ def __getitem__(self, index: int) -> typing.Any: if not self._parameters.cross_document_attention else None ) + token_ids = np.concatenate(token_ids, dtype=np.int64) loss_masking_spans = ( (np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) if self._parameters.use_loss_masking_spans else None ) - images = [im for img_list in images for im in img_list] if images else None - image_positions = np.array(image_positions) if image_positions else None + images = [im for img_list in images for im in img_list] if self._parameters.use_images else None + image_positions = np.array(image_positions) if self._parameters.use_images else None Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) return GPTSample( @@ -661,20 +560,36 @@ def __getitem__(self, index: int) -> typing.Any: def name(self) -> str: return self._indexed_dataset.name + def _get_image_sizes(self, document_index: int): + # TODO: Duplicate of _get_document_sizes + image_lengths = self._indexed_dataset.get_image_size(document_index) + + resized_image_lengths = [ + get_resize_dims( + *image_length, + self._parameters.max_image_size, + self._parameters.max_image_size, + self._parameters.patch_size, + ) + for image_length in image_lengths + ] + image_sizes = [ + get_num_image_tokens( + *image_length, + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for image_length in resized_image_lengths + ] + image_tokens = sum(image_sizes) + return resized_image_lengths, image_sizes, image_tokens + def _lazy_load(self): if not hasattr(self, "_documents_per_epoch"): self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r"))) def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - - if self._parameters.use_preference_loss_spans: - data["unshuffled_tokens"] = 0 # not used, ignore - elif "unshuffled_tokens" not in data: - # Backward compatibility - # TODO v0.x: Remove - assert self._truncate_documents - data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] - self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index d46e38935..23a839af7 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -49,7 +49,7 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: ) def tokenize( - self, text: str, char_spans=None, image_positions=None + self, text: str, add_bos=True, add_eos=True, char_spans=None, image_positions=None ) -> tuple[list[int], list[tuple[int, int]], list[int]]: """ Tokenize the input text and return the tokenized input_ids, token spans, and image token positions. @@ -81,7 +81,9 @@ def tokenize( # We only tokenize if there is at least one character, else we might potentially add begin/end multiple times if char_pos < position[0]: tokenized_text = self._tokenize( - text[char_pos : position[0]], begin=(char_pos == 0), end=position[0] > len(text) - 1 + text[char_pos : position[0]], + begin=add_bos and char_pos == 0, + end=add_eos and position[0] > len(text) - 1, ) token_ids.extend(tokenized_text) char_pos = position[0] @@ -106,7 +108,7 @@ def tokenize( current_span_start = None # Handle any remaining text after the last position and add EOS token if char_pos < len(text): - tokenized_text = self._tokenize(text[char_pos:], begin=(char_pos == 0), end=True) + tokenized_text = self._tokenize(text[char_pos:], begin=add_bos and char_pos == 0, end=add_eos) token_ids.extend(tokenized_text) return token_ids, token_spans, image_token_positions diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index 0929b7cb1..f4a2cfd6c 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -9,7 +9,7 @@ from triton import language as tl -class DataType(str, enum.Enum): +class DataType(enum.StrEnum): """ An enum to represent data types independently of third party libraries, so we can swap them more easily and allow for lazy imports. diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 9a2be6b49..2d4f049f0 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -64,6 +64,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -71,6 +72,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) @@ -273,7 +275,7 @@ def _query_key_value_backward( input_grad.add_(self.key_value.backward(key_value_grad, context.pop("key_value"))) return input_grad - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 58a786947..6a85f43e5 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -203,7 +203,7 @@ def _get_correction(self, beta: float, dim: int) -> float: ) -class Rotary2D[ConfigType: DefaultRotaryConfig](DefaultRotary[Rotary2DConfig]): +class Rotary2D[ConfigType: Rotary2DConfig](DefaultRotary[ConfigType]): _rotary_embedding_frequencies: torch.Tensor _tensor_cache_max_num_patches: int = -1 diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 773cce87e..0f975c9c5 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -123,16 +123,3 @@ class Block[ConfigType: Config](BaseBlock[ConfigType], Layer): """ Base class for actual blocks, i.e., base blocks that are also `Layers`. """ - - def __init__( - self, - config: ConfigType, - distributed_config: DistributedConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - return_input: bool = False, - ): - super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) - self._return_input = return_input diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index df5bd8181..47c1ab9b7 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -87,7 +87,6 @@ def get_block( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - return_input: bool = False, ) -> "Block": return self.layer_class( self, @@ -95,7 +94,6 @@ def get_block( hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, - return_input=return_input, ) diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index e7c6d9e92..0dc118269 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -1,7 +1,12 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.engine.config_utils.initialization import Initialization, init_uniform_centered_, init_zeros_ +from fast_llm.engine.config_utils.initialization import ( + Initialization, + init_normal_, + init_uniform_centered_, + init_zeros_, +) from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.config import ActivationType @@ -9,7 +14,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.common.linear.convolution import CausalConv1d + from fast_llm.layers.common.linear.convolution import CausalConv1d, Convolution2D from fast_llm.layers.common.linear.linear import LinearBase @@ -217,3 +222,44 @@ def get_layer( return CausalConv1d( weight, bias, activation=default_activation if self.activation is None else self.activation ) + + +@config_class +class Convolution2DConfig(AffineLinearBaseConfig): + def get_layer( + self, + in_dim: TensorDim, + out_dim: TensorDim, + kernel_dim_1: TensorDim, + kernel_dim_2: TensorDim, + *, + stride: tuple[int, int], + default_weight_initialization: Initialization | None = None, + default_bias_initialization: Initialization | None = None, + default_add_bias: bool = True, + lr_scale: float | None, + peft: PeftConfig | None, + ) -> "Convolution2D": + from fast_llm.layers.common.linear.convolution import Convolution2D + + if default_weight_initialization is None: + default_weight_initialization = init_normal_() + if default_bias_initialization is None: + default_bias_initialization = init_normal_() + + lr_scale = (combine_lr_scales(lr_scale, self.lr_scale),) + weight = self.weight.get_parameter( + (out_dim, in_dim, kernel_dim_1, kernel_dim_2), + default_initialization=default_weight_initialization, + lr_scale=lr_scale, + peft=peft, + ) + bias = self.bias.get_parameter( + (out_dim,), + default_initialization=default_bias_initialization, + lr_scale=lr_scale, + default_enabled=default_add_bias, + peft=peft, + ) + + return Convolution2D(weight, bias, stride=stride) diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index b88b7b2e6..6281348e1 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -55,3 +55,27 @@ def _forward_causal_conv1d(self, input_: torch.Tensor) -> torch.Tensor: def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: raise NotImplementedError() + + +class Convolution2D(torch.nn.Module): + """ + TODO: Generalize to other convolutions? + """ + + def __init__( + self, + weight: ParameterMeta, + bias: ParameterMeta | None, + *, + stride: tuple[int, int], + ): + super().__init__() + self.weight = weight + self.bias = bias + self._stride = stride + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self._stride) + + def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: + raise NotImplementedError() diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index ba4c370c2..ce9893ea4 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.block import BaseBlock, Block +from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig @@ -19,18 +19,44 @@ logger = logging.getLogger(__name__) -class BlockWithBias[ConfigType: Config](BaseBlock[ConfigType]): +class BlockWithBias[ConfigType: Config](Block[ConfigType]): """ Base class for mixer and MLP modules. """ - @abc.abstractmethod + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) + self._return_bias = return_bias + def forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor: + output, bias = self._forward(input_, kwargs, losses, metrics) + if self._return_bias: + return output, bias + else: + return output if bias is None else output + bias + + @abc.abstractmethod + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: pass @@ -58,7 +84,7 @@ def __init__( peft=peft, ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. - self._return_input: bool = return_input + self._return_input = return_input # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) @@ -70,6 +96,7 @@ def __init__( self._hidden_dim, self._lr_scale, peft=peft, + return_bias=True, ) self.mlp = self._config.mlp.get_layer( @@ -77,6 +104,7 @@ def __init__( self._hidden_dim, self._lr_scale, peft=peft, + return_bias=True, ) def setup(self, distributed: Distributed) -> None: diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 5f8131b5c..724b8d172 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BaseBlockConfig, BlockConfig +from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -15,7 +15,7 @@ @config_class() -class BlockWithBiasConfig(BaseBlockConfig): +class BlockWithBiasConfig(BlockConfig): """ A common interface for various blocks and block layers. """ @@ -30,6 +30,7 @@ def get_layer( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = False, ) -> "BlockWithBias": return self.layer_class( self, @@ -37,6 +38,7 @@ def get_layer( hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, + return_bias=return_bias, ) @@ -94,6 +96,23 @@ def layer_class(self) -> "type[DecoderBlock]": return DecoderBlock + def get_block( + self, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_input: bool = False, + ) -> "DecoderBlock": + return self.layer_class( + self, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + return_input=return_input, + ) + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: return self.mixer.get_preprocessors(distributed_config) + self.mlp.get_preprocessors(distributed_config) diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 089fa2dc7..d4cb46dbf 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -46,6 +46,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): Assert.gt(config.experts, 1) # TODO: Implement? @@ -56,6 +57,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self.router = self._config.router.get_layer( self._hidden_dim, @@ -83,9 +85,9 @@ def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, None]: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug.enabled: diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 9dd17d698..aaea94adb 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -28,6 +28,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -35,6 +36,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, self._intermediate_2_dim = self._get_intermediate_dims() @@ -102,7 +104,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): _config: MLPConfig - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 5bd35eb71..216db6de3 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -401,7 +401,12 @@ def get_blocks(self, distributed_config: DistributedConfig): peft=self.peft, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self.output_layer.prediction_heads > 1 and i == len(self.decoder) - 1, + # TODO: Not all blocks support this argument. + **( + {"return_input": True} + if self.output_layer.prediction_heads > 1 and i == len(self.decoder) - 1 + else {} + ), ) for i in range(len(self.decoder)) ], diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d1e13a5b..362ffaa22 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -36,17 +36,13 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - return_input: bool = False, ): - if return_input: - raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, - return_input=return_input, ) self._residual_dtype = ( self._distributed_config.optimization_dtype diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index eb8302025..21753006a 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -58,17 +58,13 @@ def __init__( lr_scale: float | None, peft: PeftConfig | None, prediction_distance: int, - return_input: bool = False, ): - if return_input: - raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, - return_input=return_input, ) self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index f014012b2..c9fc609b0 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -43,6 +43,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -50,6 +51,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) state_dim = TensorDim("state", self._config.state_size) v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) @@ -128,7 +130,7 @@ def __init__( peft=self._peft, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index e77a4468b..081aabe65 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -43,13 +43,10 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( - config, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for Mamba" @@ -120,7 +117,7 @@ def __init__( peft=self._peft, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b0657313d..4b0bd4366 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -41,13 +41,10 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( - config, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) num_heads = div(self._config.d_inner, self._config.state_size) @@ -153,7 +150,7 @@ def __init__( BlockDimNames.sequence_q, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py deleted file mode 100644 index 7ec50dfee..000000000 --- a/fast_llm/layers/vision_encoder/adapter.py +++ /dev/null @@ -1,55 +0,0 @@ -import typing - -import torch - -from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.functional.triton.mlp import torch_mlp_activation -from fast_llm.layers.common.linear import Linear -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames -from fast_llm.tensor import TensorMeta, init_normal_ - - -class VisionAdapter(Layer): - """ - Vision adapter layer for the LLM. - """ - - def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): - super().__init__() - input_dim = tensor_space[VisionEncoderDimNames.out_channels] - self._activation_type = config.adapter_activation_type - self.layer_1 = Linear( - input_dim, - tensor_space[VisionEncoderDimNames.adapter_size], - bias=True, - weight_init_method=init_normal_(std=config.adapter_init_method_std), - bias_init_method=init_normal_(std=config.adapter_init_method_std), - lr_scale=config.adapter_lr_scale, - ) - self.layer_2 = Linear( - tensor_space[VisionEncoderDimNames.adapter_size], - tensor_space[TransformerDimNames.hidden], - bias=True, - weight_init_method=init_normal_(std=config.adapter_init_method_std), - bias_init_method=init_normal_(std=config.adapter_init_method_std), - lr_scale=config.adapter_lr_scale, - ) - - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> torch.Tensor: - if isinstance(input_, TensorMeta): - return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], - tensor_name="Vision adapter output", - dtype=input_.dtype, - ) - return self.layer_2( - torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) - ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index a705d948a..966c7843b 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,12 +1,16 @@ import enum +import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import NormalizationConfig -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.utils import Assert +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig +from fast_llm.layers.common.linear.config import Convolution2DConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig +from fast_llm.layers.decoder.config import MLPBaseConfig + +if typing.TYPE_CHECKING: + pass class VisionEncoderDimNames: @@ -24,9 +28,6 @@ class VisionEncoderKwargs: image_positions = "image_positions" max_image_size = "max_image_size" image_sizes = "image_sizes" - image_mean = "image_normalization_mean" - image_std = "image_normalization_std" - image_rescale_factor = "image_rescale_factor" rope_theta = "vit_rope_theta" rotary_inv_freq = "vit_rotary_inv_freq" kv_channels = "vit_kv_channels" @@ -82,48 +83,83 @@ class VisionEncoderType(str, enum.Enum): pixtral = "pixtral" -@config_class(registry=True) -class VisionEncoderConfig(BaseModelConfig): +@config_class() +class PatchConvolutionConfig(BlockConfig): _abstract = False - - type: VisionEncoderType = Field( - default=VisionEncoderType.none, - desc="Type of the vision encoder. Choices: none, pixtral.", + convolution: Convolution2DConfig = Field( + desc="Configuration for the 2d convolution.", hint=FieldHint.architecture, ) - transformer: TransformerConfig = Field( - desc="Configuration for the vision transformer architecture.", - hint=FieldHint.core, + normalization: NormalizationConfig = Field( + desc="Configuration for the normalization layer.", + hint=FieldHint.architecture, ) patch_size: int = Field( default=16, - desc="Patch size for the image encoder.", + desc="Size of image patches, in pixels (width and height).", hint=FieldHint.core, ) - conv_bias: bool = Field( - default=False, - desc="Whether to use bias in the convolutional layer.", - hint=FieldHint.optional, + input_channels: int = Field( + default=3, + desc="Number of pixel channels (usually 3).", + hint=FieldHint.feature, ) - patch_norm: NormalizationConfig = Field( - desc="Configuration for the normalization layers applied to the image patches.", - hint=FieldHint.optional, + + +@config_class(registry=True) +class VisionEncoderConfig(BaseModelConfig): + _abstract = False + patch_convolution_layer: PatchConvolutionConfig = Field( + desc="Configuration for the patch convolution layer.", + hint=FieldHint.architecture, ) - adapter_size: int = Field( - default=5120, - desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", - hint=FieldHint.core, + adapter_layer: MLPBaseConfig = Field( + desc="Configuration for the adapter layer.", + hint=FieldHint.architecture, ) - adapter_activation_type: ActivationType = Field( - default=ActivationType.gelu, - desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", - hint=FieldHint.core, + decoder: BlockSequenceConfig = Field( + desc="Configuration for the vision decoder.", + hint=FieldHint.architecture, ) - adapter_bias: bool = Field( - default=True, - desc="Whether to use bias in the adapter linear layer.", - hint=FieldHint.optional, + + type: VisionEncoderType = Field( + default=VisionEncoderType.none, + desc="Type of the vision encoder. Choices: none, pixtral.", + hint=FieldHint.architecture, ) + # transformer: TransformerConfig = Field( + # desc="Configuration for the vision transformer architecture.", + # hint=FieldHint.core, + # ) + # patch_size: int = Field( + # default=16, + # desc="Patch size for the image encoder.", + # hint=FieldHint.core, + # ) + # conv_bias: bool = Field( + # default=False, + # desc="Whether to use bias in the convolutional layer.", + # hint=FieldHint.optional, + # ) + # patch_norm: NormalizationConfig = Field( + # desc="Configuration for the normalization layers applied to the image patches.", + # hint=FieldHint.optional, + # ) + # adapter_size: int = Field( + # default=5120, + # desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + # hint=FieldHint.core, + # ) + # adapter_activation_type: ActivationType = Field( + # default=ActivationType.gelu, + # desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + # hint=FieldHint.core, + # ) + # adapter_bias: bool = Field( + # default=True, + # desc="Whether to use bias in the adapter linear layer.", + # hint=FieldHint.optional, + # ) image_normalization: ImageNormalizationConfig = Field( desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, @@ -138,24 +174,24 @@ class VisionEncoderConfig(BaseModelConfig): desc="Token id to indicate the end of an image. If None, no token id is applied.", hint=FieldHint.optional, ) - adapter_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate scale for the adapter weights.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - conv_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate scale for the convolutional layer weights.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - adapter_init_method_std: float = Field( - default=None, - desc="Standard deviation for the normal initialization of the adapter weights. Default: adapter_size ** -0.5.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) + # adapter_lr_scale: float | None = Field( + # default=None, + # desc="Custom learning rate scale for the adapter weights.", + # hint=FieldHint.feature, + # valid=skip_valid_if_none(check_field(Assert.geq, 0)), + # ) + # conv_lr_scale: float | None = Field( + # default=None, + # desc="Custom learning rate scale for the convolutional layer weights.", + # hint=FieldHint.feature, + # valid=skip_valid_if_none(check_field(Assert.geq, 0)), + # ) + # adapter_init_method_std: float = Field( + # default=None, + # desc="Standard deviation for the normal initialization of the adapter weights. Default: adapter_size ** -0.5.", + # hint=FieldHint.optional, + # valid=check_field(Assert.geq, 0), + # ) def _validate(self) -> None: with self._set_implicit_default(): @@ -163,7 +199,7 @@ def _validate(self) -> None: self.adapter_init_method_std = self.adapter_size**-0.5 super()._validate() - def setup_tensor_space(self, tensor_space: TensorSpace): + def setup_tensor_space(self, tensor_space: "TensorSpace"): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py deleted file mode 100644 index 6c2a70930..000000000 --- a/fast_llm/layers/vision_encoder/patch_conv.py +++ /dev/null @@ -1,62 +0,0 @@ -import typing - -import torch - -from fast_llm.core.ops import split -from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ - - -class PatchConv(Layer): - def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): - super().__init__() - self._tensor_space = tensor_space - self._distributed_config = tensor_space.distributed_config - self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._lr_scale = config.adapter_lr_scale - self.weight = ParameterMeta.from_dims( - ( - self._tensor_space[VisionEncoderDimNames.out_channels], - self._tensor_space[VisionEncoderDimNames.in_channels], - self._tensor_space[VisionEncoderDimNames.patch_size], - self._tensor_space[VisionEncoderDimNames.patch_size], - ), - init_method=init_normal_(), - lr_scale=self._lr_scale, - ) - if config.conv_bias: - self.bias = ParameterMeta.from_dims( - (self._tensor_space[VisionEncoderDimNames.out_channels],), - init_method=init_normal_(), - lr_scale=self._lr_scale, - ) - else: - self.bias = None - self.norm = config.patch_norm.get_layer(tensor_space[VisionEncoderDimNames.out_channels]) - self.stride = config.patch_size - - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict | None = None, - ) -> torch.Tensor: - hidden_dims = kwargs[VisionTransformerKwargs.hidden_dims] - if isinstance(input_, TensorMeta): - return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] - sequence_length = kwargs[TransformerKwargs.sequence_length] - out_channels = kwargs[VisionEncoderKwargs.out_channels] - reshape_dims = (micro_batch_size, sequence_length, out_channels) - group = self._tensor_space.distributed.tensor_group - input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) - patch_embeddings = self.norm(input_.flatten(1)) - patch_embeddings = patch_embeddings.view(reshape_dims) - if self._sequence_parallel: - patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() - patch_embeddings = split(patch_embeddings, group=group, dim=0) - return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/patch_convolution.py b/fast_llm/layers/vision_encoder/patch_convolution.py new file mode 100644 index 000000000..0a41efe2c --- /dev/null +++ b/fast_llm/layers/vision_encoder/patch_convolution.py @@ -0,0 +1,77 @@ +import typing + +import torch + +from fast_llm.core.ops import split +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.block.block import Block +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.vision_encoder.config import PatchConvolutionConfig, VisionEncoderKwargs +from fast_llm.tensor import TensorMeta + + +class PatchConvolution[ConfigType: PatchConvolutionConfig](Block[ConfigType]): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + # TODO: Input or output dim? + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + input_dim = TensorDim("input_channels", self._config.input_channels) + patch_dim = TensorDim("patch", self._config.patch_size) + + self.convolution = self._config.convolution.get_layer( + self._hidden_dim, + input_dim, + patch_dim, + patch_dim, + stride=(self._config.patch_size, self._config.patch_size), + default_add_bias=False, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.normalization = self._config.normalization.get_layer(hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + # TODO: Replace last dim instead + hidden_dims = kwargs[VisionTransformerKwargs.hidden_dims] + return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) + micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] + sequence_length = kwargs[AttentionKwargs.sequence_length] + out_channels = kwargs[VisionEncoderKwargs.out_channels] + # TODO: Avoid padding + reshape_dims = (micro_batch_size, sequence_length, out_channels) + group = self._tensor_space.distributed.tensor_group + + input_ = self.convolution(input_) + patch_embeddings = self.norm(input_.flatten(1)) + patch_embeddings = patch_embeddings.view(reshape_dims) + + # TODO: Sequence first" + if sequence_first: + patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() + if self._sequence_parallel: + patch_embeddings = split(patch_embeddings, group=group, dim=0) + return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index adacd380c..65eac6c45 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -5,15 +5,16 @@ import torchvision.transforms.v2 as torchvision_transforms from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerDimNames, VisionTransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.layers.vision_encoder.config import ImageNormalizationConfig, VisionEncoderConfig, VisionEncoderKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import div -def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int]: +def get_num_patches(height: int, width: int, patch_size: int) -> int: """ Calculate the number of patches in height and width dimensions. """ @@ -49,21 +50,14 @@ def get_resize_dims(height: int, width: int, max_height: int, max_width: int, pa return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) -def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: - target_height, target_width = get_resize_dims( - image.size(1), image.size(2), max_height, max_width, patch_size=patch_size - ) - height, width = image.size(1), image.size(2) - while height > 2 * target_height or width > 2 * target_width: - # cap the resizing to half of the current size as a workaround for large images - # See pytorch issue: https://github.com/pytorch/pytorch/issues/103589 - intermediate_max_width = max(target_width, width // 2) - intermediate_max_height = max(target_height, height // 2) - height, width = get_resize_dims( - height, width, intermediate_max_height, intermediate_max_width, patch_size=patch_size - ) +def resize(image: torch.Tensor, target_height: int, target_width: int) -> torch.Tensor: + # cap the resizing to half of the current size as a workaround for large images + # See pytorch issue: https://github.com/pytorch/pytorch/issues/103589 + while max(image.size(1) / target_height, image.size(2) / target_width) > 2: image = torchvision_transforms.functional.resize( - image, size=(height, width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC + image, + size=(math.ceil(image.size(1) / 2), math.ceil(image.size(2) / 2)), + interpolation=torchvision_transforms.InterpolationMode.BICUBIC, ) # TODO: options for interpolation mode? @@ -72,22 +66,6 @@ def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int ) -def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: - """ - Normalize the image using the specified mean and standard deviation. - """ - return torchvision_transforms.functional.normalize(image, mean=mean, std=std) - - -def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: - """ - Pad images on the right and bottom with 0s untitl max_height and max_width - """ - width_padding = max(0, max_height - image.size(1)) - depth_padding = max(0, max_width - image.size(2)) - return torchvision_transforms.functional.pad(image, (0, 0, depth_padding, width_padding), 0) - - def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) max_patches_per_side = max_image_size // patch_size @@ -111,57 +89,36 @@ def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, pat def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: patch_height = height // patch_size patch_width = width // patch_size - mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") - h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) - ids = h_grid * max_size + v_grid - return ids[:, 0] + return torch.arange(patch_height).repeat_interleave(patch_width) * max_size + torch.arange(patch_width).repeat( + patch_height + ) class VisionPreprocessor(Preprocessor): - def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + def __init__(self, config: VisionEncoderConfig, distributed: Distributed): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + self._distributed = distributed def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( ( TensorDim( - VisionTransformerDimNames.batch, - kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + "vision_batch", + kwargs[AttentionKwargs.micro_batch_size] * kwargs[AttentionKwargs.sequence_q_dim].size, ), - TensorDim(VisionEncoderDimNames.in_channels, 3), - TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), - TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + TensorDim("in_channels", 3), + TensorDim("patch_size", self._config.patch_size), + TensorDim("patch_size", self._config.patch_size), ), - dtype=self._distributed_config.training_dtype.torch, + dtype=self._distributed.config.training_dtype.torch, ) def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: - images = kwargs.get(VisionEncoderKwargs.images) max_image_size = kwargs.get(VisionEncoderKwargs.max_image_size) - im_width = kwargs.get(VisionEncoderKwargs.max_image_size) - patch_size = kwargs[VisionEncoderKwargs.patch_size] - image_positions = kwargs.get(VisionEncoderKwargs.image_positions) - image_sizes = [ - [get_resize_dims(im.size(1), im.size(2), max_image_size, im_width, patch_size=patch_size) for im in ims] - for ims in images - ] - kwargs[VisionEncoderKwargs.image_sizes] = image_sizes - images = [ - [ - normalize( - resize(image, max_image_size, im_width, patch_size).to( - dtype=self._tensor_space.distributed_config.training_dtype.torch - ) - / kwargs[VisionEncoderKwargs.image_rescale_factor], - mean=kwargs[VisionEncoderKwargs.image_mean], - std=kwargs[VisionEncoderKwargs.image_std], - ) - for image in imgs - ] - for imgs in images - ] + patch_size = self._config.patch_size + image_sizes = [] + + norm_config: ImageNormalizationConfig = kwargs["norm_config"] if LanguageModelKwargs.labels in kwargs: labels = kwargs[LanguageModelKwargs.labels] @@ -169,113 +126,113 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # If image break or end token is present, we need to replace image token ids to -100 in labels # TODO: avoid double cloning labels in case of loss masking spans? labels = labels.clone() - patches = [] patch_position_ids = [] - cu_seqlens = [0] - max_seqlen = -1 - kwargs.get(TransformerKwargs.sequence_first) - for idx, (imgs, sizes, positions) in enumerate(zip(images, image_sizes, image_positions)): - # add an empty tensor for clean concatenation in case of no images - seq_patches = [ - torch.tensor([]).to( - dtype=self._tensor_space.distributed_config.training_dtype.torch, - device=self._tensor_space.distributed.device, + sequence_lengths = [0] + max_sequence_length = -1 + kwargs.get(AttentionKwargs.sequence_first) + + for sample_index, (sample_images_, positions) in enumerate( + zip(kwargs[VisionEncoderKwargs.images], kwargs.get(VisionEncoderKwargs.image_positions), strict=True) + ): + image_sizes.append(sample_image_sizes := []) + + sample_sequence_length = 0 + + for image, position in zip(sample_images_, positions, strict=True): + height, width = get_resize_dims( + image.size(1), image.size(2), max_image_size, max_image_size, patch_size=patch_size ) - ] - sample_cu_seqlen = 0 - for image, size, position in zip(imgs, sizes, positions): - seqlen = get_num_patches(*size, patch_size) - num_tokens = get_num_image_tokens( - *size, - patch_size=patch_size, - image_break=self._config.image_break_token is not None, - image_end=self._config.image_end_token is not None, + + sample_image_sizes.append((height, width)) + + image = resize(image, height, width) + + # TODO: Normalize with constant dtype instead? + image = image.to(dtype=self._distributed.config.training_dtype.torch) + + image = torchvision_transforms.functional.normalize( + image / norm_config.rescale_factor, + mean=[norm_config.mean_r, norm_config.mean_g, norm_config.mean_b], + std=[norm_config.std_r, norm_config.std_g, norm_config.std_b], ) - if LanguageModelKwargs.labels in kwargs: - # set labels for image patches to -100 - labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 - if seqlen > max_seqlen: - max_seqlen = seqlen - cu_seqlens.append(cu_seqlens[-1] + seqlen) - sample_cu_seqlen += seqlen - seq_patches.append( - torch.cat( - [ - torch.nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).T.reshape( - -1, 3, patch_size, patch_size - ), - ] + patches.extend( + torch.nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).T.reshape( + -1, 3, patch_size, patch_size ) ) - padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen - if padding_size > max_seqlen: - max_seqlen = padding_size - cu_seqlens.append(kwargs[TransformerKwargs.sequence_length] * (idx + 1)) + + num_height_patches = div(height, patch_size) + num_width_patches = div(width, patch_size) + grid_height = torch.arange(num_height_patches).repeat_interleave(num_width_patches) + grid_width = torch.arange(num_width_patches).repeat(num_height_patches) + grid_height * div(max_image_size, patch_size) + grid_width + patch_position_ids.append(grid_height * div(max_image_size, patch_size) + grid_width) + + if LanguageModelKwargs.labels in kwargs: + num_tokens = get_num_image_tokens( + height, + width, + patch_size=patch_size, + image_break=self._config.image_break_token is not None, + image_end=self._config.image_end_token is not None, + ) + # set labels for image patches to -100 + labels[sample_index, max(position - 1, 0) : position + num_tokens - 1] = -100 + + sequence_lengths.append(sequence_length := num_height_patches * num_width_patches) + if sequence_length > max_sequence_length: + max_sequence_length = sequence_length + sample_sequence_length += sequence_length + + # TODO: No need for padding with varlen? + padding_size = kwargs[AttentionKwargs.sequence_length] - sample_sequence_length + if padding_size > max_sequence_length: + max_sequence_length = padding_size + sequence_lengths.append(padding_size) + patches.append( - torch.cat( - [ - *seq_patches, - torch.zeros(padding_size, 3, patch_size, patch_size).to( - dtype=self._tensor_space.distributed_config.training_dtype.torch, - device=self._tensor_space.distributed.device, - ), - ] - ) - ) - if sizes: - position_ids = torch.cat( - [position_ids_in_meshgrid(*size, max_image_size // patch_size, patch_size) for size in sizes] - ).to(device=self._tensor_space.distributed.device) - else: - position_ids = torch.tensor( - [], - dtype=torch.int64, + torch.zeros(padding_size, 3, patch_size, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, device=self._tensor_space.distributed.device, - ) - # We pad at the end instead of padding at the position in meshgrid because flash attention does not support custom attention masks - patch_position_ids.append( - torch.cat( - [ - position_ids, - torch.full((padding_size,), 0).to(device=self._tensor_space.distributed.device), - ] - ) + ), ) - assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] - patches = torch.cat(patches) - patch_position_ids = torch.cat(patch_position_ids) - kwargs[VisionEncoderKwargs.image_patches] = patches - kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids + patch_position_ids.append(torch.full((padding_size,), 0, dtype=torch.int64)) + + kwargs[VisionEncoderKwargs.image_sizes] = image_sizes + kwargs[VisionEncoderKwargs.image_patches] = torch.cat(patches).to(device=self._distributed.device) + kwargs[VisionTransformerKwargs.patch_position_ids] = torch.cat(patch_position_ids).to( + device=self._distributed.device + ) kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( kwargs[VisionEncoderKwargs.rope_theta], kwargs[VisionEncoderKwargs.kv_channels], max_image_size, patch_size, - ).to(device=self._tensor_space.distributed.device) - kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size * im_width, patch_size**2) + ).to(device=self._distributed.device) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size**2, patch_size**2) # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( - cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + cu_seqlens, device=self._distributed.device, dtype=torch.int32 ) kwargs[VisionTransformerKwargs.cu_seqlens_k] = torch.tensor( - cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + cu_seqlens, device=self._distributed.device, dtype=torch.int32 ) - kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen - kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen + kwargs[VisionTransformerKwargs.max_seqlen_q] = max_sequence_length + kwargs[VisionTransformerKwargs.max_seqlen_k] = max_sequence_length if LanguageModelKwargs.labels in kwargs: kwargs[LanguageModelKwargs.labels] = labels # TODO: add proper preprocessing for attention-mask when not using flash attention # Following is just a dummy code to run the tests. kwargs[self._config.transformer._transformer_kwargs.attention_mask] = torch.ones( - (1, 1, kwargs[TransformerKwargs.sequence_length], 1, kwargs[TransformerKwargs.sequence_length]), + (1, 1, kwargs[AttentionKwargs.sequence_length], 1, kwargs[AttentionKwargs.sequence_length]), dtype=torch.bool, device=self._tensor_space.distributed.device, ) kwargs[self._config.transformer._transformer_kwargs.attention_mask_value] = torch.full( [], - torch.finfo(self._distributed_config.training_dtype.torch).min, - dtype=self._distributed_config.training_dtype.torch, - device=self._tensor_space.distributed.device, + torch.finfo(self._distributed.config.training_dtype.torch).min, + dtype=self._distributed.config.training_dtype.torch, + device=self._distributed.device, ) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py deleted file mode 100644 index 534d813ff..000000000 --- a/fast_llm/models/custom/model.py +++ /dev/null @@ -1,70 +0,0 @@ -import typing - -import torch - -from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import Layer, LossDef -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerBlock -from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig -from fast_llm.models.custom.head import CustomHead -from fast_llm.models.gpt.config import GPTBaseModelConfig -from fast_llm.models.gpt.model import GPTBaseModel, GPTModel -from fast_llm.tensor import TensorMeta - - -class CustomBaseModel[ConfigType: CustomBaseModelConfig](GPTBaseModel[ConfigType]): - config_class: typing.ClassVar[type[GPTBaseModelConfig]] = GPTBaseModelConfig - - def __init__( - self, - config: CustomBaseModelConfig, - distributed_config: DistributedConfig, - ): - # TODO: Implement / update. - super().__init__(config, distributed_config) - - def get_layers(self) -> list[Layer]: - # TODO: Adjust as needed. - return [ - LanguageModelEmbedding(self._config, self._tensor_space), - *[ - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, - ) - for i in range(self._config.transformer.num_layers) - ], - CustomHead(self._config, self._tensor_space), - ] - - def preprocess_meta( - self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType - ) -> list[tuple[TensorMeta, dict]]: - # TODO: Adjust or reimplement. - return super().preprocess_meta(batch_meta, phase) - - def preprocess( - self, - batch: GPTBatch, - preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, - *, - phase: PhaseType, - iteration: int, - metrics: dict | None = None, - ) -> list[tuple[torch.Tensor, dict]]: - # TODO: Adjust or reimplement. - return super().preprocess(batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics) - - @property - def loss_defs(self) -> list[LossDef]: - # TODO: Adjust or reimplement. - return super().loss_defs - - -class CustomModel[ConfigType: CustomBaseModelConfig](GPTModel[ConfigType]): - config_class: typing.ClassVar[type[CustomModelConfig]] = CustomModelConfig - base_model_class: typing.ClassVar[type[CustomBaseModel]] = CustomBaseModel diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 7712d764d..58d1ed04c 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -130,25 +130,6 @@ def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelF return HuggingfaceGPTModelForCausalLM - @classmethod - def get_checkpoint_format(cls, format: type[CheckpointFormat]) -> type[CheckpointFormat]: - if isinstance(format, type) and issubclass(format, CheckpointFormat): - format_ = cls.get_checkpoint_format(format.name) - Assert.is_(format, format_) - return format_ - elif isinstance(format, dict): - for format_ in cls.checkpoint_formats: - if format_.name == format["name"]: - if (vision_name := format.get("vision_name")) is not None: - format_.vision_name = vision_name - if (text_name := format.get("text_name")) is not None: - format_.text_name = text_name - return format_ - for format_ in cls.checkpoint_formats: - if format_.name == format: - return format_ - raise ValueError(f"Checkpoint format {format} not supported for model {cls.model_name}") - @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): @@ -202,10 +183,6 @@ def _validate(self) -> None: ) Assert.geq(output_layer.prediction_heads, output_layer.prediction_heads) - if self.model.base_model.vision_encoder.enabled: - assert self.batch.max_image_size is not None, "max_image_size must be set when using vision encoder" - Assert.gt(self.batch.max_image_size, 0) - @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: from fast_llm.models.gpt.trainer import GPTTrainer diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bbe9f5cb6..b5f429376 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -11,17 +11,10 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor -from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding -from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs -from fast_llm.layers.vision_encoder.patch_conv import PatchConv -from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -51,59 +44,28 @@ def __init__( param, self._config.decoder.block, config.embeddings_layer.hidden_size ) # Noqa # `self._reference_models` is not populated at this point, so we pass a mutable dict. - self._preprocessors: list[Preprocessor] = [] - if self._config.use_absolute_position_embeddings: - self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) - # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. - # TODO: Find a better solution. - self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) - if self._use_flash_attention: - self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) - else: - self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + self._preprocessors: list[Preprocessor] = self._config.get_preprocessors(distributed_config) - if self._config.enable_dpo: # TODO better way to pass in? - self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) + # self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) - if self._config.vision_encoder.enabled: - self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) - self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) - - def get_output_layers(self) -> list[Layer]: - layers = [] - for i in range(self._config.prediction_heads): - if i > 0: - layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, - # TODO MTP: which index? - block_index=max(self._config.transformer.num_layers + i, 1), - # The last layer only returns the transformer output. - # The previous layers return a stack of shared_hidden and transformer_output. - return_input=i < self._config.prediction_heads - 1, - ) - ) - layers.append( - LanguageModelHead( - self._config, - self._tensor_space, - prediction_distance=i, - ) - ) - return layers - - def get_vision_layers(self) -> list[Layer]: - vit_layers = [ - VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, block_index=idx + 1) - for idx in range(self._config.vision_encoder.transformer.num_layers) - ] - return [ - PatchConv(self._config.vision_encoder, self._tensor_space), - *vit_layers, - VisionAdapter(self._config.vision_encoder, self._tensor_space), - MultiModalEmbedding(self._config, self._tensor_space), - ] + def get_layers(self) -> list[Layer]: + return self._config.get_blocks(self._distributed_config) + + # TODO ====== Vision ====== + # def get_vision_layers(self) -> list[Layer]: + # vit_layers = [ + # VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, block_index=idx + 1) + # for idx in range(self._config.vision_encoder.transformer.num_layers) + # ] + # return [ + # PatchConv(self._config.vision_encoder, self._tensor_space), + # *vit_layers, + # VisionAdapter(self._config.vision_encoder, self._tensor_space), + # MultiModalEmbedding(self._config, self._tensor_space), + # ] def get_embedding_layers(self) -> list[Layer]: if self._config.vision_encoder.enabled: @@ -111,23 +73,6 @@ def get_embedding_layers(self) -> list[Layer]: else: return [LanguageModelEmbedding(self._config, self._tensor_space)] - def get_layers(self) -> list[Layer]: - return [ - *(self.get_embedding_layers()), - *[ - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, - # The last layer only returns the transformer output. - # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, - ) - for i in range(self._config.transformer.num_layers) - ], - *self.get_output_layers(), - ] - def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: @@ -146,41 +91,8 @@ def preprocess_meta( micro_sequence_length = sequence_length truncate_documents = True - if self._config.vision_encoder.enabled: - try: - max_image_size = batch_meta.max_image_size - except AttributeError: - max_image_size = 256 - logger.warning("Inference mode: max_image_size not provided, defaulting to 256") - image_mean = [ - self._config.vision_encoder.image_normalization.mean_r, - self._config.vision_encoder.image_normalization.mean_g, - self._config.vision_encoder.image_normalization.mean_b, - ] - image_std = [ - self._config.vision_encoder.image_normalization.std_r, - self._config.vision_encoder.image_normalization.std_g, - self._config.vision_encoder.image_normalization.std_b, - ] - image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor - vision_kwargs = { - VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, - VisionEncoderKwargs.max_image_size: max_image_size, - VisionEncoderKwargs.image_mean: image_mean, - VisionEncoderKwargs.image_std: image_std, - VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, - VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, - VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size, - VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size, - } - else: - vision_kwargs = {} - - batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) - - if isinstance(batch_meta, GPTBatchConfig): - micro_sequence_length = batch_meta.micro_sequence_length + batch_data = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data) + batch_dim = TensorDim(BlockDimNames.batch, micro_batch_size * batch_data.size, batch_data) if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -210,24 +122,11 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, self._hidden_dim) if sequence_first else (batch_dim, hidden_sequence_q_dim, self._hidden_dim) ) - if self._config.vision_encoder.enabled: - vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden] - vision_hidden_dims = ( - (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) - if sequence_first - else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) - ) - vision_kwargs.update( - { - VisionTransformerKwargs.hidden_dims: vision_hidden_dims, - } - ) common_kwargs = { LanguageModelKwargs.phase: phase, @@ -235,10 +134,35 @@ def preprocess_meta( AttentionKwargs.hidden_dims: hidden_dims, AttentionKwargs.sequence_length: sequence_length, AttentionKwargs.sequence_q_dim: sequence_q_dim, - AttentionKwargs.micro_batch_size: micro_batch_size, LanguageModelKwargs.mask_inputs: not truncate_documents, } - common_kwargs.update(vision_kwargs) + + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # try: + # max_image_size = batch_meta.max_image_size + # except AttributeError: + # max_image_size = 256 + # logger.warning("Inference mode: max_image_size not provided, defaulting to 256") + # vision_kwargs = { + # VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, + # VisionEncoderKwargs.max_image_size: max_image_size, + # VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, + # VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size, + # VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size, + # } + # vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden] + # vision_hidden_dims = ( + # (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + # if sequence_first + # else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + # ) + # vision_kwargs.update( + # { + # VisionTransformerKwargs.hidden_dims: vision_hidden_dims, + # } + # ) + # common_kwargs.update(vision_kwargs) sequence_k_pasts = range( sequence_q_dim.size * self._distributed_config.sequence_data_rank, @@ -284,11 +208,14 @@ def preprocess_meta( reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs - if self._config.vision_encoder.enabled: - # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size - preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) - else: - preprocessed_meta.append((tokens, kwargs)) + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + # preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + # else: + # preprocessed_meta.append((tokens, kwargs)) + + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -373,12 +300,10 @@ def preprocess( labels = token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config - labels_cloned = False if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() - labels_cloned = True - for i, spans in enumerate(batch.loss_masking_spans): + for idx, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue valid_spans = spans[ @@ -391,20 +316,21 @@ def preprocess( valid_spans -= sequence_offset for start, end in valid_spans: if sequence_first: - labels[start : end + 1, i] = -100 + labels[start : end + 1, idx] = -100 else: - labels[i, start : end + 1] = -100 - if self._config.vision_encoder.enabled: - if self._config.vision_encoder.image_break_token is not None: - if not labels_cloned: - labels = labels.clone() - labels_cloned = True - labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) - if self._config.vision_encoder.image_end_token is not None: - if not labels_cloned: - labels = labels.clone() - labels_cloned = True - labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) + labels[idx, start : end + 1] = -100 + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # if self._config.vision_encoder.image_break_token is not None: + # if not labels_cloned: + # labels = labels.clone() + # labels_cloned = True + # labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + # if self._config.vision_encoder.image_end_token is not None: + # if not labels_cloned: + # labels = labels.clone() + # labels_cloned = True + # labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) # Loss-masking for distillation losses if self._config.distillation_model is not None: loss_mask = torch.ones_like(labels, dtype=torch.bool) @@ -413,48 +339,58 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) - if self._config.vision_encoder.enabled: - batch_images = ( - batch.images if batch.images is not None else [[]] * kwargs[AttentionKwargs.micro_batch_size] - ) - kwargs[VisionEncoderKwargs.images] = [ - [ - img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) - for img in images - ] - for images in batch_images - ] - kwargs[VisionEncoderKwargs.image_positions] = ( - batch.image_positions - if batch.image_positions is not None - else [[]] * kwargs[AttentionKwargs.micro_batch_size] - ) - kwargs[LanguageModelKwargs.tokens] = tokens + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # batch_images = ( + # batch.images if batch.images is not None else [[]] * kwargs[AttentionKwargs.micro_batch_size] + # ) + # kwargs[VisionEncoderKwargs.images] = [ + # [ + # img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + # for img in images + # ] + # for images in batch_images + # ] + # kwargs[VisionEncoderKwargs.image_positions] = ( + # batch.image_positions + # if batch.image_positions is not None + # else [[]] * kwargs[AttentionKwargs.micro_batch_size] + # ) + # kwargs[LanguageModelKwargs.tokens] = tokens for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) - image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) - if image_patches is not None: - preprocessed.append((image_patches, kwargs)) - else: - preprocessed.append((tokens, kwargs)) + + # TODO ====== Vision ====== + # image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + # if image_patches is not None: + # preprocessed.append((image_patches, kwargs)) + # else: + # preprocessed.append((tokens, kwargs)) + + preprocessed.append((tokens, kwargs)) return preprocessed @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[self.embedding_layer_index] + return self.layers[0] - @property - def transformer_layers(self) -> list[TransformerBlock]: - return self.layers[self.embedding_layer_index + 1 : -1] + # TODO ====== Vision ====== + # @property + # def embedding(self) -> LanguageModelEmbedding: + # return self.layers[self.embedding_layer_index] - @property - def embedding_layer_index(self) -> int: - if self._config.vision_encoder.enabled: - return self._config.vision_encoder.transformer.num_layers + 2 - else: - return 0 + # @property + # def transformer_layers(self) -> list[TransformerBlock]: + # return self.layers[self.embedding_layer_index + 1 : -1] + + # @property + # def embedding_layer_index(self) -> int: + # if self._config.vision_encoder.enabled: + # return self._config.vision_encoder.transformer.num_layers + 2 + # else: + # return 0 @property def model_head(self) -> LanguageModelHead: @@ -469,7 +405,9 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - (self.embedding_layer_index, *self.model_head_indices), + # TODO ====== Vision ====== + # (self.embedding_layer_index, *self.model_head_indices), + (0, *self.model_head_indices), ) } elif self._config.output_layer.prediction_heads > 1: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index cc676d184..171a9a9b2 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -31,13 +31,15 @@ def _get_sampling_parameters( "extra_tokens": self._config.model.base_model.output_layer.prediction_heads, } ) - if self._config.model.base_model.vision_encoder.enabled: - parameters.update( - { - "patch_size": self._config.model.base_model.vision_encoder.patch_size, - "max_image_size": self._config.batch.max_image_size, - "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, - "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, - } - ) + + # TODO ====== Vision ====== + # if self._config.model.base_model.vision_encoder.enabled: + # parameters.update( + # { + # "patch_size": self._config.model.base_model.vision_encoder.patch_size, + # "max_image_size": self._config.batch.max_image_size, + # "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, + # "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, + # } + # ) return parameters if _return_dict else GPTSamplingParameters(**parameters) diff --git a/tests/data/common.py b/tests/data/common.py index b7284872f..afee8694e 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -218,16 +218,11 @@ def name(self) -> str: def __len__(self) -> int: return self._config.num_documents - def get_document_sizes(self) -> np.ndarray: - return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64), np.array( - [], dtype=np.int64 - ) + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: + return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) - def get_document_size(self, index: int) -> int: + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: return self._config.num_tokens_per_document def get(self, index: int, *args, **kwargs) -> typing.Any: raise NotImplementedError() - - def has_images(self) -> bool: - return False diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index e04af129a..2f9b143b9 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, ShufflingType +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters, ShufflingType from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.utils import Assert @@ -78,7 +78,7 @@ def get(self, index: int, offset=0, length=None, use_loss_masking_spans: bool = def __len__(self) -> int: return len(self._samples) - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: doc_sizes = [] im_sizes = [] for index in range(len(self)): @@ -87,16 +87,12 @@ def get_document_sizes(self) -> np.ndarray: im_sizes.append(im_size) return np.array(doc_sizes, dtype=np.int64), np.array(im_sizes, dtype=np.int64) - def get_document_size(self, index: int) -> int: + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: return len(self._samples[index]), [] def name(self) -> str: return "dataset" - @property - def has_images(self) -> bool: - return False - TEST_DATASET = SimpleGPTIndexedDataset( [ diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py index bbfa4b21a..926bcc346 100644 --- a/tools/concatenate_dataset.py +++ b/tools/concatenate_dataset.py @@ -35,6 +35,7 @@ def run(self): dataset_dict = { "prefix": str(prefix.relative_to(self.directory)), "num_documents": len(dataset), + # Todo: fix "num_tokens": dataset.num_tokens, } if self.min_tokens is not None and dataset_dict["num_tokens"] < self.min_tokens: diff --git a/tools/generate_config_yaml_for_sharded_dst.py b/tools/generate_config_yaml_for_sharded_dst.py deleted file mode 100644 index c0b4fa24d..000000000 --- a/tools/generate_config_yaml_for_sharded_dst.py +++ /dev/null @@ -1,98 +0,0 @@ -import argparse -import pathlib - -import yaml - -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator, GPTMemmapDatasetPreparatorConfig - -""" -This script is intended to be used only for creation of fast_llm_config.yaml files for sharded datasets encoded with older version of the prepare command. -""" - - -def read_dataset_shard_config(shard_path): - """ - Read a dataset shard from the given path. - - Args: - shard_path: Path to the shard prefix (without .idx or .bin extension) - - Returns: - A GPTMemmapDataset instance - """ - # Convert to pathlib.Path if it's a string - path = pathlib.Path(shard_path) if isinstance(shard_path, str) else shard_path - - # Create a GPTMemmapDataset instance - # The name parameter is just for identification - dataset = GPTMemmapDataset(name=path.name, prefix=path) - - # Print basic information about the dataset - print(f"Dataset: {dataset.name}") - print(f"Number of documents: {dataset._num_documents}") - print(f"Number of tokens: {dataset.num_tokens}") - - return GPTMemmapDatasetConfig.from_dict( - { - "type": "memmap", - "path": path.name.replace(".bin", ""), - "num_documents": dataset._num_documents, - "num_tokens": dataset.num_tokens, - } - ) - - -def get_preparator(prepare_config: GPTMemmapDatasetPreparatorConfig) -> GPTMemmapDatasetPreparator: - config = GPTMemmapDatasetPreparatorConfig.from_dict( - { - "output_path": prepare_config.output_path, - "dataset": {"path": prepare_config.dataset.path}, - "tokenizer": {"path": prepare_config.tokenizer.path}, - }, - {}, - ) - return config.get_dataset_preparator_class()(config=config) - - -def main(config_dict): - prepare_config = GPTMemmapDatasetPreparatorConfig.from_dict(config_dict) - destination = pathlib.Path(prepare_config.output_path) - - shards = list(destination.glob("shard_*.bin")) - dataset_configs = [read_dataset_shard_config(shard) for shard in shards] - - preparator = get_preparator(prepare_config) - preparator.generate_config_yaml_for_sharded_dst(dataset_configs) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Generate config YAML for sharded datasets") - parser.add_argument( - "--prepare_config", - type=str, - required=False, - default=None, # "/home/toolkit/dev/Fast-LLM/.vscode/prepare_dst.yaml", - help="Path to the prepare config YAML file", - ) - parser.add_argument( - "--dataset_path", - type=str, - required=False, - default="/mnt/datasets/tokenized/Mistral-Nemo-Base-2407/FineWeb2/deu_Latn/", - help="Path to the dataset path", - ) - args = parser.parse_args() - - if args.prepare_config: - with open(args.prepare_config) as f: - config_dict = yaml.safe_load(f) - else: - assert args.dataset_path is not None, "Please provide a prepare config YAML file or dataset path" - config_dict = { - "output_path": args.dataset_path, - "dataset": {"path": "unknown"}, - "tokenizer": {"path": "no_tokenizer"}, - } - main(config_dict) From a44642c4ce3d66b37a8d3deac0a0830275417b65 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 15:01:33 -0400 Subject: [PATCH 03/15] Drop varlen mamba --- Dockerfile | 3 +- fast_llm/layers/ssm/preprocessing.py | 68 ------ tests/test_ssms.py | 349 --------------------------- 3 files changed, 1 insertion(+), 419 deletions(-) delete mode 100644 fast_llm/layers/ssm/preprocessing.py delete mode 100644 tests/test_ssms.py diff --git a/Dockerfile b/Dockerfile index abb0759b9..d9258159b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,9 +29,8 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) -# Using varlen_mamba for variable length sequence support RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba" +RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/ diff --git a/fast_llm/layers/ssm/preprocessing.py b/fast_llm/layers/ssm/preprocessing.py deleted file mode 100644 index 343f0bb28..000000000 --- a/fast_llm/layers/ssm/preprocessing.py +++ /dev/null @@ -1,68 +0,0 @@ -import logging -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.ssm.config import SSMKwargs -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig -from fast_llm.utils import Assert - -logger = logging.getLogger(__name__) - - -class Mamba2Preprocessor(Preprocessor): - def __init__(self, config: HybridSSMBaseModelConfig, tensor_space: TensorSpace): - self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._transformer_dim_names = config.transformer._transformer_dim_names - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - """ - Simplified preprocessor that does not take into account micro-sequences. - """ - if TransformerKwargs.sequence_lengths not in kwargs: - return - sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] - if TransformerKwargs.cu_seqlens_k in kwargs: - # already set this in the transformer preprocessor, so we can use it here - cu_seqlens_k = kwargs[TransformerKwargs.cu_seqlens_k] - cu_seqlens_q = kwargs[TransformerKwargs.cu_seqlens_q] - Assert.eq( - cu_seqlens_k.shape[0], - cu_seqlens_q.shape[0], - msg="cu_seqlens_k and cu_seqlens_q have different lengths, is micro_sequence_length being used? This is currently not supported for Mamba.", - ) - Assert.all_equal(cu_seqlens_k, cu_seqlens_q) - cu_seqlens = cu_seqlens_k - else: - seqlens = torch.cat(sequence_lengths) - cu_seqlens = torch.cat( - ( - torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), - torch.cumsum(seqlens, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), - ) - ) - kwargs[SSMKwargs.cu_seqlens] = cu_seqlens - # from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/verl/models/mamba/hybrid_wrapper.py#L152 - kwargs[SSMKwargs.seq_idx] = torch.cat( - [ - torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) - for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1]) - ], - dim=0, - ).unsqueeze(0) - - sequence_lengths = kwargs.get(TransformerKwargs.sequence_lengths) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - position_ids = torch.stack( - [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] - ).to(self._tensor_space.distributed.device, dtype=torch.int64) - position_ids = position_ids[ - :, sequence_k - sequence_q : sequence_k - ] # this is only needed if we do micro-sequences? - kwargs[SSMKwargs.ssm_position_ids] = position_ids.to(torch.int32) diff --git a/tests/test_ssms.py b/tests/test_ssms.py deleted file mode 100644 index 2a338f1ba..000000000 --- a/tests/test_ssms.py +++ /dev/null @@ -1,349 +0,0 @@ -import inspect -import itertools -import pathlib -from functools import partial - -import pytest -import torch -from mamba2 import Mamba2 - -from fast_llm.config import NoAutoValidate -from fast_llm.engine.checkpoint.config import CheckpointLoadConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.schedule.config import ScheduleConfig -from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.ssm.llamba_block import SSMBlock -from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs -from fast_llm.models.gpt.config import GPTBatchConfig -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat -from fast_llm.models.ssm.model import HybridSSMModel - -_mamba_varlen = False -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa - - _mamba_available = True - sig = inspect.signature(selective_scan_fn) - if "position_indices" in sig.parameters: - _mamba_varlen = True - else: - _mamba_varlen = False - # for training with packing install https://github.com/jxiw/varlen_mamba - # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md - -except (ImportError, RuntimeError): - _mamba_available = False - - -def get_hybrid_config(hybrid_block_layout=["t", "m2"], prediction_heads=1, default_mtp_type=None): - hidden_size = 512 - config = HybridSSMBaseModelConfig( - transformer=TransformerConfig(num_layers=len(hybrid_block_layout), hidden_size=hidden_size), - ssm=SSMConfig(d_xb=hidden_size, dt_rank=10, d_inner=hidden_size * 2), - hybrid_block_layout=hybrid_block_layout, - prediction_heads=prediction_heads, - default_mtp_type=default_mtp_type, - init_method_std_embed=0.02, - init_method_min_embed=-0.02, - init_method_max_embed=0.02, - use_position_embeddings=True, - tie_word_embeddings=False, - ) - return config - - -@pytest.mark.skip("Disabled due to cartesia_pytorch installation issue") -@pytest.mark.slow -def test_load_from_llamba_checkpoint(): - """ - Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. - """ - import cartesia_pytorch.Llamba.llamba - - vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json - batch_size = 2 - seq_length = 32 - - path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") - format = LLambaHuggingfaceCheckpointFormat - - x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - - hf_model = cartesia_pytorch.Llamba.llamba.LMHeadModel.from_pretrained(path, strict=True).to("cuda") - parameter_sum_hf = sum(p.detach().sum().cpu().item() for p in hf_model.parameters()) - hf_logits = hf_model(x)["logits"].cpu() - del hf_model - torch.cuda.empty_cache() - - # Create checkpoint load config - checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) - # Initialize model - model = HybridSSMModel.from_pretrained(checkpoint_config) - param_sum = 0 - for stage in model.stages: - for fsdp in stage.fsdps: - if hasattr(fsdp, "_weight_shard"): - param_sum += torch.sum(fsdp._weight_shard).item() - assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 - - # model = GPTModel.from_pretrained(checkpoint_config) - assert model.config.base_model.vocab_size == vocab_size - schedule_config = ScheduleConfig() - with NoAutoValidate(): - batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) - batch_config.setup(DistributedConfig.from_dict({})) - batch_config.validate() - schedule_runner = ScheduleRunner( - config=schedule_config, - multi_stage=model, - distributed_config=model.distributed.config, - ) - schedule = Schedule( - multi_stage=model, - batch_config=batch_config, - schedule_config=schedule_config, - distributed_config=model.distributed.config, - phase=PhaseType.inference, - ) - schedule_runner.setup(model.distributed, optimizer=None) - - common_kwargs = { - TransformerKwargs.sequence_first: True, - TransformerKwargs.grad_output: False, - } - input_data = [(x, common_kwargs)] - - schedule_runner.run_step(iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True) - - logits = input_data[0][1]["logits"].cpu() - assert torch.allclose(logits, hf_logits, atol=1e-2) - - -@pytest.fixture -def distributed_config(): - return DistributedConfig( - tensor_parallel=1, - pipeline_parallel=1, - sequence_data_parallel=1, - local_world_size=1, - world_size=1, - ) - - -@pytest.fixture -def distributed(distributed_config): - return Distributed(config=distributed_config) - - -def materialize_meta_tensors(model, tensor_space): - # Materialize parameters that are on meta device - for name, param in model.named_parameters(): - if param.device.type == "meta": - # Check if the parameter is a custom tensor type - if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): - param_data = param.new_empty(param.shape, device="cuda") - # Initialize param_data - param.init_parameter(param_data, tensor_space.distributed) - # Replace the parameter in the module - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - module = model - if module_path is not None: - for part in module_path.split("."): - module = getattr(module, part) - param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation - param.grad = None - param.grad_buffer = torch.empty_like(param) - param.param_grad_is_zero = True - module._parameters[param_name] = param - return model - - -def unpack(packed_hidden_states, cu_seqlens): - batch_size = packed_hidden_states.shape[0] - package_num = cu_seqlens.shape[0] - 1 - seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - hidden_dim = packed_hidden_states.shape[2] - hidden_states = torch.zeros( - package_num * batch_size, - seq_len, - hidden_dim, - dtype=packed_hidden_states.dtype, - device=packed_hidden_states.device, - ) - for j in range(batch_size): - for i in range(package_num): - line = j * package_num + i - hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[ - j, cu_seqlens[i] : cu_seqlens[i + 1], : - ] - return hidden_states - - -def pack(hidden_states, cu_seqlens, batch_size): - package_num, seq_len, hidden_dim = hidden_states.shape - seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] - seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) - indices_3d = ( - torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).unsqueeze(2).repeat(package_num, 1, hidden_dim) - ) - mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) - packed_hidden_states = hidden_states[mask_3d].view(batch_size, -1, hidden_dim) - return packed_hidden_states - - -def generate_random_cu_seqlens(seq_len, packages_num=2): - if packages_num < 1: - raise ValueError("packages_num must be at least 1") - - # base size of each chunk, and how many get an extra token - base, rem = divmod(seq_len, packages_num) - # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] - lengths = [base + 1 if i < rem else base for i in range(packages_num)] - - # split points exclude the final cumulative (seq_len) - split_points = list(itertools.accumulate(lengths))[:-1] - - # cu_seqlens = [0] + split_points + [seq_len] - cu_seqlens = [0] + split_points + [seq_len] - - # index: for each chunk, we emit 0,1,...,length-1 - index = [] - for length in lengths: - index.extend(range(length)) - - # sanity check - assert len(cu_seqlens) - 1 == packages_num - assert sum(lengths) == seq_len - assert len(index) == seq_len - - return cu_seqlens, index - - -# Quick and dirty test for Mamba2 varlen block from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/tests/pack_mamba/test_mamba_layer.py -# TODO: integrate in the testing framework -@pytest.mark.slow -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") -@pytest.mark.skipif(not _mamba_available, reason="Mamba2 is not available") -@pytest.mark.skipif(not _mamba_varlen, reason="Mamba2 varlen is not available") -def test_mamba_varlen_block(distributed_config, distributed): - """ - Compare that the output and grads of packed and unpacked Mamba2 varlen block are the same. - """ - hybrid_config = get_hybrid_config(hybrid_block_layout=["m2", "t"]) - tensor_space = TensorSpace(distributed_config=distributed_config) - tensor_space.setup(distributed) - hybrid_config.setup_tensor_space(tensor_space) - layer_idx = 0 - - mixer_cls = partial(Mamba2, block_index=layer_idx) - block_packed = SSMBlock( - hybrid_config.transformer, - hybrid_config.ssm, - mixer_cls=mixer_cls, - tensor_space=tensor_space, - block_index=layer_idx, - ) - block_ref = SSMBlock( - hybrid_config.transformer, - hybrid_config.ssm, - mixer_cls=mixer_cls, - tensor_space=tensor_space, - block_index=layer_idx, - ) - device = "cuda" - materialize_meta_tensors(block_packed, tensor_space) - materialize_meta_tensors(block_ref, tensor_space) - block_ref.load_state_dict(block_packed.state_dict()) - block_packed.to(device) - block_ref.to(device) - - batch_size = 2 - seq_len = 64 - packages_num = 2 - hidden_dim = hybrid_config.transformer.hidden_size - - cu_seqlens, index = generate_random_cu_seqlens(seq_len, packages_num=packages_num) - cu_seqlens = torch.tensor(cu_seqlens).cuda() - ssm_position_ids = torch.tensor(index, dtype=torch.int32).unsqueeze(0).expand(batch_size, -1).contiguous().cuda() - seq_idx = ( - torch.cat( - [ - torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) - for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1]) - ], - dim=0, - ) - .unsqueeze(0) - .repeat(batch_size, 1) - ) - - # Generate packed_hidden_states with random values for testing - hidden_states_list = [ - torch.randn(l, hidden_dim, device=device, dtype=torch.bfloat16, requires_grad=True) - for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - ] - packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0) - packed_hidden_states = packed_hidden_states.expand(batch_size, -1, -1).contiguous() - # hidden_states should be forwarded without cu_seqlens - hidden_states = unpack(packed_hidden_states, cu_seqlens) - - # Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states - assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1] - # Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states - assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] - - output_states_packed = block_packed( - packed_hidden_states, - {"cu_seqlens": cu_seqlens, "seq_idx": seq_idx, "ssm_position_ids": ssm_position_ids, "sequence_first": False}, - ) - output_states_unpacked = block_ref( - hidden_states.clone(), {"cu_seqlens": None, "seq_idx": None, "ssm_position_ids": None, "sequence_first": False} - ) - tollerance = 1e-4 - assert output_states_packed.shape == packed_hidden_states.shape - assert output_states_unpacked.shape == hidden_states.shape - assert not torch.isnan(hidden_states).any() - assert not torch.isinf(hidden_states).any() - - output_states_unpacked = pack(output_states_unpacked, cu_seqlens, batch_size) - torch.allclose(output_states_packed, output_states_unpacked, atol=tollerance) - - loss = output_states_packed.sum() - loss.backward() - loss_ref = output_states_unpacked.sum() - loss_ref.backward() - assert torch.allclose(block_packed.mixer.conv1d_weight.grad, block_ref.mixer.conv1d_weight.grad, atol=tollerance) - assert torch.allclose(block_packed.mixer.conv1d_bias.grad, block_ref.mixer.conv1d_bias.grad, atol=tollerance) - assert torch.allclose( - block_packed.mixer.in_proj.weight.grad_buffer, block_ref.mixer.in_proj.weight.grad_buffer, atol=tollerance - ) - assert torch.allclose( - block_packed.mixer.out_proj.weight.grad_buffer, block_ref.mixer.out_proj.weight.grad_buffer, atol=tollerance - ) - assert torch.allclose( - block_packed.mixer.dt_in_proj.weight.grad_buffer, - block_ref.mixer.dt_in_proj.weight.grad_buffer, - atol=tollerance, - ) - - assert torch.allclose( - block_packed.mlp.layer_1.weight.grad_buffer, block_ref.mlp.layer_1.weight.grad_buffer, atol=tollerance - ) - assert torch.allclose( - block_packed.mlp.layer_1.bias.grad_buffer, block_ref.mlp.layer_1.bias.grad_buffer, atol=tollerance - ) - assert torch.allclose( - block_packed.mlp.layer_2.weight.grad_buffer, block_ref.mlp.layer_2.weight.grad_buffer, atol=tollerance - ) - assert torch.allclose( - block_packed.mlp.layer_2.bias.grad_buffer, block_ref.mlp.layer_2.bias.grad_buffer, atol=tollerance - ) - - -if __name__ == "__main__": - pytest.main([__file__]) From ddf21437065da569884b6078a2d4ad8d0b35735c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 15:50:26 -0400 Subject: [PATCH 04/15] cleanup --- fast_llm/data/data/gpt/config.py | 40 +---- fast_llm/data/dataset/config.py | 19 +-- fast_llm/functional/config.py | 14 +- fast_llm/functional/cross_entropy.py | 210 +++---------------------- fast_llm/layers/language_model/head.py | 120 ++++---------- 5 files changed, 58 insertions(+), 345 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 405d1c672..efee46959 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,23 +1,16 @@ import logging -import typing from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.config import MultiprocessingContext, TokenizerConfig from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.gpt.config import ( - GPTLegacyConfig, - GPTLegacyDatasetConfig, - GPTSampledDatasetConfig, - GPTSamplingConfig, -) -from fast_llm.engine.distributed.config import PhaseType +from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) @config_class() -class GPTDataConfig(DataConfig, GPTLegacyConfig): +class GPTDataConfig(DataConfig): """ Configuration for the dataset(s), split and sampling. Currently hard-coded to a GPT dataset. @@ -48,32 +41,3 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Multiprocessing context. Do not touch.", hint=FieldHint.expert, ) - - def _validate(self) -> None: - if not self.datasets: - logger.warning( - "Using the legacy dataset definition format." " Specify it through `data.datasets` instead." - ) - self.datasets = { - phase.value.lower(): GPTLegacyDatasetConfig.from_dict(self, strict=False) - for phase in (PhaseType.training, PhaseType.validation, PhaseType.test) - } - super()._validate() - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - if "datasets" in default: - for phase in PhaseType: - if phase.value in default["datasets"]: - rename = phase.value.lower() - logger.warning(f"Renaming dataset {phase.value} to {rename}") - assert rename not in default["datasets"] - default["datasets"][rename] = default["datasets"].pop(phase.value) - - return super()._from_dict(default, strict, flat) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 5e3ced8a4..0c1b0cd09 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -204,11 +204,6 @@ class BlendedDatasetConfig(SampledDatasetConfig): desc="The blending weight of each dataset.", hint=FieldHint.core, ) - legacy: bool = Field( - default=False, - desc="Use the legacy formulas for sub-dataset seeds and sample sizes.", - hint=FieldHint.deprecated, - ) def _validate(self) -> None: self.weights = normalize_probabilities(self.weights) @@ -231,20 +226,10 @@ def build_and_sample( sampling, parameters=dataclasses.replace( sampling.parameters, - num_samples=( - math.ceil( - weight - * ( - sampling.parameters.num_samples - + 5 * (sampling.parameters.num_samples * (1 - weight)) ** 0.5 - ) - ) - if self.legacy - else math.ceil(weight * sampling.parameters.num_samples) + 1 - ), + num_samples=math.ceil(weight * sampling.parameters.num_samples) + 1, ), # TODO: Seed may not be unique for nested blended datasets. - config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}), + config=sampling.config.to_copy({"seed": sampling.config.seed + i * 697}), ), ) for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True)) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 5c8d75a6f..684193848 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -40,7 +40,6 @@ class ActivationType(enum.StrEnum): """ gelu = "gelu" - gelu_pytorch_tanh = "gelu_pytorch_tanh" silu = "silu" relu = "relu" squared_relu = "squared_relu" @@ -68,8 +67,7 @@ def _set_activation_fn_map() -> None: global _ACTIVATION_FN_MAP _ACTIVATION_FN_MAP = { - ActivationType.gelu: torch.nn.functional.gelu, - ActivationType.gelu_pytorch_tanh: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), @@ -80,8 +78,7 @@ def _set_activation_fn_map() -> None: _ACTIVATION_FN_MAP: dict[ActivationType, typing.Callable[["torch.Tensor"], "torch.Tensor"]] = {} _ACTIVATION_HF_NAMES = { - ActivationType.gelu: "gelu", - ActivationType.gelu_pytorch_tanh: "gelu_pytorch_tanh", + ActivationType.gelu: "gelu_pytorch_tanh", ActivationType.silu: "silu", ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", @@ -89,16 +86,9 @@ def _set_activation_fn_map() -> None: } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} - MAX_DROPLESS_BLOCK_SIZE_ROW = 128 -class ReverseKLImpl(str, enum.Enum): - tp = "tp" - stp = "stp" - no_tp = "no_tp" - - class CrossEntropyImpl(str, enum.Enum): auto = "auto" torch = "torch" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d9ca547a7..d56dce98d 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -1,7 +1,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl, ReverseKLImpl, TargetFormat +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.utils import Assert @@ -49,19 +49,6 @@ def _torch_cross_entropy_forward_backward( return loss.detach_(), grad -def distributed_log_softmax(logits: torch.Tensor, group: ProcessGroup, dim: int = -1): - logits = logits.float() - local_max = logits.max(dim=dim, keepdim=True)[0] - all_reduce(local_max, op=ReduceOp.MAX, group=group) - - logits_shifted = logits - local_max - exp_logits = torch.exp(logits_shifted) - sum_exp = exp_logits.sum(dim=dim, keepdim=True) - all_reduce(sum_exp, op=ReduceOp.SUM, group=group) - - return logits_shifted - sum_exp.log() # log_softmax - - @torch.compile def _fused_softmax_base( logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 @@ -164,8 +151,7 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= group.size() + all_reduce(loss, op=ReduceOp.MEAN, group=group) return loss, grad @@ -227,30 +213,20 @@ def cross_entropy_forward_backward( ) -def _torch_reverse_kl_forward_backward_vocab_parallel( +def _torch_reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, + logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. - This is used for TP version where we split accross vocab dimantion. - This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. - In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. + Much simpler and more reliable than custom implementation! """ - Assert.eq( - teacher_softmax_temperature, - 1, - msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", - ) - Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") - # TODO: merge into single function _torch_reverse_kl_forward_backward Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype @@ -258,78 +234,32 @@ def _torch_reverse_kl_forward_backward_vocab_parallel( Assert.eq(loss_mask.shape, logits.shape[:-1]) # Compute log probabilities - let _fused_softmax handle scaling internally - teacher_log_probs = distributed_log_softmax(target.float(), group=group) - batch_size = logits.shape[0] - with torch.enable_grad(): - logits_ = logits.float().detach().requires_grad_(grad_output is not None) - student_log_probs = distributed_log_softmax(logits_, group=group) - - # Reverse KL: input=teacher_log_probs, target=student_probs - if loss_mask is None: - loss = torch.nn.functional.kl_div( - teacher_log_probs, # input = log(p) - student_log_probs, # target = log(q) - reduction="sum", - log_target=True, - ) - else: - # Apply loss mask - this requires some reshaping - raise NotImplementedError("Loss mask not implemented with TP for reverse KL , it must be doublechecked") - loss_per_sample = torch.nn.functional.kl_div( - teacher_log_probs, student_log_probs, reduction="none", log_target=True - ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).sum() - - if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= batch_size - - if grad_output is not None: - loss.backward(torch.full_like(loss, grad_output)) - grad = logits_.grad.to(logits.dtype) - else: - grad = None - - return loss.detach_(), grad - + # teacher_probs = _fused_softmax(target, logits_scale_factor * (1 / teacher_softmax_temperature), group) + # # teacher_log_probs = torch.log(teacher_probs + 1e-8) # log(p) + # teacher_probs = torch.clamp(teacher_probs, min=1e-7) # or even 1e-6 + # teacher_log_probs = torch.log(teacher_probs) -def _torch_reverse_kl_forward_backward_no_tp( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - logits_scale_factor: float, - target_format: TargetFormat, - teacher_softmax_temperature: float = 1.0, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Reverse KL using PyTorch's native kl_div function. - THis is only used for no-TP case. - """ - Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) # Scale target logits more carefully scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) - # Clamp to prevent extreme values that cause NaNs in log_softmax - scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) - teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1) + # Clamp to prevent extreme values before log_softmax + scaled_target = torch.clamp(scaled_target, min=-50, max=50) + teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) # Use kl_div with: input=log(p), target=q, log_target=False # This gives: Σ q * (log(q) - log(p)) = exactly what we want! with torch.enable_grad(): - logits_ = logits.float().detach().requires_grad_(grad_output is not None) + logits_ = logits.detach().requires_grad_(grad_output is not None) + # Use log_softmax for consistency instead of _fused_softmax scaled_logits = logits_ * logits_scale_factor - # Clamp to prevent extreme values that cause NaNs in log_softmax - scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) - student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1) + scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) + student_log_probs = torch.log_softmax(scaled_logits, dim=-1) + + # Convert to probabilities for kl_div + # student_probs_ = torch.exp(student_log_probs) # Reverse KL: input=teacher_log_probs, target=student_probs if loss_mask is None: @@ -344,85 +274,12 @@ def _torch_reverse_kl_forward_backward_no_tp( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).sum() / loss_mask.sum() - - if grad_output is not None: - # note, we never get here in TP over seq. dim. - loss.backward(torch.full_like(loss, grad_output)) - grad = logits_.grad.to(logits.dtype) - else: - grad = None - - return loss.detach_(), grad - - -def _torch_reverse_kl_forward_backward_sequence_tensor_parallel( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - logits_scale_factor: float, - target_format: TargetFormat, - teacher_softmax_temperature: float = 1.0, - total_valid_tokens: int | None = None, # total number of unmasked tokens in the batch - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Reverse KL using PyTorch's native kl_div function. - THis is only used for sequence-tensor-parallel case where we split over sequence dimension. - """ - Assert.eq( - total_valid_tokens is not None, - msg="Total valid tokens must be provided for sequence-tensor-parallel reverse KL", - ) - Assert.eq( - teacher_softmax_temperature, - 1, - msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", - ) - Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") - Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - # Scale target logits more carefully - scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) - # Clamp to prevent extreme values that cause NaNs in log_softmax - scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) - - teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1) - - # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) - # Use kl_div with: input=log(p), target=q, log_target=False - # This gives: Σ q * (log(q) - log(p)) = exactly what we want! - - with torch.enable_grad(): - logits_ = logits.float().detach().requires_grad_(grad_output is not None) - - scaled_logits = logits_ * logits_scale_factor - # Clamp to prevent extreme values that cause NaNs in log_softmax - scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) - student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1) + loss = (loss_per_sample * loss_mask).mean() - # Reverse KL: input=teacher_log_probs, target=student_probs - if loss_mask is None: - loss = torch.nn.functional.kl_div( - teacher_log_probs, # input = log(p) - student_log_probs, # target = log(q) - reduction="sum", - log_target=True, - ) - else: - # Apply loss mask - this requires some reshaping - loss_per_sample = torch.nn.functional.kl_div( - teacher_log_probs, student_log_probs, reduction="none", log_target=True - ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).sum() # this can be 0.0 if all tokens are masked + if group is not None and target_format != TargetFormat.labels: + all_reduce(loss, op=ReduceOp.MEAN, group=group) if grad_output is not None: - # note, if we compute gradient w.r.t sum of losses, - # and grad_output should reflect the scaling by 1/valid samples loss.backward(torch.full_like(loss, grad_output)) grad = logits_.grad.to(logits.dtype) else: @@ -431,13 +288,6 @@ def _torch_reverse_kl_forward_backward_sequence_tensor_parallel( return loss.detach_(), grad -REVERSE_KL_IMPLEMENTATIONS = { - ReverseKLImpl.no_tp: _torch_reverse_kl_forward_backward_no_tp, - ReverseKLImpl.tp: _torch_reverse_kl_forward_backward_vocab_parallel, - ReverseKLImpl.stp: _torch_reverse_kl_forward_backward_sequence_tensor_parallel, -} - - def reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -447,8 +297,6 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, - reverse_kl_impl: ReverseKLImpl = ReverseKLImpl.no_tp, - total_valid_tokens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -491,15 +339,7 @@ def reverse_kl_forward_backward( assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - # TODO: implement fused reverse KL? - return REVERSE_KL_IMPLEMENTATIONS[reverse_kl_impl]( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - logits_scale_factor=logits_scale_factor, - target_format=target_format, - teacher_softmax_temperature=teacher_softmax_temperature, - group=group, - total_valid_tokens=total_valid_tokens, + # TODO: implement fused? + return _torch_reverse_kl_forward_backward( + logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group, teacher_softmax_temperature ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 21753006a..ade1144d2 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -11,13 +11,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import ( - CrossEntropyImpl, - DistillationLossImpl, - ReverseKLImpl, - TargetFormat, - TritonConfig, -) +from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward @@ -58,13 +52,17 @@ def __init__( lr_scale: float | None, peft: PeftConfig | None, prediction_distance: int, + return_input: bool = False, ): + if return_input: + raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_input=return_input, ) self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -244,24 +242,10 @@ def _get_targets( ).flatten() else: lm_target = None - targets = (dpo_target, lm_target, distillation_target) - # If we do distillation, no need to split it here as it has already been split in the embedding layer! - # if we do CPT/language modeling, we need to split the targets here! - if ( - self._config.distillation_model is not None - and self._sequence_parallel_logits - and not self._parallel_embeddings - and not self._sequence_parallel - ) or (self._config.distillation_model is None and self._sequence_parallel_logits): - # We dont split targets if they already have been split in the embedding layer! - targets = [ - None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) - for target in targets - ] - # Loss mask may need to be split. It was not split in the embedding layer as it is not used there. - if loss_mask is not None and self._sequence_parallel_logits: - loss_mask = split_op(loss_mask, self._tensor_space.distributed.tensor_group, 0) - targets = (*targets, loss_mask) + + targets = (dpo_target, lm_target, distillation_target, loss_mask) + if self._sequence_parallel_logits: + targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] if not any(target is not None for target in targets): # Simplify so we don't have to check every time. targets = None @@ -321,13 +305,14 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad_.copy_(grad_) loss = loss_ if loss is None else loss + loss_ del grad_, loss_ - assert self._cross_entropy_splits is None, "This is not supported for now" - # loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) - # if loss_count != 1: - # loss.div_(loss_count) - # if self._sequence_parallel_logits: - # # TODO: Async - # all_reduce(loss, group=self._tensor_space.distributed.tensor_group) + loss_count = (self._config.cross_entropy_splits or 1) * ( + self._parallel_dim.size if self._sequence_parallel_logits else 1 + ) + if loss_count != 1: + loss.div_(loss_count) + if self._sequence_parallel_logits: + # TODO: Async + all_reduce(loss, group=self._parallel_dim.group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None def _logits_cross_entropy_forward_backward( @@ -358,34 +343,13 @@ def _logits_cross_entropy_forward_backward( LanguageModelLossNames.z_loss, logits_scale_factor=self._config.logits_scale_factor, ) - if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space[ - LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ] - dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] - sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) - dims[sequence_index] = ( - TensorDim( - TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor - ) - if self._sequence_parallel_logits - else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) - ) - - dim_names = ( - [TransformerDimNames.sequence_q_tp, LanguageModelDimNames.vocab] - if self._sequence_parallel_logits - else [TransformerDimNames.sequence_q, LanguageModelDimNames.vocab_tp] - ) - - dim_names.insert(int(kwargs[TransformerKwargs.sequence_first]), TransformerDimNames.batch) - log_distributed_tensor( - "", - logits, - level=self._debug_transformer, - meta=TensorMeta.from_dims(tuple(dims), tensor_name="transformer logits", dtype=logits.dtype), - distributed=self._tensor_space.distributed, - scale=self._logits_scale_factor, + if self._debug.enabled and self._config.cross_entropy_splits is None: + sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q + batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] + dims = ( + (sequence_dim, batch_dim, self._vocab_dim) + if kwargs[LanguageModelKwargs.sequence_first] + else (batch_dim, sequence_dim, self._vocab_dim) ) self._debug(logits, "Language model logits", dims, kwargs, scale=self._config.logits_scale_factor) @@ -421,31 +385,8 @@ def _logits_cross_entropy_forward_backward( else: lm_loss, lm_grad = None, None - if distillation_target is not None and self._distillation_loss_factor > 0.0: - if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: - local_valid_tokens = total_valid_tokens = logits.shape[0] - if logits.shape[-1] != self._config.vocab_size: - reverse_kl_impl = ReverseKLImpl.tp - assert loss_mask is None, "Loss mask is not implemented for TP (vocab dim) reverse KL yet" - elif self._sequence_parallel_logits: - # grad_output already reflects scaling 1/ number of ranks (group_size), see _forward_backward - reverse_kl_impl = ReverseKLImpl.stp - if loss_mask is not None: - local_valid_tokens = loss_mask.sum() - total_valid_tokens = local_valid_tokens.clone() - all_reduce( - total_valid_tokens, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group - ) - else: - local_valid_tokens = logits.shape[0] - total_valid_tokens = local_valid_tokens * self._group_size - # in the loss function we compute grads w.r.t sum of losses, - # so we need to multiply back by the group size and divide by the number of valid tokens to get the correct scaling - # note, the function returns the sum of local losses, so we need to handle this properly for reporting - grad_output *= self._group_size / total_valid_tokens # multiply back by the group size - else: - reverse_kl_impl = ReverseKLImpl.no_tp - + if distillation_target is not None and self._config.distillation_loss_factor > 0.0: + if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, @@ -457,15 +398,8 @@ def _logits_cross_entropy_forward_backward( target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), - reverse_kl_impl=reverse_kl_impl, - total_valid_tokens=total_valid_tokens, ) - if self._sequence_parallel_logits: - # distillation_loss is local sum, so we need to divide by the number of valid tokens to get the correct scaling - all_reduce(distillation_loss, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group) - distillation_loss /= total_valid_tokens # final global loss - - elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: + elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, From 8ee7d5e2d27773081b94d9a2ae5884c525019e71 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 15:53:50 -0400 Subject: [PATCH 05/15] cleanup --- fast_llm/models/gpt/conversion/auto.py | 6 - fast_llm/models/gpt/conversion/config.py | 24 -- fast_llm/models/gpt/conversion/llava.py | 155 ---------- .../models/gpt/conversion/llava_hybrid.py | 40 --- fast_llm/models/gpt/conversion/pixtral.py | 266 ------------------ .../modeling_apriel_hybrid_ssm.py | 66 +---- .../configuration_llava_hybrid.py | 117 -------- .../llava_hybrid/modeling_llava_hybrid.py | 132 --------- 8 files changed, 8 insertions(+), 798 deletions(-) delete mode 100644 fast_llm/models/gpt/conversion/llava.py delete mode 100644 fast_llm/models/gpt/conversion/llava_hybrid.py delete mode 100644 fast_llm/models/gpt/conversion/pixtral.py delete mode 100644 fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py delete mode 100644 fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py diff --git a/fast_llm/models/gpt/conversion/auto.py b/fast_llm/models/gpt/conversion/auto.py index bd2c3d2c7..659d1f12c 100644 --- a/fast_llm/models/gpt/conversion/auto.py +++ b/fast_llm/models/gpt/conversion/auto.py @@ -8,12 +8,9 @@ DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, LlamaCheckpointFormat, - LlavaCheckpointFormat, - LlavaHybridCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, MTPLlamaCheckpointFormat, - PixtralCheckpointFormat, Qwen2CheckpointFormat, ) from fast_llm.models.gpt.conversion.diffusion_dream import DiffusionDreamHuggingfaceCheckpointHandler @@ -38,7 +35,4 @@ class AutoGPTHuggingfaceCheckpointHandler( DiffusionDreamCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler, DiffusionLlamaCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, AprielHybridSSMCheckpointFormat.name: AprielHuggingfaceCheckpointHandler, - PixtralCheckpointFormat: PixtralHuggingfaceCheckpointHandler, - LlavaCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, - LlavaHybridCheckpointFormat: LlavaHybridCHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/conversion/config.py b/fast_llm/models/gpt/conversion/config.py index f6e3d65c9..7c06906ad 100644 --- a/fast_llm/models/gpt/conversion/config.py +++ b/fast_llm/models/gpt/conversion/config.py @@ -47,27 +47,3 @@ class DiffusionLlamaCheckpointFormat(GPTHuggingfaceCheckpointFormat): class AprielHybridSSMCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_hybrid_ssm" - - -class LlavaCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "llava" - # Using default values for vision and text models. Can be overridden in the config - vision_name: typing.ClassVar[str] = "pixtral" - text_name: typing.ClassVar[str] = "mistral" - - -class PixtralCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "pixtral" - - -class LlavaHybridCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "llava_hybrid" - vision_name: typing.ClassVar[str] = "pixtral" - text_name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" - trust_remote_code: typing.ClassVar[bool] = True - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import LlavaHybridHuggingfaceCheckpointHandler - - return LlavaHybridHuggingfaceCheckpointHandler diff --git a/fast_llm/models/gpt/conversion/llava.py b/fast_llm/models/gpt/conversion/llava.py deleted file mode 100644 index 996269862..000000000 --- a/fast_llm/models/gpt/conversion/llava.py +++ /dev/null @@ -1,155 +0,0 @@ -import typing - -from fast_llm import __version__ -from fast_llm.config import MISSING, get_nested_dict_value, set_nested_dict_value -from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadMetadataConfig -from fast_llm.engine.checkpoint.external import ExternalStateDictCheckpointHandler -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig -from fast_llm.functional.config import ActivationType -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig -from fast_llm.models.gpt.conversion.auto import AutoGPTHuggingfaceCheckpointHandler -from tests.utils.model_configs import LlavaGPTHuggingfaceCheckpointFormat - - -class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "LlavaForConditionalGeneration" - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig - - @classmethod - def get_vision_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: - return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) - - @classmethod - def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: - return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) - - @classmethod - def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: - vision_handler_cls = cls.get_vision_handler_class() - text_handler_cls = cls.get_text_handler_class() - cfg_dict = cls._load_config(config.path) - kwargs = {} - if "text_config" in cfg_dict: - text_kwargs = text_handler_cls._import_config_dict(cfg_dict["text_config"]) - kwargs.update(text_kwargs) - if "vision_config" in cfg_dict: - vision_kwargs = vision_handler_cls._import_config_dict(cfg_dict["vision_config"]) - vision_kwargs = {tuple(["vision_encoder"] + list(key)): value for key, value in vision_kwargs.items()} - kwargs.update(vision_kwargs) - kwargs.update( - cls._import_config( - {key: value for key, value in cfg_dict.items() if key not in ("text_config", "vision_config")} - ) - ) - imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) - return CheckpointMetadata( - fast_llm_version=__version__, - model=cls._model_class, - format=config.format, - config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), - shards=["weights"], - ) - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "adapter_activation_type"),), - export_names=(("projector_hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "adapter_size"),), - export_names=(("projector_intermediate_size",),), - ), - ] - - @classmethod - def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: - # handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) - kwargs = {} - for converter in cls._create_config_converters(): - try: - values = () - for export_name in converter.export_names: - try: - value = get_nested_dict_value(config, export_name) - except KeyError: - value = MISSING - values = values + (value,) - values = converter.import_params(values) - for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): - if value is MISSING: - raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") - if fast_llm_name in kwargs: - raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") - kwargs[fast_llm_name] = value - except Exception as e: - raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) - - return kwargs - - @classmethod - def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: - exported_config = {} - vision_handler_cls = cls.get_vision_handler_class() - text_handler_cls = cls.get_text_handler_class() - for converter in vision_handler_cls._create_config_converters(): - try: - values = converter.export_params( - tuple( - cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) - for fast_llm_name in converter.fast_llm_names - ) - ) - for export_name, value in zip(converter.export_names, values, strict=True): - if value is not MISSING: - set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) - except Exception as e: - raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) - - for converter in text_handler_cls._create_config_converters(): - try: - values = converter.export_params( - tuple( - cls._get_fast_llm_attribute(config, fast_llm_name) - for fast_llm_name in converter.fast_llm_names - ) - ) - for export_name, value in zip(converter.export_names, values, strict=True): - if value is not MISSING: - set_nested_dict_value(exported_config, ("text_config",) + export_name, value) - except Exception as e: - raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) - - for converter in cls._create_config_converters(): - try: - values = converter.export_params( - tuple( - cls._get_fast_llm_attribute(config, fast_llm_name) - for fast_llm_name in converter.fast_llm_names - ) - ) - for export_name, value in zip(converter.export_names, values, strict=True): - if value is not MISSING: - set_nested_dict_value(exported_config, export_name, value) - except Exception as e: - raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) - - return exported_config - - def _create_weight_converters(self): - vision_handler_cls = self.get_vision_handler_class() - vision_handler = vision_handler_cls(self._model) - converters = vision_handler._create_weight_converters(hf_base_prefix="vision_tower.", offset=0) - text_handler_cls = self.get_text_handler_class() - text_handler = text_handler_cls(self._model) - converters.extend( - text_handler._create_weight_converters(hf_base_prefix="language_model.", offset=vision_handler.num_layers) - ) - return converters diff --git a/fast_llm/models/gpt/conversion/llava_hybrid.py b/fast_llm/models/gpt/conversion/llava_hybrid.py deleted file mode 100644 index 45eb1cf2d..000000000 --- a/fast_llm/models/gpt/conversion/llava_hybrid.py +++ /dev/null @@ -1,40 +0,0 @@ -import typing - -from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import ExternalStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.models.gpt.conversion.config import LlavaHybridCheckpointFormat -from fast_llm.models.gpt.conversion.llava import LlavaHuggingfaceCheckpointHandler - - -class LlavaHybridHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlavaHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridCheckpointFormat - architecture: typing.ClassVar[str] = "LlavaHybridForConditionalGeneration" - modeling_file = modeling_llava_hybrid.__file__ - configuration_file = configuration_llava_hybrid.__file__ - configuration_cls: typing.ClassVar[type[PretrainedConfig]] = configuration_llava_hybrid.LlavaHybridConfig - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - additional_files = [ - modeling_ssm_hybrid_apriel15b.__file__, - configuration_ssm_hybrid_apriel15b.__file__, - ] - - @classmethod - def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: - from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler - - return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("auto_map",),), - export_value={ - "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", - "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", - "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", - "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", - }, - ), - ] diff --git a/fast_llm/models/gpt/conversion/pixtral.py b/fast_llm/models/gpt/conversion/pixtral.py deleted file mode 100644 index da055a5db..000000000 --- a/fast_llm/models/gpt/conversion/pixtral.py +++ /dev/null @@ -1,266 +0,0 @@ -import typing - -from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import SplitWeightConverter, WeightConverter -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.functional.config import ActivationType -from fast_llm.layers.attention.rotary.config import Rotary2DConfig -from fast_llm.layers.common.normalization.config import LayerNormalizationConfig -from fast_llm.models.gpt.conversion.llama import KeyValueWeightConverter, MLPLayer2Converter, QueryWeightConverter -from fast_llm.utils import Assert - - -class PixtralNumHeadsConverter(ParamConverter): - """ - Pixtral encoder uses Multi-Head Attention. - Map `num_attention_heads` and `head_groups` to a single `num_heads` parameter. - """ - - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 2) - Assert.eq(len(self.export_names), 1) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - (num_heads, head_groups) = fast_llm_values - assert head_groups == num_heads, "Pixtral encoder expects num_heads == head_groups (MHA)" - return (num_heads,) - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - (num_heads,) = export_values - return (num_heads, num_heads) - - -class PixtralRotaryParamConverter(ParamConverter): - """ - Pixtral encoder uses 2D Rotary Embeddings. - Map `rope_theta` to a single `rotary` parameter. `rotary_scaling` is not needed. - """ - - def __init__(self, fast_llm_names, export_names): - Assert.eq(len(fast_llm_names), 1) - Assert.eq(len(export_names), 1) - self.fast_llm_names = fast_llm_names - self.export_names = export_names - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - (rotary_config,) = fast_llm_values - if type(rotary_config) is Rotary2DConfig: - return (rotary_config.theta,) - else: - raise ValueError(f"Unsupported rotary type: {type(rotary_config).__name__}") - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - (rotary_theta,) = export_values - rotary_config = { - "type": "rope_2d", - "theta": rotary_theta, - } - return (rotary_config,) - - -class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat - _model_class: typing.ClassVar[FastLLMModelConfig] = FastLLMModelConfig - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value="pixtral"), - ConstantImportParamConverter(fast_llm_names=(("patch_norm", "type"),), fast_llm_value="rms_norm"), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value="rms_norm" - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "type"),), fast_llm_value="image_encoder"), - ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), - ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=( - ( - "transformer", - "num_layers", - ), - ), - export_names=(("num_hidden_layers",),), - ), - RenameParamConverter( - fast_llm_names=( - ( - "transformer", - "hidden_size", - ), - ), - export_names=(("hidden_size",),), - ), - PixtralNumHeadsConverter( - fast_llm_names=( - ( - "transformer", - "num_attention_heads", - ), - ( - "transformer", - "head_groups", - ), - ), - export_names=(("num_attention_heads",),), - ), - RenameParamConverter( - fast_llm_names=( - ( - "transformer", - "ffn_hidden_size", - ), - ), - export_names=(("intermediate_size",),), - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=( - ( - "transformer", - "kv_channels", - ), - ), - export_names=(("head_dim",),), - ), - # ConstantImportParamConverter( - # fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.rope_2d - # ), - # RenameParamConverter( - # fast_llm_names=( - # ( - # "transformer", - # "rotary", - # "theta", - # ), - # ), - # export_names=(("rope_theta",),), - # ), - PixtralRotaryParamConverter( - fast_llm_names=(("transformer", "rotary"),), - export_names=(("rope_theta",),), - ), - RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), - ] - - def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - return [ - SplitWeightConverter( - f"{fast_llm_prefix}.mlp.layer_1.weight", - (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), - ), - MLPLayer2Converter( - f"{fast_llm_prefix}.mlp.layer_2.weight", - f"{hf_prefix}.feed_forward.down_proj.weight", - self._model.config.base_model, - ), - ] - - def _create_vision_transformer_layer_converters( - self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" - ) -> list[WeightConverter]: - # Vision transformer layer - transformer_config = self._model.config.base_model.vision_encoder.transformer - norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) - name_bias_cls = [ - # Self-attn - ( - f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", - f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.q_proj", - transformer_config.add_attn_qkv_bias, - QueryWeightConverter, - ), - ( - f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", - ( - f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.k_proj", - f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.v_proj", - ), - transformer_config.add_attn_qkv_bias, - KeyValueWeightConverter, - ), - ( - f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", - f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.o_proj", - transformer_config.add_attn_dense_bias, - WeightConverter, - ), - # Norm - ( - f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", - f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention_norm", - norm_bias, - WeightConverter, - ), - ( - f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", - f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.ffn_norm", - norm_bias, - WeightConverter, - ), - ] - converters = [] - for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: - converters += self._get_weight_and_bias_converters( - fast_llm_prefix, - hf_prefix, - use_bias, - cls, - ) - # MLP - converters += self._get_transformer_mlp_converters( - f"layers.{fast_llm_offset + transformer_layer_index}", - f"{hf_base_prefix}transformer.layers.{transformer_layer_index}", - ) - return converters - - def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: - converters = [] - norm_bias = isinstance(self._model.config.base_model.vision_encoder.patch_norm, LayerNormalizationConfig) - converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) - if self._model.config.base_model.vision_encoder.conv_bias: - converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) - converters.append(WeightConverter(f"layers.{offset}.norm.weight", f"{hf_base_prefix}ln_pre.weight")) - if norm_bias: - converters.append(WeightConverter(f"layers.{offset}.norm.bias", f"{hf_base_prefix}ln_pre.bias")) - - num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers - for i in range(num_layers): - converters += self._create_vision_transformer_layer_converters(i, offset + 1, hf_base_prefix) - - converters.extend( - [ - WeightConverter( - f"layers.{offset + num_layers + 1}.layer_1.weight", "multi_modal_projector.linear_1.weight" - ), - WeightConverter( - f"layers.{offset + num_layers + 1}.layer_2.weight", "multi_modal_projector.linear_2.weight" - ), - ] - ) - if self._model.config.base_model.vision_encoder.adapter_bias: - converters.extend( - [ - WeightConverter( - f"layers.{offset + num_layers + 1}.layer_1.bias", "multi_modal_projector.linear_1.bias" - ), - WeightConverter( - f"layers.{offset + num_layers + 1}.layer_2.bias", "multi_modal_projector.linear_2.bias" - ), - ] - ) - - return converters - - @property - def num_layers(self) -> int: - # +2 for projector and conv layers - return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index ad76c56de..5c0a2216c 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -18,7 +18,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, can_return_tuple, logging +from transformers.utils import LossKwargs, logging from transformers.utils.generic import ModelOutput from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig @@ -357,13 +357,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx if len(self.key_cache) <= layer_idx: return 0 - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or not self.key_cache[layer_idx].numel() # the layer has no cache - ) - return self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - # return self.key_cache[layer_idx].shape[-2] + return self.key_cache[layer_idx].shape[-2] def reset(self): self.conv_states.zero_() @@ -892,7 +886,7 @@ def forward( self, hidden_states: torch.Tensor, past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - mamba_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, return_mixer_matrix=False, **kwargs, ): @@ -904,10 +898,6 @@ def forward( assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda" cache_position = kwargs.get("cache_position", None) batch, seqlen, dim = hidden_states.shape - # mamba_mask = ( - # None if seqlen == 1 else mamba_mask - # ) # prevent that hidden_states are expanded to mask's seq. dimention., i.e. we do not need apply_mask_to_padding_states when generating single token at a time - # hidden_states = apply_mask_to_padding_states(hidden_states, mamba_mask) ssm_state, conv_state = None, None use_precomputed_states = False @@ -988,7 +978,7 @@ def forward( # Update state (B D W) conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]).transpose(1, 2) + x = self.act(self.conv1d(x)[..., :seqlen]) else: assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( @@ -996,10 +986,7 @@ def forward( weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, - ) # .transpose(1, 2) - # x = apply_mask_to_padding_states(x, mamba_mask).transpose( - # 1, 2 - # ) # zero out everything that comes from padding tokens + ) if not self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) @@ -1054,14 +1041,14 @@ def step(self, hidden_states, conv_state, ssm_state): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbc = self.in_proj(hidden_states_input) - z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) + zxbcdt = self.in_proj(hidden_states_input) + z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group) C = rearrange(C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state).contiguous() - dt = self.dt_proj(self.dt_in_proj(hidden_states_input)) # B, d_inner + dt = self.dt_proj(dt) # B, d_inner if self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) @@ -1229,42 +1216,6 @@ def __init__(self, config: AprielHybridSSMConfig, **kwargs): # Initialize weights and apply final processing self.post_init() - @can_return_tuple - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> BaseModelOutputWithPast: - use_cache = use_cache if use_cache is not None else self.config.use_cache - if use_cache and past_key_values is None: - # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) - batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] - past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) - output = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **flash_attn_kwargs, - ) - past_key_values: HybridMambaAttentionDynamicCache = output.past_key_values - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return output - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -1446,7 +1397,6 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, - mamba_mask=attention_mask, # non-expended mask **kwargs, ) diff --git a/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py b/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py deleted file mode 100644 index b8e822d9f..000000000 --- a/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py +++ /dev/null @@ -1,117 +0,0 @@ -from transformers import MistralConfig -from transformers.configuration_utils import PretrainedConfig -from transformers.models.auto import CONFIG_MAPPING -from transformers.utils import logging - -logger = logging.get_logger(__name__) - -# Copied from configuration_ssm_hybrid_apriel15b.py -# TODO: split into mamba 2 and discrete mamba 2 configs with a base dict -ssm_config_default = { - # discrete mamba2 - "d_state": 64, - "n_v_heads": 32, - "n_qk_heads": 32, - "expand": 1, - "chunk_size": 128, - "activation": "identity", - "bias": False, - "d_conv": 4, - "d_inner": 32 * 128, - # mamba2 - "d_xb": None, # will be set to model dim - "dt_rank": "auto", - "dt_min": 0.001, - "dt_max": 0.1, - "dt_init": "random", - "dt_scale": 1.0, - "dt_init_floor": 1e-4, - "conv_bias": True, -} - - -class AprielSSMHybridConfig(MistralConfig): - model_type = "apriel_ssm_thinker_hybrid" - - def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): - super().__init__(**kwargs) - self.hybrid_block_layout = hybrid_block_layout - self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 - self.ssm_cfg = ssm_cfg or ssm_config_default - - for k, v in ssm_config_default.items(): - if k not in self.ssm_cfg: - self.ssm_cfg[k] = v # to make sure all elements are present in the config - - -class LlavaHybridConfig(PretrainedConfig): - """ - Configuration class for Llava SSM-Hybrid-decoder model. - """ - - model_type = "llava_hybrid" - - def __init__( - self, - vision_config=None, - text_config=None, - image_token_index=32000, - projector_hidden_act="gelu", - projector_intermediate_size=4096, - vision_feature_select_strategy="default", - vision_feature_layer=-2, - image_seq_length=576, - multimodal_projector_bias=True, - **kwargs, - ): - self.image_token_index = image_token_index - self.projector_hidden_act = projector_hidden_act - # projector_intermediate_size is an addition to the original Llava config - self.projector_intermediate_size = projector_intermediate_size - self.image_seq_length = image_seq_length - - if vision_feature_select_strategy not in ["default", "full"]: - raise ValueError( - "vision_feature_select_strategy should be one of 'default', 'full'." - f"Got: {vision_feature_select_strategy}" - ) - - self.vision_feature_select_strategy = vision_feature_select_strategy - self.vision_feature_layer = vision_feature_layer - - if isinstance(vision_config, dict): - vision_config["model_type"] = ( - vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" - ) - vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) - elif vision_config is None: - vision_config = CONFIG_MAPPING["clip_vision_model"]( - intermediate_size=4096, - hidden_size=1024, - patch_size=14, - image_size=336, - num_hidden_layers=24, - num_attention_heads=16, - vocab_size=32000, - projection_dim=768, - ) - - self.vision_config = vision_config - - if isinstance(text_config, dict): - # Load the custom SSM hybrid config if specified - if text_config.get("model_type") == "apriel_ssm_thinker_hybrid": - text_config = AprielSSMHybridConfig(**text_config) - else: - text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" - text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) - elif text_config is None: - text_config = CONFIG_MAPPING["llama"]() - - self.text_config = text_config - self.multimodal_projector_bias = multimodal_projector_bias - - super().__init__(**kwargs) - - -__all__ = ["LlavaHybridConfig"] diff --git a/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py b/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py deleted file mode 100644 index 68073f9cd..000000000 --- a/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py +++ /dev/null @@ -1,132 +0,0 @@ -from torch import nn -from transformers import AutoModel, LlavaForConditionalGeneration, LlavaModel -from transformers.activations import ACT2FN - -from .configuration_llava_hybrid import LlavaHybridConfig - -try: - # In the fast-llm repo, import from the SSM modeling file - from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( - AprielThinkerSSMHybridModel, - HybridMambaAttentionDynamicCache, - ) -except ImportError: - # In the exported checkpoint, import from local file - from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache - - -class LlavaMultiModalProjector(nn.Module): - def __init__(self, config: LlavaHybridConfig): - super().__init__() - # We have hidden_size * the number of vision feature layers - num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) - self.linear_1 = nn.Linear( - config.vision_config.hidden_size * num_feature_layers, - config.projector_intermediate_size, - bias=config.multimodal_projector_bias, - ) - self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = nn.Linear( - config.projector_intermediate_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias - ) - - def forward(self, image_features): - hidden_states = self.linear_1(image_features) - hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states - - -class LlavaHybridModel(LlavaModel): - """ - Llava SSM-Hybrid-decoder model. - """ - - config_class = LlavaHybridConfig - - def __init__(self, config: LlavaHybridConfig): - super(LlavaModel, self).__init__(config) - self.vision_tower = AutoModel.from_config(config.vision_config) - - self.multi_modal_projector = LlavaMultiModalProjector(config) - assert ( - config.text_config.model_type == "apriel_ssm_thinker_hybrid" - ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" - - self.language_model = AprielThinkerSSMHybridModel(config.text_config) - self.post_init() - - -class LlavaHybridForConditionalGeneration(LlavaForConditionalGeneration): - config_class = LlavaHybridConfig - - def __init__(self, config: LlavaHybridConfig): - super(LlavaForConditionalGeneration, self).__init__(config) - self.model = LlavaHybridModel(config) - self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) - self.post_init() - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - output_router_logits=False, - cache_position=None, - position_ids=None, - use_cache=True, - pixel_values=None, - **kwargs, - ): - # Copy of the method from `AprielThinkerSSMHybridForCausalLM` - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) - - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if not empty_past_kv: - if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - else: - past_key_values = HybridMambaAttentionDynamicCache( - self.config.text_config, input_ids.shape[0], self.dtype, device=self.device - ) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if not empty_past_kv: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and empty_past_kv: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases - - # Copy from `LlavaForConditionalGeneration.prepare_inputs_for_generation` - if cache_position[0] == 0: - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore - # Otherwise we need pixel values to be passed to model - model_inputs["pixel_values"] = pixel_values - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "output_router_logits": output_router_logits, - # "logits_to_keep": self.config.num_logits_to_keep, - "cache_position": cache_position, - } - ) - return model_inputs From 43ca9133d6c6eecaa37a043f67a997d80b4e2a9e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 15:57:44 -0400 Subject: [PATCH 06/15] cleanup --- fast_llm/engine/multi_stage/stage.py | 2 +- fast_llm/engine/schedule/config.py | 6 ------ fast_llm/engine/training/config.py | 2 +- fast_llm/functional/triton/mlp.py | 6 ++++-- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 132bfd383..7829c243b 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -153,7 +153,7 @@ def backward( assert self._mode.support_backward input_, output = grad_context output.backward(output_grad) - return input_.grad if input_.grad is not None else torch.zeros_like(input_) + return input_.grad def restore_parameters(self) -> None: assert self._is_setup diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index a5e0a86a6..272b7c6ae 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -48,12 +48,6 @@ class BatchConfig(Config): desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", hint=FieldHint.setup, ) - # Image inputs - max_image_size: int | None = Field( - default=None, - desc="Maximum image height and width", - hint=FieldHint.optional, - ) def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 809d46801..531bc206b 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -361,7 +361,7 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - # Assert.eq(self.model.distributed.tensor_parallel, 1) + Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index f3d9d7d0c..ab408368f 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -47,7 +47,8 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": + # Triton doesn't like enums, so we use str instead of ActivationType. + if activation_type == "gelu": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) @@ -97,7 +98,8 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": + # Triton doesn't like enums, so we use str instead of ActivationType. + if activation_type == "gelu": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) From 15405a10d195752a9c7d6c983c5b764d78196b72 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 16:11:30 -0400 Subject: [PATCH 07/15] cleanup --- setup.cfg | 6 +-- tests/data/test_sampling.py | 18 +++------ tests/layers/test_lm_head.py | 2 + tests/utils/model_configs.py | 72 ------------------------------------ 4 files changed, 11 insertions(+), 87 deletions(-) diff --git a/setup.cfg b/setup.cfg index f65f21a85..2a1614554 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,11 +52,11 @@ HUGGINGFACE = # To install on cpu environment (ex. for IDE support): # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = - mamba_ssm[causal-conv1d] @ git+https://github.com/jxiw/varlen_mamba.git@varlen_mamba + mamba_ssm[causal-conv1d]==2.2.4 cartesia_pytorch>=0.0.2 -# GENERATION = -# lm_eval>=0.4.9 +GENERATION = + lm_eval>=0.4.9 # Required for supporting vision inputs diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 2f9b143b9..6a2be3dcc 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters, ShufflingType +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, ShufflingType from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.utils import Assert @@ -78,17 +78,11 @@ def get(self, index: int, offset=0, length=None, use_loss_masking_spans: bool = def __len__(self) -> int: return len(self._samples) - def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: - doc_sizes = [] - im_sizes = [] - for index in range(len(self)): - doc_size, im_size = self.get_document_size(index) - doc_sizes.append(doc_size) - im_sizes.append(im_size) - return np.array(doc_sizes, dtype=np.int64), np.array(im_sizes, dtype=np.int64) - - def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: - return len(self._samples[index]), [] + def get_document_sizes(self) -> np.ndarray: + return np.array([self.get_document_size(index) for index in range(len(self))], dtype=np.int64) + + def get_document_size(self, index: int) -> int: + return len(self._samples[index]) def name(self) -> str: return "dataset" diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c836df9fe..f14f028e1 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -23,10 +23,12 @@ def _reverse_kl_loss( ): scaled_target = target / teacher_softmax_temperature + scaled_target = torch.clamp(target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) with torch.enable_grad(): # Use log_softmax for consistency instead of _fused_softmax + logits = torch.clamp(logits, min=-50, max=50) student_log_probs = torch.log_softmax(logits, dim=-1) if loss_mask is None: loss = torch.nn.functional.kl_div( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index a818a1f23..aa8100126 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -16,7 +16,6 @@ DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, LlamaCheckpointFormat, - LlavaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, MTPLlamaCheckpointFormat, @@ -681,77 +680,6 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms"), ) -_update_and_add_testing_config( - # Tests hybrid Mamba, llamba converter. - "llama", - "llava", - extra_args=[ - "batch.max_image_size=128", - "model.base_model.vision_encoder.type=pixtral", - "model.base_model.vision_encoder.patch_norm.type=rms_norm", - "model.base_model.vision_encoder.transformer.add_linear_biases=False", - "model.base_model.vision_encoder.transformer.causal=False", - "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", - "model.base_model.vision_encoder.transformer.type=image_encoder", - "model.base_model.vision_encoder.transformer.gated=True", - "model.base_model.vision_encoder.transformer.num_layers=2", - "model.base_model.vision_encoder.transformer.hidden_size=256", - "model.base_model.vision_encoder.transformer.num_attention_heads=8", - "model.base_model.vision_encoder.transformer.head_groups=8", - "model.base_model.vision_encoder.transformer.init_method_std=0.022", - "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", - "model.base_model.vision_encoder.adapter_size=256", - "model.distributed.training_dtype=torch.bfloat16", - ], - megatron_args=None, - checkpoint_format=LlavaCheckpointFormat, - groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, - ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, - }, - compare_factor=8.0, -) - -_update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "hybrid_mamba2", - "vision_hybrid_mamba2", - model_type="hybrid_ssm", - extra_args=[ - "batch.max_image_size=128", - "model.base_model.vision_encoder.type=pixtral", - "model.base_model.vision_encoder.patch_norm.type=rms_norm", - "model.base_model.vision_encoder.transformer.add_linear_biases=False", - "model.base_model.vision_encoder.transformer.causal=False", - "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", - "model.base_model.vision_encoder.transformer.type=image_encoder", - "model.base_model.vision_encoder.transformer.gated=True", - "model.base_model.vision_encoder.transformer.num_layers=2", - "model.base_model.vision_encoder.transformer.hidden_size=256", - "model.base_model.vision_encoder.transformer.num_attention_heads=8", - "model.base_model.vision_encoder.transformer.head_groups=8", - "model.base_model.vision_encoder.transformer.init_method_std=0.022", - "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", - "model.base_model.vision_encoder.adapter_size=512", - "model.distributed.training_dtype=torch.bfloat16", - ], - megatron_args=None, - checkpoint_format=LlavaHybridHuggingfaceCheckpointFormat, - groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, - ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, - }, - compare_factor=16.0, -) - @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: From a3dc89d979e41ede09bee5aad8d9d0f8fa8d8445 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 16:18:33 -0400 Subject: [PATCH 08/15] stuff --- fast_llm/engine/config_utils/data_type.py | 2 +- fast_llm/engine/config_utils/run.py | 2 +- fast_llm/layers/attention/attention.py | 8 +++-- fast_llm/layers/block/block.py | 13 ------- fast_llm/layers/block/config.py | 2 -- fast_llm/layers/decoder/block.py | 36 ++++++++++++++++--- fast_llm/layers/decoder/config.py | 23 ++++++++++-- .../layers/decoder/mlp/mixture_of_experts.py | 6 ++-- fast_llm/layers/decoder/mlp/mlp.py | 4 ++- fast_llm/layers/language_model/embedding.py | 4 --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++- fast_llm/layers/ssm/mamba.py | 9 ++--- fast_llm/layers/ssm/mamba2.py | 9 ++--- 13 files changed, 76 insertions(+), 46 deletions(-) diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index 0929b7cb1..f4a2cfd6c 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -9,7 +9,7 @@ from triton import language as tl -class DataType(str, enum.Enum): +class DataType(enum.StrEnum): """ An enum to represent data types independently of third party libraries, so we can swap them more easily and allow for lazy imports. diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 1737f4308..1849a2316 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -136,7 +136,7 @@ def __init__( self._distributed.config.data_rank == 0 and self._distributed.config.tensor_rank == 0 ) config_dict = config.to_dict() - config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance) + config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.debug) if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 9a940f4cb..2d4f049f0 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -64,6 +64,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -71,6 +72,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) @@ -273,7 +275,7 @@ def _query_key_value_backward( input_grad.add_(self.key_value.backward(key_value_grad, context.pop("key_value"))) return input_grad - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], @@ -340,7 +342,7 @@ def forward( max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), dropout_p=self._config.dropout if self.training else 0.0, window_size=window_size, - causal=True, + causal=self._config.causal, softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -350,7 +352,7 @@ def forward( value, window_size=window_size, dropout_p=self._config.dropout if self.training else 0.0, - causal=True, + causal=self._config.causal, softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 773cce87e..0f975c9c5 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -123,16 +123,3 @@ class Block[ConfigType: Config](BaseBlock[ConfigType], Layer): """ Base class for actual blocks, i.e., base blocks that are also `Layers`. """ - - def __init__( - self, - config: ConfigType, - distributed_config: DistributedConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - return_input: bool = False, - ): - super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) - self._return_input = return_input diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index df5bd8181..47c1ab9b7 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -87,7 +87,6 @@ def get_block( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - return_input: bool = False, ) -> "Block": return self.layer_class( self, @@ -95,7 +94,6 @@ def get_block( hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, - return_input=return_input, ) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index ba4c370c2..ce9893ea4 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.block import BaseBlock, Block +from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig @@ -19,18 +19,44 @@ logger = logging.getLogger(__name__) -class BlockWithBias[ConfigType: Config](BaseBlock[ConfigType]): +class BlockWithBias[ConfigType: Config](Block[ConfigType]): """ Base class for mixer and MLP modules. """ - @abc.abstractmethod + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) + self._return_bias = return_bias + def forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor: + output, bias = self._forward(input_, kwargs, losses, metrics) + if self._return_bias: + return output, bias + else: + return output if bias is None else output + bias + + @abc.abstractmethod + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: pass @@ -58,7 +84,7 @@ def __init__( peft=peft, ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. - self._return_input: bool = return_input + self._return_input = return_input # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) @@ -70,6 +96,7 @@ def __init__( self._hidden_dim, self._lr_scale, peft=peft, + return_bias=True, ) self.mlp = self._config.mlp.get_layer( @@ -77,6 +104,7 @@ def __init__( self._hidden_dim, self._lr_scale, peft=peft, + return_bias=True, ) def setup(self, distributed: Distributed) -> None: diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 5f8131b5c..724b8d172 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BaseBlockConfig, BlockConfig +from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -15,7 +15,7 @@ @config_class() -class BlockWithBiasConfig(BaseBlockConfig): +class BlockWithBiasConfig(BlockConfig): """ A common interface for various blocks and block layers. """ @@ -30,6 +30,7 @@ def get_layer( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = False, ) -> "BlockWithBias": return self.layer_class( self, @@ -37,6 +38,7 @@ def get_layer( hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, + return_bias=return_bias, ) @@ -94,6 +96,23 @@ def layer_class(self) -> "type[DecoderBlock]": return DecoderBlock + def get_block( + self, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_input: bool = False, + ) -> "DecoderBlock": + return self.layer_class( + self, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + return_input=return_input, + ) + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: return self.mixer.get_preprocessors(distributed_config) + self.mlp.get_preprocessors(distributed_config) diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 089fa2dc7..d4cb46dbf 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -46,6 +46,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): Assert.gt(config.experts, 1) # TODO: Implement? @@ -56,6 +57,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self.router = self._config.router.get_layer( self._hidden_dim, @@ -83,9 +85,9 @@ def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, None]: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug.enabled: diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 9dd17d698..aaea94adb 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -28,6 +28,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -35,6 +36,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, self._intermediate_2_dim = self._get_intermediate_dims() @@ -102,7 +104,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): _config: MLPConfig - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d1e13a5b..362ffaa22 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -36,17 +36,13 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - return_input: bool = False, ): - if return_input: - raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, - return_input=return_input, ) self._residual_dtype = ( self._distributed_config.optimization_dtype diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index f014012b2..c9fc609b0 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -43,6 +43,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -50,6 +51,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) state_dim = TensorDim("state", self._config.state_size) v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) @@ -128,7 +130,7 @@ def __init__( peft=self._peft, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index e77a4468b..081aabe65 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -43,13 +43,10 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( - config, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for Mamba" @@ -120,7 +117,7 @@ def __init__( peft=self._peft, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b0657313d..4b0bd4366 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -41,13 +41,10 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( - config, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) num_heads = div(self._config.d_inner, self._config.state_size) @@ -153,7 +150,7 @@ def __init__( BlockDimNames.sequence_q, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], From 414f87edb881247b8351a1d09e41fe883de93247 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 16:19:07 -0400 Subject: [PATCH 09/15] stuff --- fast_llm/layers/language_model/head.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index ade1144d2..4b080b360 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -52,17 +52,13 @@ def __init__( lr_scale: float | None, peft: PeftConfig | None, prediction_distance: int, - return_input: bool = False, ): - if return_input: - raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, - return_input=return_input, ) self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) From 4a21360ed07277ba90e5bdbefe6cf3a43a83a9fd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 16:25:34 -0400 Subject: [PATCH 10/15] stuff --- fast_llm/layers/language_model/config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index f59b4cffd..1af6bdc38 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -395,7 +395,12 @@ def get_blocks(self, distributed_config: DistributedConfig): peft=self.peft, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self.output_layer.prediction_heads > 1 and i == len(self.decoder) - 1, + # TODO: Not all blocks support this argument. + **( + {"return_input": True} + if self.output_layer.prediction_heads > 1 and i == len(self.decoder) - 1 + else {} + ), ) for i in range(len(self.decoder)) ], From bb7c62d086ac0cf405ff83a358c78f03e77e4f17 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 19:55:44 -0400 Subject: [PATCH 11/15] Embeddings --- fast_llm/core/ops.py | 8 +- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/attention/rotary/rotary.py | 2 +- fast_llm/layers/language_model/config.py | 4 +- fast_llm/layers/language_model/embedding.py | 60 ++++-- fast_llm/layers/multi_modal/embedding.py | 183 ------------------ fast_llm/layers/vision/__init__.py | 0 .../{vision_encoder => vision}/config.py | 0 .../patch_convolution.py | 2 +- .../preprocessing.py | 3 +- 10 files changed, 58 insertions(+), 206 deletions(-) delete mode 100644 fast_llm/layers/multi_modal/embedding.py create mode 100644 fast_llm/layers/vision/__init__.py rename fast_llm/layers/{vision_encoder => vision}/config.py (100%) rename fast_llm/layers/{vision_encoder => vision}/patch_convolution.py (96%) rename fast_llm/layers/{vision_encoder => vision}/preprocessing.py (98%) diff --git a/fast_llm/core/ops.py b/fast_llm/core/ops.py index a7492daa5..bb61aadd0 100644 --- a/fast_llm/core/ops.py +++ b/fast_llm/core/ops.py @@ -26,7 +26,7 @@ def reduce_op( return (input_, handle) if async_op else input_ -def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: +def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: """Split the tensor along its last dimension and keep the corresponding slice.""" if group: @@ -139,11 +139,11 @@ class _Split(torch.autograd.Function): """Split the input and keep only the corresponding chuck to the rank.""" @staticmethod - def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa + def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa return split_op(input_, group, dim) @staticmethod - def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa + def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa ctx.group = group ctx.dim = dim return split_op(input_, group, dim) @@ -209,7 +209,7 @@ def reduce_backward(input_: torch.Tensor, group: ProcessGroup | None) -> torch.T @torch._dynamo.disable # noqa -def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: +def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: return _Split.apply(input_, group, dim) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 388ef9d41..60b20e4d3 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -14,7 +14,7 @@ from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims +from fast_llm.layers.vision.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert, div try: diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 6a85f43e5..189596352 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -17,7 +17,7 @@ RotaryConfig, YarnRotaryConfig, ) -from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision.config import VisionEncoderKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import div diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 216db6de3..e852cf96f 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -9,7 +9,7 @@ from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.vision_encoder.config import VisionEncoderConfig +from fast_llm.layers.vision.config import VisionEncoderConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -32,7 +32,9 @@ def multi_token_prediction_loss(index: int) -> str: class LanguageModelKwargs(BlockKwargs): + token_ids = "token_ids" position_ids = "position_ids" + embedding_map = "embedding_map" # TODO: These are generic labels = "labels" tokens = "tokens" diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 362ffaa22..8625945ce 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -3,7 +3,7 @@ import torch from fast_llm.core.distributed import set_generator -from fast_llm.core.ops import reduce_forward, split +from fast_llm.core.ops import gather, reduce_forward, split from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim @@ -73,34 +73,62 @@ def __init__( ) @torch.compile - def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: + def _forward( + self, + input_: torch.Tensor, + token_ids: torch.Tensor, + position_ids: torch.Tensor | None, + mask_inputs: bool, + # TODO: Flatten the batch and sequence in the map? + embedding_map: tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]] | None, + ) -> torch.Tensor: Assert.eq(position_ids is None, self.position_embeddings_weight is None) group = self._parallel_dim.group if self._vocab_parallel: - input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) - masked_input = (input_ - self._vocab_start_index) * input_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa + token_mask = (token_ids >= self._vocab_start_index) * (token_ids < self._vocab_end_index) + masked_input = (token_ids - self._vocab_start_index) * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * token_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) + # TODO: Input masking of position embeddings inconsistant with non-vocab-parallel if self.position_embeddings_weight is not None: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + + if embedding_map is not None: + # TODO: Accumulate redundant with masking? + input_index, embedding_index = embedding_map + if self._sequence_parallel: + input_ = gather(input_, group=group, dim=0) + embeddings = embeddings.index_put(embedding_index, input_[input_index], accumulate=True) + if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: - input_ = split(input_, group=group, dim=0) + token_ids = split(token_ids, group=group, dim=0) if self.position_embeddings_weight is not None: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: - input_mask = input_ >= 0 - masked_input = input_ * input_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_input) - else: - embeddings = torch.embedding(self.word_embeddings_weight, input_) + token_mask = token_ids >= 0 + token_ids = token_ids * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, token_ids) if self.position_embeddings_weight is not None: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: - embeddings = embeddings * input_mask.unsqueeze(2) + embeddings = embeddings * token_mask.unsqueeze(2) + + if embedding_map is not None: + # TODO: Accumulate redundant with masking? + input_index, embedding_index = embedding_map + if self._sequence_parallel: + # TODO:: Filter and shift embedding map instead? (needs cuda sync) + input_ = gather(input_, group=group, dim=0) + embeddings_ = embeddings.new_zeros(embeddings.shape[0] * group.size(), *embeddings.shape[1:]) + embeddings_.index_put(embedding_index, input_[input_index], accumulate=True) + embeddings = embeddings + split(embeddings_, group=group, dim=0) + else: + embeddings = embeddings.index_put(embedding_index, input_[input_index], accumulate=True) + with set_generator( self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): @@ -120,8 +148,14 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) + return self._forward( - input_, kwargs.get(LanguageModelKwargs.position_ids), kwargs.get(LanguageModelKwargs.mask_inputs) + input_, + kwargs.get(LanguageModelKwargs.token_ids), + kwargs.get(LanguageModelKwargs.position_ids), + # TODO ====== Vision ====== Review input masking. + kwargs.get(LanguageModelKwargs.mask_inputs), + kwargs.get(LanguageModelKwargs.embedding_map), ) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py deleted file mode 100644 index a5a789f9e..000000000 --- a/fast_llm/layers/multi_modal/embedding.py +++ /dev/null @@ -1,183 +0,0 @@ -import typing - -import torch - -from fast_llm.core.distributed import set_generator -from fast_llm.core.ops import reduce_forward, split -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs -from fast_llm.layers.vision_encoder.preprocessing import get_num_patches -from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert, div - - -class MultiModalEmbedding(LanguageModelEmbedding): - """ - Multi-modal embedding layer to combine embeddings from text, image and more modalities. - """ - - def __init__( - self, - config: LanguageModelBaseConfig, - tensor_space: TensorSpace, - ): - super().__init__(config, tensor_space) - - # @torch.compile - def _forward( - self, - input_: torch.Tensor, - tokens: torch.Tensor, - position_ids: torch.Tensor | None, - image_positions: list[torch.Tensor] | None, - image_sizes: list[list[tuple[int, int]]] | None, - ) -> torch.Tensor: - """ - Forward pass for the multi-modal embedding layer. - Args: - input_: The input tensor (image embeddings). - tokens: The tokenized text input. - position_ids: The position ids for the text input. - image_positions: The positions of the image tokens in the input. - image_sizes: The sizes of the images in the input. - Returns: - The combined embeddings for text and images. - """ - Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) - group = self._tensor_space.distributed.tensor_group - if self._sequence_parallel: - micro_seqlen = input_.size(0) - patch_start_offset = self._distributed_config.tensor_rank * micro_seqlen - patch_end_offset = (self._distributed_config.tensor_rank + 1) * micro_seqlen - else: - patch_start_offset = 0 - patch_end_offset = input_.size(0) - if self._parallel_embeddings: - token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) - masked_tokens = (tokens - self._vocab_start_index) * token_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa - # Cloning since we will modify the embeddings in-place - embeddings = embeddings.clone() - # the embeddings tensor are full-sized, but we might get a split of the patch embeddings - # We need to determine the offset in the embeddings tensor for each sample - # and also account for the special image tokens if applicable - for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): - image_embedding_offset = 0 - for position, size in zip(positions, sizes): - num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) - if image_embedding_offset + num_patches < patch_start_offset: - image_embedding_offset += num_patches - continue - if self._config.vision_encoder.image_break_token is not None: - patch_height = div(size[0], self._config.vision_encoder.patch_size) - patch_width = div(size[1], self._config.vision_encoder.patch_size) - for row in range(patch_height): - row_start_src = image_embedding_offset + row * patch_width - row_start_dst = position + row * (patch_width + 1) - if row_start_src > patch_end_offset: - break - if row_start_src + patch_width <= patch_start_offset: - continue - - input_start_index = max(row_start_src, patch_start_offset) - patch_start_offset - input_end_index = min(row_start_src + patch_width, patch_end_offset) - patch_start_offset - embeddings_start_index = row_start_dst + max(patch_start_offset - row_start_src, 0) - embeddings_end_index = ( - row_start_dst + patch_width - max(row_start_src + patch_width - patch_end_offset, 0) - ) - # row_end_src = min(row_start_src + patch_width, patch_end_offset) - if self._sequence_parallel: - embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ - input_start_index:input_end_index, sample_idx - ] - else: - embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ - sample_idx, input_start_index:input_end_index - ] - else: - input_start_index = max(image_embedding_offset, patch_start_offset) - patch_start_offset - input_end_index = ( - min(image_embedding_offset + num_patches, patch_end_offset) - patch_start_offset - ) - embedding_start_index = position - max(patch_start_offset - image_embedding_offset, 0) - embedding_end_index = ( - position + num_patches - max(image_embedding_offset + num_patches - patch_end_offset, 0) - ) - embeddings[sample_idx, embedding_start_index:embedding_end_index] = input_[ - input_start_index:input_end_index, sample_idx - ] - # embeddings[sample_idx, position : position + num_patches] = input_[ - # sample_idx, image_embedding_offset : image_embedding_offset + num_patches - # ] - image_embedding_offset += num_patches - if image_embedding_offset > patch_end_offset: - break - embeddings = reduce_forward(embeddings, group) - if self._use_absolute_position_embeddings: - embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) - if self._sequence_parallel: - embeddings = split(embeddings, group=group, dim=0) - else: - if self._sequence_parallel: - tokens = split(tokens, group=group, dim=0) - if self._use_absolute_position_embeddings: - position_ids = split(position_ids, group=group, dim=0) - # mask padded tokens - token_mask = tokens >= 0 - masked_tokens = tokens * token_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) - embeddings = embeddings.clone() - for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): - image_embedding_offset = 0 - for position, size in zip(positions, sizes): - num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) - if self._config.vision_encoder.image_break_token is not None: - patch_height = div(size[0], self._config.vision_encoder.patch_size) - patch_width = div(size[1], self._config.vision_encoder.patch_size) - - for row in range(patch_height): - row_start_src = image_embedding_offset + row * patch_width - row_start_dst = position + row * (patch_width + 1) - - embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ - sample_idx, row_start_src : row_start_src + patch_width - ] - else: - embeddings[sample_idx, position : position + num_patches] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_patches - ] - # Move to the next image in the input tensor - image_embedding_offset += num_patches - - if self._use_absolute_position_embeddings: - embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) - with set_generator( - self._tensor_space.distributed.tp_generator - if self._sequence_parallel - else self._tensor_space.distributed.pp_generator - ): - embeddings = torch.dropout(embeddings, self._dropout_p, self.training) - return embeddings.to(dtype=self._residual_dtype) - - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict | None = None, - ) -> torch.Tensor: - if isinstance(input_, TensorMeta): - return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], - tensor_name="Embedding output", - dtype=self._residual_dtype, - ) - position_ids = kwargs.get(LanguageModelKwargs.position_ids) - image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) - image_positions = kwargs.get(VisionEncoderKwargs.image_positions) - tokens = kwargs.get(LanguageModelKwargs.tokens) - - return self._forward(input_, tokens, position_ids, image_positions, image_sizes) diff --git a/fast_llm/layers/vision/__init__.py b/fast_llm/layers/vision/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision/config.py similarity index 100% rename from fast_llm/layers/vision_encoder/config.py rename to fast_llm/layers/vision/config.py diff --git a/fast_llm/layers/vision_encoder/patch_convolution.py b/fast_llm/layers/vision/patch_convolution.py similarity index 96% rename from fast_llm/layers/vision_encoder/patch_convolution.py rename to fast_llm/layers/vision/patch_convolution.py index 0a41efe2c..b40a9593e 100644 --- a/fast_llm/layers/vision_encoder/patch_convolution.py +++ b/fast_llm/layers/vision/patch_convolution.py @@ -8,7 +8,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.vision_encoder.config import PatchConvolutionConfig, VisionEncoderKwargs +from fast_llm.layers.vision.config import PatchConvolutionConfig, VisionEncoderKwargs from fast_llm.tensor import TensorMeta diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision/preprocessing.py similarity index 98% rename from fast_llm/layers/vision_encoder/preprocessing.py rename to fast_llm/layers/vision/preprocessing.py index 65eac6c45..83246608e 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision/preprocessing.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.vision_encoder.config import ImageNormalizationConfig, VisionEncoderConfig, VisionEncoderKwargs +from fast_llm.layers.vision.config import ImageNormalizationConfig, VisionEncoderConfig, VisionEncoderKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -130,7 +130,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_position_ids = [] sequence_lengths = [0] max_sequence_length = -1 - kwargs.get(AttentionKwargs.sequence_first) for sample_index, (sample_images_, positions) in enumerate( zip(kwargs[VisionEncoderKwargs.images], kwargs.get(VisionEncoderKwargs.image_positions), strict=True) From 47b9a448fa74aa71a6538f7208ca77ffe2972f7c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 30 Sep 2025 22:39:09 -0400 Subject: [PATCH 12/15] Model interface --- fast_llm/engine/base_model/base_model.py | 108 +++--- fast_llm/engine/base_model/config.py | 34 +- fast_llm/engine/evaluation/evaluator.py | 2 +- fast_llm/engine/multi_stage/config.py | 5 - fast_llm/engine/multi_stage/multi_stage.py | 15 +- fast_llm/engine/multi_stage/stage.py | 16 +- fast_llm/engine/multi_stage/stage_base.py | 13 +- fast_llm/engine/schedule/runner.py | 2 +- fast_llm/engine/training/trainer.py | 12 +- fast_llm/layers/attention/attention.py | 117 +++++++ fast_llm/layers/attention/config.py | 18 - fast_llm/layers/attention/preprocessing.py | 153 --------- fast_llm/layers/attention/rotary/config.py | 6 +- fast_llm/layers/attention/rotary/rotary.py | 94 ++---- fast_llm/layers/block/block.py | 21 +- fast_llm/layers/block/config.py | 88 ++--- fast_llm/layers/block/sequence.py | 108 ++++++ .../layers/common/normalization/config.py | 4 +- fast_llm/layers/decoder/block.py | 13 +- fast_llm/layers/decoder/config.py | 9 +- fast_llm/layers/decoder/mlp/config.py | 21 -- .../layers/decoder/mlp/mixture_of_experts.py | 22 +- fast_llm/layers/language_model/config.py | 318 +++++++----------- fast_llm/layers/language_model/embedding.py | 32 ++ fast_llm/layers/language_model/head.py | 112 ++++-- .../language_model/multi_token_prediction.py | 77 +++++ .../layers/language_model/preprocessing.py | 107 ------ fast_llm/layers/vision/config.py | 37 +- fast_llm/layers/vision/patch_convolution.py | 26 +- fast_llm/layers/vision/preprocessing.py | 47 +-- fast_llm/models/gpt/config.py | 19 +- fast_llm/models/gpt/model.py | 145 +++++--- tests/layers/test_lm_head.py | 1 + tests/utils/utils.py | 2 +- 34 files changed, 866 insertions(+), 938 deletions(-) delete mode 100644 fast_llm/layers/attention/preprocessing.py create mode 100644 fast_llm/layers/language_model/multi_token_prediction.py delete mode 100644 fast_llm/layers/language_model/preprocessing.py diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 0a3f8d1ce..ce7002c54 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -1,23 +1,19 @@ import abc import typing -import torch import torch.nn from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import BaseModelConfig, ResourceUsageConfig +from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, ResourceUsageConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta -from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.engine.inference.runner import InferenceRunner -class Module(torch.nn.Module, abc.ABC): - """ """ - +class LayerBase(torch.nn.Module, abc.ABC): _is_setup: bool = False _distributed: Distributed @@ -27,57 +23,87 @@ def __init__(self, distributed_config: DistributedConfig): def setup(self, distributed: Distributed) -> None: assert not self._is_setup + for layer in self.get_layers(): + if layer is not self: + layer.setup(distributed) distributed.check_config(self._distributed_config) self._distributed = distributed self._is_setup = True + @abc.abstractmethod + def get_layers(self) -> list["Layer"]: + """ + The list of layers as meant to be seen by the Fast-LLM engine. + May differ from the module configuration seen by pytorch. + """ -class Layer(Module): - # Weight used to determine the stage size + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + out = 0 + for layer in self.get_layers(): + if layer is self: + raise NotImplementedError() + out += layer.get_compute_usage(input_, kwargs, config) + return out + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + losses = [] + for layer in self.get_layers(): + if layer is not self: + losses += layer.get_loss_definitions(count) + return losses + + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + for layer in self.get_layers(): + if layer is not self: + layer.preprocess(batch, kwargs) + + +class Layer(LayerBase): + # Weight used to determine the stage size. layer_count: float = 1.0 + def get_layers(self) -> list["Layer"]: + # Return a breakdown of the layer into atomic ones, + # i.e. the list of layers from as seen from the Fast-LLM model. + return [self] + @abc.abstractmethod def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: pass - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - raise NotImplementedError() +class LayerWithNamespace(Layer): + """ + A layer with its own namespace for preprocessing (kwargs), + so that it doesn't inadvertently interact with other layers. + TODO: Consider namespace for losses and metrics? + """ -class Sequential(Layer): - def __init__(self, distributed_config: DistributedConfig): - super().__init__(distributed_config) - self.layers = torch.nn.ModuleList(self.get_layers()) - - def __getitem__(self, item): - return self.layers[item] + def __init__(self, layer: Layer, namespace: str): + super().__init__(layer._distributed_config) + self._layer = layer + self._namespace = namespace + self.layer_count = self._layer.layer_count + self.get_compute_usage = self._layer.get_compute_usage - def __iter__(self): - return iter(self.layers) - - def __len__(self): - return len(self.layers) + def setup(self, distributed: Distributed) -> None: + self._layer.setup(distributed) + super().setup(distributed) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: - for layer in self.layers: - input_ = layer(input_, kwargs, losses, metrics) - return input_ + return self._layer.forward(input_, kwargs[self._namespace], losses, metrics) - @abc.abstractmethod - def get_layers(self) -> list[Layer]: - pass + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + assert self._namespace not in kwargs + kwargs[self._namespace] = kwargs.copy() + return self._layer.preprocess(batch, kwargs[self._namespace]) - def setup(self, distributed: Distributed) -> None: - super().setup(distributed) - for layer in self.layers: - layer.setup(distributed) - -class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential): +class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], LayerBase): def __init__( self, @@ -85,23 +111,14 @@ def __init__( distributed_config: DistributedConfig, ): super().__init__(config, distributed_config) - for key, value in self.named_modules(): - value.module_name = key - for key, value in self.named_parameters(): - Assert.custom(isinstance, value, ParameterMeta) - # Rename to the parameter full name - value.tensor_name = key # Reference models # TODO: Add basic handling (preprocessor) in this class. self._reference_models: dict[str, "InferenceRunner"] = {} - @abc.abstractmethod - def get_layers(self) -> list[Layer]: - pass - @abc.abstractmethod def preprocess_meta(self, batch_meta: typing.Any, phase: PhaseType) -> list[tuple[TensorMeta, dict]]: + # TODO ====== Remove (Move batch splitting elsewhere) ====== pass @abc.abstractmethod @@ -114,9 +131,12 @@ def preprocess( iteration: int, metrics: dict | None = None, ) -> list[tuple[torch.Tensor, dict]]: + # TODO ====== Move batch splitting elsewhere, align interface with LayerBase ====== pass def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: + # TODO ====== Tied weights ====== + # Return tuples of independently defined metas to tie together. # For each tied weight, return the weight and the tuple of layers sharing it. # The weight should be defined in the first layer in the set. # Warning: This may return buffers instead of metas after stage setup. diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 78fafea34..f1eef47b9 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -4,14 +4,15 @@ from fast_llm.config import MISSING, Config, Field, FieldHint, FieldVerboseLevel, config_class from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.utils import compare_nested, log +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.utils import Assert, compare_nested, log if typing.TYPE_CHECKING: - import torch + from fast_llm.engine.base_model.base_model import BaseModel @config_class() -class BaseModelConfig(Config): +class ModuleConfig(Config): """ Abstract config class for a base model. # TODO: Find better name? @@ -43,7 +44,7 @@ def _get_architecture(self) -> dict[str, typing.Any]: return architecture def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: - if isinstance(value, BaseModelConfig): + if isinstance(value, ModuleConfig): # TODO: Make sure all nested configs have an architecture type hint? return value._get_architecture() elif isinstance(value, Config): @@ -57,12 +58,29 @@ def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: return self._serialize_value(value) -class Preprocessor(abc.ABC): - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - pass +@config_class() +class BaseModelConfig(ModuleConfig): + """ + Abstract config class for a base model. + """ + + def get_base_model(self, distributed_config: DistributedConfig) -> "BaseModel": + from fast_llm.tensor import ParameterMeta + + model = self.base_model_class(self, distributed_config) + # Storing the global name of each module and tensor. + # Done here because it needs to run right after `model.__init__()` + for key, value in model.named_modules(): + value.module_name = key + for key, value in model.named_parameters(): + Assert.custom(isinstance, value, ParameterMeta) + # Rename to the parameter full name + value.tensor_name = key + return model + @property @abc.abstractmethod - def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + def base_model_class(self) -> type["BaseModel"]: pass diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index d5202a90f..e055595bd 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -116,7 +116,7 @@ def setup( phase=PhaseType.validation, ) - self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions() + self._loss_defs = self._multi_stage.base_model.get_loss_definitions() self._evaluation_iterator = None self._is_setup = True diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index aa18f5052..27c0e2b7b 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -137,11 +137,6 @@ class StageConfig(Config): desc="Check for tensor-parallel desyncs and log an error if a desync is found. High overhead", hint=FieldHint.logging, ) - compile_all: bool = Field( - default=False, - desc="Compile the whole model using torch.compile.", - hint=FieldHint.expert, - ) @config_class() diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index e48fdb88b..77dc4e7dd 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -25,7 +25,6 @@ class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): - base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel _is_setup: bool = False _flat_shard: torch.Tensor _shards: dict[str, torch.Tensor] @@ -46,7 +45,8 @@ def __init__( stage_filter: set | None = None, ): super().__init__(config) - self._base_model = self.base_model_class(self._config.base_model, self._config.distributed) + self._base_model = self._config.base_model.get_base_model(self._config.distributed) + self._layers = self._base_model.get_layers() self._training = None self._verbose = verbose self._stage_filter = stage_filter @@ -67,10 +67,8 @@ def __init__( self._stages = [ Stage( config=self._config.multi_stage, - base_model=self._base_model, + layers=self._layers[stage_splits[i] : stage_splits[i + 1]], distributed_config=self._config.distributed, - begin=stage_splits[i], - end=stage_splits[i + 1], index=i, ) for i in (range(self._num_stages)) @@ -510,12 +508,9 @@ def _split_into_stages(self) -> list[int]: # Create stages (greedy split, could do better). stage_splits = [0] layer_counter, last_counter = 0, 0 - for i, layer in enumerate(self._base_model): + for i, layer in enumerate(self._layers): layer_counter += layer.layer_count # noqa - if ( - layer_counter >= last_counter + self._config.multi_stage.layers_per_stage - or i == len(self._base_model) - 1 - ): + if layer_counter >= last_counter + self._config.multi_stage.layers_per_stage or i == len(self._layers) - 1: stage_splits.append(i + 1) last_counter = layer_counter return stage_splits diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 7829c243b..bb3133256 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -139,7 +139,9 @@ def forward( else: # TODO: Handle variable shape. output_global = output - kwargs["hidden_states"][self._layer_range[i]] = { + + # TODO ====== Use ====== + kwargs["hidden_states"][self._layers[i].module_name] = { "layer_type": type(layer).__name__, "tensor": output_global, } @@ -223,9 +225,9 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] and self._distributed.tensor_group is not None and not self._meta_outputs[i].is_tensor_parallel ): - check_parallel_match(output, self._distributed.tensor_group, f"layer {self._layer_range[i]} fw") + check_parallel_match(output, self._distributed.tensor_group, f"layer {self._layers[i].module_name} fw") if self._config.debug_layer_outputs: - name = f"layer {self._layer_range[i]} fw" + name = f"layer {self._layers[i].module_name} fw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" if (nms := kwargs.get("micro_batch_splits", 1)) > 1: @@ -242,7 +244,7 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] meta=self._meta_outputs[i], ) if self._config.debug_activation_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"layer {self._layer_range[i]} fw", str)) + log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"layer {self._layers[i].module_name} fw", str)) def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any], i: int) -> None: if not input_.requires_grad: @@ -254,11 +256,11 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any ): input_.register_hook( lambda grad: check_parallel_match( - grad, self._distributed.tensor_group, f"layer {self._layer_range[i]} bw" + grad, self._distributed.tensor_group, f"layer {self._layers[i].module_name} bw" ) ) if self._config.debug_layer_gradients: - name = f"layer {self._layer_range[i]} bw" + name = f"layer {self._layers[i].module_name} bw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" if (nms := kwargs.get("micro_batch_splits", 1)) > 1: @@ -276,6 +278,6 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any if self._config.debug_activation_memory: input_.register_hook( lambda grad: log_pipeline_parallel_main_rank( - lambda: log_memory_usage(f"layer {self._layer_range[i]} bw", str) + lambda: log_memory_usage(f"layer {self._layers[i].module_name} bw", str) ) ) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index ded24e538..4778780ee 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -6,7 +6,7 @@ from fast_llm.config import Configurable from fast_llm.core.distributed import check_parallel_match -from fast_llm.engine.base_model.base_model import BaseModel, Layer +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed @@ -28,24 +28,17 @@ def __init__( self, *, config: StageConfig, - base_model: BaseModel | list[Layer], + layers: list[Layer], distributed_config: DistributedConfig, - begin: int, - end: int, index: int, ): super().__init__(config) self._distributed_config = distributed_config.validate() - Assert.in_range(begin, 0, end) - Assert.leq(end, len(base_model)) - self._fsdp_rank = self._distributed_config.data_rank self._fsdp_size = self._distributed_config.data_parallel self._is_setup = False self._index = index - - self._layers = [torch.compile(layer) if self._config.compile_all else layer for layer in base_model[begin:end]] - self._layer_range = list(range(begin, end)) + self._layers = layers parameter_metas, frozen_metas = self._get_parameter_metas() self._parameter_metas = parameter_metas + frozen_metas diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index dbdd035a4..58449f207 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -95,7 +95,7 @@ def __init__( self._num_stages = len(self._stages) self._loss_definitions = { loss_definition.name: loss_definition - for loss_definition in self._multi_stage.base_model.config.get_loss_definitions() + for loss_definition in self._multi_stage.base_model.get_loss_definitions() } def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index a752bec28..aa4f2d570 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -149,7 +149,7 @@ def __init__(self, config: TrainerConfig): multi_stage=self._multi_stage, distributed_config=self._config.model.distributed, ) - self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions() + self._loss_definitions = self._multi_stage.base_model.get_loss_definitions() if not self._is_evaluation_only: steps_per_split = { @@ -320,7 +320,7 @@ def _run_training(self) -> None: phase=PhaseType.test, num_iters=self._config.training.test_iters, ) - formatted_metrics = format_metrics(metrics[metrics_key], self._loss_defs, PhaseType.test) + formatted_metrics = format_metrics(metrics[metrics_key], self._loss_definitions, PhaseType.test) log_main_rank(formatted_metrics) self._wandb.alert("Testing results", formatted_metrics, "WARN") # TODO: This may erase some metrics. @@ -331,7 +331,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: advanced_iters = 0 skipped_iters = 0 nan_iters = 0 - total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} + total_losses = {loss_def.name: 0.0 for loss_def in self._loss_definitions} # Profiling profiler = self._config.profiling.get_profiler( @@ -435,7 +435,9 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: **get_and_reset_memory_usage_mib(), } - formatted_metrics = format_metrics(metrics[metrics_key], self._loss_defs, PhaseType.training) + formatted_metrics = format_metrics( + metrics[metrics_key], self._loss_definitions, PhaseType.training + ) logger.info(formatted_metrics) if self._config.training.wandb.alert.enabled(self._completed_steps): self._wandb.alert("Training results", formatted_metrics, "INFO") @@ -443,7 +445,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: advanced_iters = 0 skipped_iters = 0 nan_iters = 0 - total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} + total_losses = {loss_def.name: 0.0 for loss_def in self._loss_definitions} self._run.save_logged_tensors(f"train_{self._completed_steps}") diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 2d4f049f0..167184193 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -56,6 +56,11 @@ class Attention[ConfigType: AttentionConfig](BlockWithBias[ConfigType]): _config: ConfigType + # Preprocessing + _backup_attention_mask: torch.Tensor + _backup_attention_mask_value: torch.Tensor + _backup_attention_tensor_cache_max_sequence_length: int = -1 + def __init__( self, config: ConfigType, @@ -431,3 +436,115 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c self.dense.get_compute_usage(dense_input, config), ) ) + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._rotary.preprocess(batch, kwargs) + if not self._use_flash_attention: + self._preprocess_for_backup_attention(batch, kwargs) + elif AttentionKwargs.sequence_lengths in kwargs: + self._preprocess_for_varlen(batch, kwargs) + + def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if ( + sequence_length := kwargs[AttentionKwargs.sequence_length] + ) > self._backup_attention_tensor_cache_max_sequence_length: + # Create tensor cache. + self._backup_attention_tensor_cache_max_sequence_length = sequence_length + + self._backup_attention_mask = torch.ones( + (sequence_length, sequence_length), + dtype=torch.bool, + device=batch.device, + ).tril_() + + if self._config.window_size is not None: + self._backup_attention_mask.triu_(-self._config.window_size + 1) + self._backup_attention_mask_value = torch.full( + [], + torch.finfo(self._distributed_config.compute_dtype.torch).min, + dtype=self._distributed_config.compute_dtype.torch, + device=batch.device, + ) + + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + kwargs[AttentionKwargs.attention_mask] = self._backup_attention_mask[ + None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k + ] + if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: + seq_ids = torch.stack( + [ + torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) + for sample_lens in sequence_lengths + ] + ) + document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) + kwargs[AttentionKwargs.attention_mask] = ( + kwargs[AttentionKwargs.attention_mask] + & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] + ) + kwargs[AttentionKwargs.attention_mask_value] = self._backup_attention_mask_value + + def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + """ + Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 + cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively. + Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k. + If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally + also contain previous tokens from the first document in micro-sequence. + We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. + """ + if AttentionKwargs.sequence_lengths not in kwargs: + return + sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + if sequence_q < kwargs[AttentionKwargs.sequence_length]: + cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] + # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents + # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets + # of the first documents so that we can index into their kv pairs + start_seq_idx = [ + torch.argmax((cu_seqlens >= sequence_k - sequence_q).to(torch.uint8), dim=0) for cu_seqlens in cumsums + ] + end_seq_idx = [torch.argmax((cu_seqlens >= sequence_k).to(torch.uint8), dim=0) for cu_seqlens in cumsums] + seqlens_q = [] + seqlens_k = [] + for idx, sample_seqlens in enumerate(sequence_lengths): + start_idx = start_seq_idx[idx] + end_idx = end_seq_idx[idx] + seqlens_q.extend([0] * start_idx) + n_attention_tokens = sample_seqlens[end_idx] - (cumsums[idx][end_idx] - sequence_k) + if start_idx == end_idx: + seqlens_q.append(sequence_q) + else: + start_q_tokens = cumsums[idx][start_idx] - (sequence_k - sequence_q) + seqlens_q.extend( + [ + start_q_tokens, + *(sample_seqlens[idx] for idx in range(start_idx + 1, end_idx)), + n_attention_tokens, + ] + ) + seqlens_k.extend(sample_seqlens[: end_idx + 1]) + seqlens_k[-1] = n_attention_tokens + seqlens_q = torch.tensor(seqlens_q, dtype=torch.int32) + seqlens_k = torch.tensor(seqlens_k, dtype=torch.int32) + else: + seqlens_q = torch.cat(sequence_lengths) + seqlens_k = torch.cat(sequence_lengths) + kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( + ( + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(batch.device), + ) + ) + kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( + ( + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(batch.device), + ) + ) + kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() + kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 924e06057..68b6dde91 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -3,9 +3,7 @@ import warnings from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig @@ -126,19 +124,3 @@ def layer_class(self) -> "type[Attention]": def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. - # TODO: Find a better solution. - preprocessors: list[Preprocessor] = [ - self.rotary.get_layer(TensorDim("head_size", self.head_size)), - ] - if self.do_use_flash_attention(distributed_config): - from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor - - preprocessors.append(FlashAttnVarlenPreprocessor(self, distributed_config)) - else: - from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor - - preprocessors.append(BackupAttentionPreprocessor(self, distributed_config)) - return preprocessors diff --git a/fast_llm/layers/attention/preprocessing.py b/fast_llm/layers/attention/preprocessing.py deleted file mode 100644 index 204c08ad2..000000000 --- a/fast_llm/layers/attention/preprocessing.py +++ /dev/null @@ -1,153 +0,0 @@ -import logging -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs -from fast_llm.tensor import TensorMeta - -logger = logging.getLogger(__name__) - - -class BackupAttentionPreprocessor(Preprocessor): - _head_size_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor - _mask: torch.Tensor - _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - - def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): - self._config = config - self._distributed_config = distributed_config - assert not self._config.do_use_flash_attention(self._distributed_config) - - def _create_tensors(self, sequence_length: int, device: torch.device) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - self._mask = torch.ones( - (sequence_length, sequence_length), - dtype=torch.bool, - device=device, - ).tril_() - - if self._config.window_size is not None: - self._mask.triu_(-self._config.window_size + 1) - self._mask_value = torch.full( - [], - torch.finfo(self._distributed_config.compute_dtype.torch).min, - dtype=self._distributed_config.compute_dtype.torch, - device=device, - ) - - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length], batch.device) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size - kwargs[AttentionKwargs.attention_mask] = self._mask[ - None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k - ] - if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: - seq_ids = torch.stack( - [ - torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) - for sample_lens in sequence_lengths - ] - ) - document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) - kwargs[AttentionKwargs.attention_mask] = ( - kwargs[AttentionKwargs.attention_mask] - & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] - ) - kwargs[AttentionKwargs.attention_mask_value] = self._mask_value - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( - ( - scalar_dim, - scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - scalar_dim, - kwargs[AttentionKwargs.sequence_k_dim], - ), - tensor_name=AttentionKwargs.attention_mask, - dtype=torch.bool, - ) - kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( - (scalar_dim,), - tensor_name=AttentionKwargs.attention_mask_value, - dtype=self._distributed_config.compute_dtype.torch, - ) - - -class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): - assert config.do_use_flash_attention(distributed_config) - - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - """ - Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 - cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively. - Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k. - If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally - also contain previous tokens from the first document in micro-sequence. - We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. - """ - if AttentionKwargs.sequence_lengths not in kwargs: - return - sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size - if sequence_q < kwargs[AttentionKwargs.sequence_length]: - cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] - # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents - # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets - # of the first documents so that we can index into their kv pairs - start_seq_idx = [ - torch.argmax((cu_seqlens >= sequence_k - sequence_q).to(torch.uint8), dim=0) for cu_seqlens in cumsums - ] - end_seq_idx = [torch.argmax((cu_seqlens >= sequence_k).to(torch.uint8), dim=0) for cu_seqlens in cumsums] - seqlens_q = [] - seqlens_k = [] - for idx, sample_seqlens in enumerate(sequence_lengths): - start_idx = start_seq_idx[idx] - end_idx = end_seq_idx[idx] - seqlens_q.extend([0] * start_idx) - n_attention_tokens = sample_seqlens[end_idx] - (cumsums[idx][end_idx] - sequence_k) - if start_idx == end_idx: - seqlens_q.append(sequence_q) - else: - start_q_tokens = cumsums[idx][start_idx] - (sequence_k - sequence_q) - seqlens_q.extend( - [ - start_q_tokens, - *(sample_seqlens[idx] for idx in range(start_idx + 1, end_idx)), - n_attention_tokens, - ] - ) - seqlens_k.extend(sample_seqlens[: end_idx + 1]) - seqlens_k[-1] = n_attention_tokens - seqlens_q = torch.tensor(seqlens_q, dtype=torch.int32) - seqlens_k = torch.tensor(seqlens_k, dtype=torch.int32) - else: - seqlens_q = torch.cat(sequence_lengths) - seqlens_k = torch.cat(sequence_lengths) - kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( - ( - torch.zeros(1, dtype=torch.int32, device=batch.device), - torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(batch.device), - ) - ) - kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( - ( - torch.zeros(1, dtype=torch.int32, device=batch.device), - torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(batch.device), - ) - ) - kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 5e24af9a6..74b5cf21a 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -4,7 +4,7 @@ import warnings from fast_llm.config import Field, FieldHint, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.utils import Assert @@ -14,7 +14,7 @@ @config_class(registry=True) -class RotaryConfig(BaseModelConfig): +class RotaryConfig(ModuleConfig): # TODO: Move rotary to its own submodule. @classmethod @@ -137,7 +137,7 @@ def _get_configurable_class(self) -> "type[YarnRotary]": return YarnRotary -@config_class(dynamic_type={RotaryConfig: "rope_2d"}) +@config_class(dynamic_type={RotaryConfig: "default_2d"}) class Rotary2DConfig(DefaultRotaryConfig): def _get_configurable_class(self) -> "type[Rotary2D]": from fast_llm.layers.transformer.rotary.rotary import Rotary2D diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 189596352..6250fd4a9 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -5,8 +5,7 @@ import torch from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.attention.rotary.config import ( @@ -17,8 +16,6 @@ RotaryConfig, YarnRotaryConfig, ) -from fast_llm.layers.vision.config import VisionEncoderKwargs -from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -43,7 +40,7 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) -class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module, Preprocessor): +class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module): def __init__( self, config: ConfigType, @@ -58,6 +55,9 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: pass + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + pass + class NoRotary[ConfigType: NoRotaryConfig](Rotary[ConfigType]): def forward( @@ -65,12 +65,6 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: return query, key - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - pass - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - pass - class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[ConfigType]): _rotary_embedding_frequencies: torch.Tensor @@ -84,26 +78,6 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None ] kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( - ( - scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - scalar_dim, - self._head_size_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_q, - ) - kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( - ( - scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - scalar_dim, - self._head_size_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_k, - ) - def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: @@ -206,52 +180,30 @@ def _get_correction(self, beta: float, dim: int) -> float: class Rotary2D[ConfigType: Rotary2DConfig](DefaultRotary[ConfigType]): _rotary_embedding_frequencies: torch.Tensor _tensor_cache_max_num_patches: int = -1 + _config: ConfigType def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - assert self._tensor_space is not None - max_num_patches = kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size] - self._create_tensors(max_num_patches) - position_ids = kwargs[VisionTransformerKwargs.patch_position_ids] - kwargs[VisionTransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] - kwargs[VisionTransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - assert self._tensor_space is not None - kwargs[VisionTransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=VisionTransformerKwargs.rotary_freq_q, - ) - kwargs[VisionTransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[TransformerKwargs.sequence_k_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=VisionTransformerKwargs.rotary_freq_k, + self._create_tensors( + kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size], batch.device ) + position_ids = kwargs[VisionTransformerKwargs.patch_position_ids] + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] - def _create_tensors(self, max_num_patches: int) -> None: - if max_num_patches <= self._tensor_cache_max_num_patches: - return - self._tensor_cache_max_num_patches = max_num_patches - - self._rotary_embedding_frequencies = self._get_frequencies( - max_num_patches, - self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, - ) + def forward( + self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] + ) -> tuple[torch.Tensor, torch.Tensor]: + rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings + query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) + return query, key - def _get_frequencies(self, max_num_patches: int, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_frequencies(self, sequence_length: int, head_size: int, device: torch.device) -> torch.Tensor: + max_num_patches = sequence_length # Calculate complex frequencies by using alternating channels for width and height height_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) width_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) - frequencies = self._config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + frequencies = self._config.theta ** -torch.arange(0, 1, 2 / head_size, device=device, dtype=torch.float64) angles_h = torch.outer(height_positions, frequencies[::2]) angles_w = torch.outer(width_positions, frequencies[1::2]) angles = torch.cat( @@ -260,12 +212,12 @@ def _get_frequencies(self, max_num_patches: int, kv_channels: int, device="cuda" angles_w[None, :, :].repeat(max_num_patches, 1, 1), ], dim=-1, - ).reshape(-1, kv_channels // 2) + ).reshape(-1, head_size // 2) frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) if not self._config.complex_format: frequencies = convert_rotary_complex_to_real( - torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + torch.view_as_real(frequencies).flatten(-2), head_size, 3 ).contiguous() return frequencies diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 0f975c9c5..ab6cb22b0 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -4,13 +4,13 @@ import torch -from fast_llm.config import Config, Configurable -from fast_llm.engine.base_model.base_model import Layer, Module -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.config import Configurable +from fast_llm.engine.base_model.base_model import Layer, LayerBase +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.logging import get_model_debug_level, log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -93,9 +93,9 @@ def __call__[ ) -class BaseBlock[ConfigType: Config](Configurable[ConfigType], Module): +class BlockBase[ConfigType: ModuleConfig](Configurable[ConfigType], LayerBase): """ - Base class for blocks and block-like layers (mlp, mixers, etc.). + Base class for blocks and block-like layers (mlp, mixers, block sequences, etc.). """ def __init__( @@ -115,11 +115,6 @@ def __init__( self._lr_scale = lr_scale self._peft = peft - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - raise NotImplementedError() - -class Block[ConfigType: Config](BaseBlock[ConfigType], Layer): - """ - Base class for actual blocks, i.e., base blocks that are also `Layers`. - """ +class Block[ConfigType: BlockConfig](BlockBase[ConfigType], Layer): + pass diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 47c1ab9b7..f3e93edeb 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,11 +1,9 @@ -import abc -import collections import functools import typing import warnings from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, Preprocessor +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -13,7 +11,8 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.block.block import Block + from fast_llm.layers.block.block import BlockBase + from fast_llm.layers.block.sequence import FixedBlockSequence, PatternBlockSequence class BlockDimNames: @@ -40,8 +39,8 @@ class BlockKwargs: grad_output = "grad_output" -@config_class() -class BaseBlockConfig(BaseModelConfig): +@config_class(registry=True) +class BlockConfig(ModuleConfig): """ Base configuration class for blocks and block-like layers (mlp, mixers, etc.). """ @@ -55,19 +54,6 @@ class BaseBlockConfig(BaseModelConfig): hint=FieldHint.feature, ) - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - return [] - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return [] - - -@config_class(registry=True) -class BlockConfig(BaseBlockConfig): - """ - Base configuration class for actual blocks, i.e., base blocks that are also `Layers`. - """ - @classmethod def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is BlockConfig and cls.get_subclass(default.get("type")) is None: @@ -78,16 +64,16 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return super()._from_dict(default, strict=strict) @property - def layer_class(self) -> "type[Block]": + def layer_class(self) -> "type[BlockBase]": raise NotImplementedError() - def get_block( + def get_layer( self, distributed_config: DistributedConfig, hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - ) -> "Block": + ) -> "BlockBase": return self.layer_class( self, distributed_config, @@ -98,7 +84,7 @@ def get_block( @config_class(registry=True) -class BlockSequenceConfig(BaseModelConfig): +class BlockSequenceConfig(BlockConfig): @classmethod def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is BlockSequenceConfig and cls.get_subclass(default.get("type")) is None: @@ -106,21 +92,6 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return FixedBlockSequenceConfig._from_dict(default, strict) return super()._from_dict(default, strict=strict) - @abc.abstractmethod - def __len__(self) -> int: - pass - - @abc.abstractmethod - def __getitem__(self, index: int) -> BlockConfig: - pass - - @abc.abstractmethod - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - pass - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return [] - @config_class(dynamic_type={BlockSequenceConfig: "fixed"}) class FixedBlockSequenceConfig(BlockSequenceConfig): @@ -136,18 +107,11 @@ class FixedBlockSequenceConfig(BlockSequenceConfig): valid=check_field(Assert.geq, 0), ) - def __len__(self) -> int: - return self.num_blocks - - def __getitem__(self, index: int) -> BlockConfig: - return self.block - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # TODO: Prevent name conflicts in preprocessed kwargs. - return self.block.get_preprocessors(distributed_config) + @property + def layer_class(self) -> "type[FixedBlockSequence]": + from fast_llm.layers.block.sequence import FixedBlockSequence - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.block.get_loss_definitions(count=count * self.num_blocks) + return FixedBlockSequence @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) @@ -180,26 +144,18 @@ def _validate(self): super()._validate() - def __len__(self) -> int: - return self.num_blocks + @property + def layer_class(self) -> "type[PatternBlockSequence]": + from fast_llm.layers.block.sequence import PatternBlockSequence - def __getitem__(self, index: int) -> BlockConfig: - return self.blocks[self.expanded_pattern[index]] + return PatternBlockSequence @functools.cached_property def expanded_pattern(self) -> list[str]: + # The complete list of block names, expanded to `num_blocks` return (self.pattern * (self.num_blocks // len(self.pattern) + 1))[: self.num_blocks] - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # TODO: Prevent name conflicts in preprocessed kwargs. - return sum((block.get_preprocessors(distributed_config) for block in self.blocks.values()), []) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - # TODO: Prevent name conflicts. - return sum( - ( - self.blocks[name].get_loss_definitions(count=count * count_) - for name, count_ in collections.Counter(self.expanded_pattern).items() - ), - [], - ) + @functools.cached_property + def preprocessing_layers(self) -> dict[str, int]: + # The index at which each block first appears. These blocks are used for preprocessing. + return {name: self.expanded_pattern.index(name) for name in set(self.expanded_pattern)} diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index e69de29bb..57621a848 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -0,0 +1,108 @@ +import collections + +import torch.nn + +from fast_llm.engine.base_model.base_model import Layer, LayerWithNamespace +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.common.peft.config import PeftConfig + + +class FixedBlockSequence[ConfigType: FixedBlockSequenceConfig](BlockBase[ConfigType], torch.nn.ModuleList): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + self.extend( + layers := [ + self._config.block.get_layer( + distributed_config, + hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + for _ in range(self._config.num_blocks) + ] + ) + # Wrap all blocks in a namespace using the unique module name of the first one. + namespace = layers[0].module_name if self._config.num_blocks > 0 else "" + # Note: Pytorch won't redundantly register modules because it doesn't look into lists. + self._layers_with_namespace = [ + LayerWithNamespace(sublayer, namespace) for layer in layers for sublayer in layer.get_layers() + ] + + def get_layers(self) -> list["Layer"]: + return self._layers_with_namespace + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self[0].get_loss_definitions(count=count * self.num_blocks) if self._config.num_blocks > 0 else [] + + +class PatternBlockSequence[ConfigType: PatternBlockSequenceConfig](BlockBase[ConfigType], torch.nn.ModuleList): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + self.extend( + layers := [ + self._config.blocks[name].get_layer( + distributed_config, + hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + for name in self._config.expanded_pattern + ] + ) + # Wrap each set of blocks with identical config in a namespace + # using the unique module name of the first such block. + # Note: Pytorch won't redundantly register modules because it doesn't look into lists. + self._layers_with_namespace = [ + LayerWithNamespace(sublayer, layers[self._config.preprocessing_layers[name]].module_name) + for name, layer in zip(self._config.expanded_pattern, layers) + for sublayer in layer.get_layers() + ] + + def get_layers(self) -> list["Layer"]: + return self._layers_with_namespace + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + # TODO: Prevent name conflicts. + return sum( + ( + self[self._config.preprocessing_layers[name]].get_loss_definitions(count=count * count_) + for name, count_ in collections.Counter(self.expanded_pattern).items() + ), + [], + ) diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index c1ced10df..a80a19280 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -3,7 +3,7 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import ParameterConfig, combine_lr_scales from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -26,7 +26,7 @@ class NormalizationImplementation(str, enum.Enum): @config_class(registry=True) -class NormalizationConfig(BaseModelConfig): +class NormalizationConfig(ModuleConfig): lr_scale: float | None = Field( default=None, desc="Scaling factor for the layer learning rate." diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index ce9893ea4..08dd5a815 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -4,22 +4,21 @@ import torch -from fast_llm.config import Config from fast_llm.core.distributed import set_generator -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) -class BlockWithBias[ConfigType: Config](Block[ConfigType]): +class BlockWithBias[ConfigType: BlockWithBiasConfig](Block[ConfigType]): """ Base class for mixer and MLP modules. """ @@ -85,12 +84,9 @@ def __init__( ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input = return_input - # Note, layer_lr_scale does not impact the norms - # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) - # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. self.mixer = self._config.mixer.get_layer( self._distributed_config, self._hidden_dim, @@ -178,3 +174,6 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c self.mlp.get_compute_usage(input_, kwargs, config), ) ) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 724b8d172..403b204c8 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -1,7 +1,6 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import LossDef, Preprocessor from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -96,7 +95,7 @@ def layer_class(self) -> "type[DecoderBlock]": return DecoderBlock - def get_block( + def get_layer( self, distributed_config: DistributedConfig, hidden_dim: TensorDim, @@ -112,9 +111,3 @@ def get_block( peft=peft, return_input=return_input, ) - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - return self.mixer.get_preprocessors(distributed_config) + self.mlp.get_preprocessors(distributed_config) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index 100f53740..36841b45b 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -3,7 +3,6 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import LossDef from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig from fast_llm.layers.decoder.config import MLPBaseConfig @@ -152,23 +151,3 @@ def _validate(self) -> None: super()._validate() Assert.leq(self.shared_experts, self.experts) Assert.leq(self.shared_experts + self.experts_per_token, self.experts) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_definitions = [] - if self.routing == RoutingType.topk: - loss_definitions.append( - LossDef( - name=MLPLossNames.load_balancing_loss, - formatted_name="load balancing loss", - count=1, - ) - ) - if self.z_loss_coefficient: - loss_definitions.append( - LossDef( - name=MLPLossNames.router_z_loss, - formatted_name="router z loss", - count=1, - ) - ) - return loss_definitions diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index d4cb46dbf..ffc9eadba 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -5,7 +5,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, set_generator -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -263,6 +263,26 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c return super().get_compute_usage(moe_input, kwargs, config) + self.router.get_compute_usage(input_, config) + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + loss_definitions = [] + if self._config.routing == RoutingType.topk: + loss_definitions.append( + LossDef( + name=MLPLossNames.load_balancing_loss, + formatted_name="load balancing loss", + count=1, + ) + ) + if self._config.z_loss_coefficient: + loss_definitions.append( + LossDef( + name=MLPLossNames.router_z_loss, + formatted_name="router z loss", + count=1, + ) + ) + return loss_definitions + def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index e852cf96f..2c5032967 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,7 +1,7 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, Preprocessor +from fast_llm.engine.base_model.config import LossDef, ModuleConfig from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -9,26 +9,22 @@ from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.vision.config import VisionEncoderConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead + from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction -class LanguageModelLossNames: - language_model_loss = "language_model_loss" - z_loss = "z_loss" - dpo_loss = "dpo_loss" - distil_lm_loss = "distillation_language_model_loss" # the next token perdiciton of combined distillation loss - distillation_loss = "distillation_loss" - - @staticmethod - def multi_token_prediction_loss(index: int) -> str: - if index == 0: - return LanguageModelLossNames.language_model_loss - return f"language_model_loss_{index}" +# class LanguageModelLossNames: +# language_model_loss = "language_model_loss" +# z_loss = "z_loss" +# dpo_loss = "dpo_loss" +# distil_lm_loss = "distillation_language_model_loss" # the next token perdiciton of combined distillation loss +# distillation_loss = "distillation_loss" class LanguageModelKwargs(BlockKwargs): @@ -108,17 +104,37 @@ def layer_class(self) -> "type[LanguageModelEmbedding]": return LanguageModelEmbedding - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - preprocessors = [] - if self.position_embeddings.enabled: - from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor - preprocessors.append(PositionEmbeddingPreprocessor(self, distributed_config)) - return preprocessors +@config_class(registry=True) +class LanguageModelHeadBaseConfig(BlockConfig): + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is LanguageModelHeadBaseConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return LanguageModelHeadConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + def get_layer( + self, + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + return self.layer_class( + self, + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + ) -@config_class() -class LanguageModelHeadConfig(BlockConfig): +@config_class(dynamic_type={LanguageModelHeadBaseConfig: "default"}) +class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): _abstract = False normalization: NormalizationConfig = Field( desc="Configuration for the final normalization layer.", @@ -129,17 +145,6 @@ class LanguageModelHeadConfig(BlockConfig): desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - tied_weight: bool = Field( - default=True, - desc="Tie the output weights (logits) with the vocabulary embedding.", - hint=FieldHint.architecture, - ) - prediction_heads: int = Field( - default=1, - desc="Number of multi-token prediction heads.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) cross_entropy_implementation: CrossEntropyImpl = Field( default=CrossEntropyImpl.auto, desc="Implementation for the cross-entropy computation.", @@ -181,12 +186,6 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - prediction_loss_coefficient: list[float] | None = Field( - default=None, - desc="Loss coefficient for each prediction head.", - doc="If not provided, all heads are equally weighted.", - hint=FieldHint.feature, - ) teacher_softmax_temperature: float = Field( default=1.0, desc="Divides distillation target logits by this factor.", @@ -216,6 +215,30 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, ) + def get_layer( + self, + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + prediction_distance: int = 0, + prediction_heads: int = 1, + loss_coefficient: float = 1.0, + ): + return self.layer_class( + self, + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + prediction_distance=prediction_distance, + prediction_heads=prediction_heads, + loss_coefficient=loss_coefficient, + ) + @property def layer_class(self) -> "type[LanguageModelHead]": from fast_llm.layers.language_model.head import LanguageModelHead @@ -230,125 +253,76 @@ def _validate(self) -> None: else: self.language_model_loss_factor = 0.0 super()._validate() - if self.distillation_model is not None: - if self.prediction_heads > 1: - raise NotImplementedError("Multi-token prediction not supported with distillation.") + + +@config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) +class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): + _abstract = False + # Needs to be `DecoderBlockConfig` for the `return_input` interface. + # TODO: Make a generic wrapper for returning input instead? + # TODO ====== Tied weight ====== + block: DecoderBlockConfig = Field( + desc="Configuration for the decoder block before each head.", + hint=FieldHint.architecture, + ) + # TODO: Generalize? (needs the extra initialization arguments) + head: LanguageModelHeadConfig = Field( + desc="Configuration for the multi-token-prediction heads.", + hint=FieldHint.architecture, + ) + prediction_heads: int = Field( + default=1, + desc="Prediction heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + # TODO ====== Adjust ====== + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) + + def _validate(self) -> None: + super()._validate() if isinstance(self.prediction_loss_coefficient, list): Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) for coeff in self.prediction_loss_coefficient: Assert.geq(coeff, 0) - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - preprocessors: list[Preprocessor] = [] - - if self.enable_dpo: # TODO better way to pass in? - from fast_llm.layers.language_model.preprocessing import PreferenceSpanPreprocessor - - preprocessors.append(PreferenceSpanPreprocessor()) + @property + def layer_class(self) -> "type[MultiTokenPrediction]": + from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction - return preprocessors + return MultiTokenPrediction def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_defs = [] - if self.logit_z_loss: - LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=count) - - if self.enable_dpo: - loss_defs.append(LossDef(name=LanguageModelLossNames.dpo_loss, formatted_name="dpo loss", count=count)) - - if self.distillation_model is not None: - loss_defs.append( - LossDef(name=LanguageModelLossNames.distillation_loss, formatted_name="distillation loss", count=count) - ) - if self.language_model_loss_factor > 0.0: - loss_defs.append( - LossDef( - name=LanguageModelLossNames.distil_lm_loss, formatted_name="distillation lm loss", count=count - ) - ) - - for i in range(self.prediction_heads): - loss_defs.append( - LossDef( - name=LanguageModelLossNames.multi_token_prediction_loss(i), - formatted_name=f"language model loss {i}", - count=count, - ) - ) - return loss_defs - - def get_block( - self, - distributed_config: DistributedConfig, - embeddings_config: LanguageModelEmbeddingsConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - prediction_distance: int = 0, - ): - return self.layer_class( - self, - distributed_config, - embeddings_config, - hidden_dim=hidden_dim, - lr_scale=combine_lr_scales(lr_scale, self.lr_scale), - peft=peft, - prediction_distance=prediction_distance, + # TODO ====== Wrong ====== + return self.block.get_loss_definitions(count=count * self.prediction_heads) + self.head.get_loss_definitions( + count=count * self.prediction_heads ) - def get_blocks( - self, - distributed_config: DistributedConfig, - embeddings_config: LanguageModelEmbeddingsConfig, - mtp_block_config: BlockConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - ): - blocks = [] - for i in range(self.prediction_heads): - if i > 0: - blocks.append( - mtp_block_config.get_block( - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, - # The last block only returns the model output. - # The previous blocks return a stack of shared_hidden and transformer_output. - return_input=i < self.prediction_heads - 1, - ) - ) - blocks.append( - self.get_block( - distributed_config, - embeddings_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, - prediction_distance=i, - ) - ) - return blocks - -# TODO: `BlockSequenceConfig`? (interface not fully compatible) @config_class() -class LanguageModelBaseConfig(BaseModelConfig): +class LanguageModelConfig(ModuleConfig): # TODO: block decoder: BlockSequenceConfig = Field( desc="Configuration for the language model decoder.", hint=FieldHint.architecture, ) embeddings_layer: LanguageModelEmbeddingsConfig = Field() - output_layer: LanguageModelHeadConfig = Field() + output_layer: LanguageModelHeadBaseConfig = Field() # TODO: Allow overriding in sub-models? peft: PeftConfig = Field( desc="Configuration for parameter-efficient fine tuning.", hint=FieldHint.architecture, ) + tied_embedding_weight: bool = Field( + default=False, + desc="Tie the output weights (logits) with the vocabulary embedding.", + hint=FieldHint.architecture, + ) sequence_first: bool | None = Field( default=None, desc="Override the default dimension ordering", @@ -358,66 +332,16 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.testing, ) - def __len__(self) -> int: - return len(self.decoder) + 2 * self.output_layer.prediction_heads - - def __getitem__(self, index: int) -> BlockConfig: - if index <= 0: - Assert.eq(index, 0) - return self.embeddings_layer - elif index <= len(self.decoder): - return self.decoder[index - 1] - else: - # Start at the last decoder layer so all MTP heads are treated similarly. - index - len(self.decoder) - return self.embeddings_layer - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - return ( - self.embeddings_layer.get_preprocessors(distributed_config) - + self.decoder.get_preprocessors(distributed_config) - + self.output_layer.get_preprocessors(distributed_config) - ) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return ( - self.embeddings_layer.get_loss_definitions(count) - + self.decoder.get_loss_definitions(count) - + self.output_layer.get_loss_definitions(count) - ) - - def get_blocks(self, distributed_config: DistributedConfig): - hidden_dim = TensorDim("hidden", self.embeddings_layer.hidden_size) - return [ - self.embeddings_layer.get_block( - distributed_config, - hidden_dim=hidden_dim, - lr_scale=None, - peft=self.peft, - ), - *[ - self.decoder[i].get_block( - distributed_config, - hidden_dim, - lr_scale=None, - peft=self.peft, - # The last layer only returns the transformer output. - # The previous layers return a stack of shared_hidden and transformer_output. - # TODO: Not all blocks support this argument. - **( - {"return_input": True} - if self.output_layer.prediction_heads > 1 and i == len(self.decoder) - 1 - else {} - ), - ) - for i in range(len(self.decoder)) - ], - *self.output_layer.get_blocks( - distributed_config, - self.embeddings_layer, - self.decoder[len(self.decoder) - 1], - hidden_dim=hidden_dim, - lr_scale=None, - peft=self.peft, - ), - ] + # def __len__(self) -> int: + # return len(self.decoder) + 2 * self.output_layer.prediction_heads + + # def __getitem__(self, index: int) -> BlockConfig: + # if index <= 0: + # Assert.eq(index, 0) + # return self.embeddings_layer + # elif index <= len(self.decoder): + # return self.decoder[index - 1] + # else: + # # Start at the last decoder layer so all MTP heads are treated similarly. + # index - len(self.decoder) + # return self.embeddings_layer diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 8625945ce..17ee3261b 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -28,6 +28,11 @@ class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[Co layer_count: float = 1000.0 _config: ConfigType + # Preprocessing + _rotary_embedding_frequencies: torch.Tensor + _position_ids: torch.Tensor + _tensor_cache_max_sequence_length: int = -1 + def __init__( self, config: ConfigType, @@ -161,3 +166,30 @@ def forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (embeddings) return 0 + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if not self._config.position_embeddings.enabled: + return + self._create_position_embeddings(kwargs[LanguageModelKwargs.sequence_length], batch.device) + sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: + position_ids = torch.stack( + [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + ).to(batch.device, dtype=torch.int64) + position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] + if kwargs[LanguageModelKwargs.sequence_first]: + position_ids = position_ids.transpose(0, 1) + kwargs[LanguageModelKwargs.position_ids] = position_ids + else: + kwargs[LanguageModelKwargs.position_ids] = self._position_ids[ + sequence_k - sequence_q : sequence_k + ].unsqueeze(int(kwargs[LanguageModelKwargs.sequence_first])) + + def _create_position_embeddings(self, sequence_length: int, device: torch.device) -> None: + if sequence_length <= self._tensor_cache_max_sequence_length: + return + self._tensor_cache_max_sequence_length = sequence_length + + Assert.leq(sequence_length, self._config.num_position_embeddings) + self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 4b080b360..42b7e3d6c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,3 +1,4 @@ +import functools import logging import typing @@ -6,7 +7,7 @@ from torch.distributed import all_reduce from fast_llm.core.ops import split_op -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames @@ -23,7 +24,6 @@ LanguageModelEmbeddingsConfig, LanguageModelHeadConfig, LanguageModelKwargs, - LanguageModelLossNames, ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.tensor import TensorMeta @@ -51,7 +51,9 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - prediction_distance: int, + prediction_distance: int = 0, + prediction_heads: int = 1, + loss_coefficient: float = 1.0, ): super().__init__( config, @@ -60,6 +62,17 @@ def __init__( lr_scale=lr_scale, peft=peft, ) + if prediction_distance > 0 and ( + self._config.distillation_model is not None or self._config.dpo_reference_model is not None + ): + raise NotImplementedError("Multi-token prediction not supported with distillation or dpo.") + + Assert.in_range(prediction_distance, 0, prediction_heads) + self._prediction_distance = prediction_distance + self._prediction_heads = prediction_heads + self._loss_coefficient = loss_coefficient + self._is_last_head = self._prediction_distance == self._prediction_heads - 1 + self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -67,20 +80,6 @@ def __init__( if self._config.cross_entropy_splits is not None and self._sequence_parallel: assert not self._vocab_parallel - self._loss_coefficient = ( - self._config.prediction_loss_coefficient[prediction_distance] - if self._config.prediction_loss_coefficient - else 1.0 - ) - self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - - # Distance of the target token prediction - # 0: next-token prediction - # >0: multi-token prediction (MTP) - Assert.geq(prediction_distance, 0) - self._prediction_distance = prediction_distance - self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 - if not self._config.enable_dpo: self._cross_entropy_impl = self._config.cross_entropy_implementation if self._cross_entropy_impl == CrossEntropyImpl.auto: @@ -222,9 +221,7 @@ def _get_targets( if lm_target is not None: # MTP: Shift the labels lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) - + 1 - - self._config.prediction_heads + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads ) if LanguageModelKwargs.sequence_q_dim in kwargs: Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) @@ -336,7 +333,7 @@ def _logits_cross_entropy_forward_backward( self.training, grad_output, losses, - LanguageModelLossNames.z_loss, + self._z_loss_name, logits_scale_factor=self._config.logits_scale_factor, ) if self._debug.enabled and self._config.cross_entropy_splits is None: @@ -424,14 +421,81 @@ def _logits_cross_entropy_forward_backward( loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) if self.training and losses is not None: if dpo_loss is not None: - losses[LanguageModelLossNames.dpo_loss].append(dpo_loss.detach()) + losses[self._dpo_loss_name].append(dpo_loss.detach()) if self._config.distillation_model is not None and distillation_loss is not None: - losses[LanguageModelLossNames.distillation_loss].append(distillation_loss.detach()) + losses[self._distillation_language_model_loss_name].append(distillation_loss.detach()) if self._config.distillation_model is not None and lm_loss is not None: - losses[LanguageModelLossNames.distil_lm_loss].append(lm_loss.detach()) + losses[self._distillation_loss_name].append(lm_loss.detach()) return loss, output_parallel_linear_backward(grad, context) if self.training else None + @functools.cached_property + def _loss_name(self) -> str: + name = "language_model_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _z_loss_name(self) -> str: + name = "z_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _dpo_loss_name(self) -> str: + name = "dpo_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _distillation_language_model_loss_name(self) -> str: + name = "distillation_language_model_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _distillation_loss_name(self) -> str: + name = "distillation_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] + if self._config.logit_z_loss: + LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) + if self._config.enable_dpo: + loss_defs.append( + LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) + ) + + if self._config.distillation_model is not None: + loss_defs.append( + LossDef( + name=self._distillation_loss_name, + formatted_name=_format_name(self._distillation_loss_name), + count=count, + ) + ) + if self._config.language_model_loss_factor > 0.0: + loss_defs.append( + LossDef( + name=self._distillation_language_model_loss_name, + formatted_name=_format_name(self._distillation_language_model_loss_name), + count=count, + ) + ) + + return loss_defs + + +def _format_name(name: str) -> str: + return name.replace("_", " ") + def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: tensors = [tensor for tensor in tensors if tensor is not None] diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py new file mode 100644 index 000000000..79555d866 --- /dev/null +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -0,0 +1,77 @@ +import torch + +from fast_llm.engine.base_model.base_model import Layer, LayerWithNamespace +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, MultiTokenPredictionConfig + + +class MultiTokenPrediction[ConfigType: MultiTokenPredictionConfig](BlockBase[ConfigType]): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + self.blocks = torch.nn.ModuleList( + [ + self._config.block.get_layer( + self._distributed_config, + self._hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + # The last block only returns the model output. + # The previous blocks return a stack of shared_hidden and transformer_output. + return_input=index < self._config.prediction_heads - 1, + ) + for index in range(self._config.prediction_heads) + ] + ) + self.heads = torch.nn.ModuleList( + [ + self._config.head.get_layer( + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + prediction_distance=index, + prediction_heads=self._config.prediction_heads, + loss_coefficient=( + 1.0 + if self._config.prediction_loss_coefficient is None + else self._config.prediction_loss_coefficient[index] + ), + ) + for index in range(self._config.prediction_heads) + ] + ) + + # Wrap all blocks in a namespace using the unique module name of the first one. + namespace = self.blocks[0].module_name + # Note: Pytorch won't redundantly register modules because it doesn't look into lists. + self._blocks_with_namespace = [ + LayerWithNamespace(sublayer, namespace) for layer in self.blocks for sublayer in layer.get_layers() + ] + + def get_layers(self) -> list[Layer]: + return [ + module + for block, head in zip(self._blocks_with_namespace, self.heads, strict=True) + for module in (block, head) + ] diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py deleted file mode 100644 index fc1dac299..000000000 --- a/fast_llm/layers/language_model/preprocessing.py +++ /dev/null @@ -1,107 +0,0 @@ -import logging -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_dim import scalar_dim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs -from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert - -logger = logging.getLogger(__name__) - - -class PositionEmbeddingPreprocessor(Preprocessor): - _rotary_embedding_frequencies: torch.Tensor - _position_ids: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - - def __init__(self, config: LanguageModelEmbeddingsConfig, distributed_config: DistributedConfig): - self._config = config - assert config.position_embeddings.enabled - self._distributed_config = distributed_config - - def _create_tensors(self, sequence_length: int, device: torch.device) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - Assert.leq(sequence_length, self._config.num_position_embeddings) - self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) - - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[LanguageModelKwargs.sequence_length], batch.device) - sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size - sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: - position_ids = torch.stack( - [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] - ).to(batch.device, dtype=torch.int64) - position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] - if kwargs[LanguageModelKwargs.sequence_first]: - position_ids = position_ids.transpose(0, 1) - kwargs[LanguageModelKwargs.position_ids] = position_ids - else: - kwargs[LanguageModelKwargs.position_ids] = self._position_ids[ - sequence_k - sequence_q : sequence_k - ].unsqueeze(int(kwargs[LanguageModelKwargs.sequence_first])) - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - # Position embeddings will be broadcast. - sequence_q_dim = kwargs[LanguageModelKwargs.sequence_q_dim] - kwargs[LanguageModelKwargs.position_ids] = TensorMeta.from_dims( - ( - (sequence_q_dim, scalar_dim) - if kwargs[LanguageModelKwargs.sequence_first] - else (scalar_dim, sequence_q_dim) - ), - tensor_name=LanguageModelKwargs.position_ids, - dtype=torch.int64, - ) - - -class PreferenceSpanPreprocessor(Preprocessor): - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - return - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size - sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels - - if LanguageModelKwargs.chosen_spans not in kwargs or LanguageModelKwargs.rejected_spans not in kwargs: - raise ValueError("Expected chosen spans or rejected spans to be found within the batch.") - - chosen_spans = kwargs[LanguageModelKwargs.chosen_spans] - chosen_valid_spans = [] - for spans in chosen_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - chosen_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans - - rejected_spans = kwargs[LanguageModelKwargs.rejected_spans] - rejected_valid_spans = [] - for spans in rejected_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - rejected_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py index 966c7843b..fbd0ed807 100644 --- a/fast_llm/layers/vision/config.py +++ b/fast_llm/layers/vision/config.py @@ -2,8 +2,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, config_class -from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig from fast_llm.layers.common.linear.config import Convolution2DConfig from fast_llm.layers.common.normalization.config import NormalizationConfig @@ -13,31 +12,6 @@ pass -class VisionEncoderDimNames: - in_channels = "vision_in_channels" - out_channels = "vision_out_channels" - adapter_size = "vision_adapter_size" - patch_size = "vision_patch_size" - kv_channels = "vision_kv_channels" - - -class VisionEncoderKwargs: - patch_size = "patch_size" - images = "images" - image_patches = "image_patches" - image_positions = "image_positions" - max_image_size = "max_image_size" - image_sizes = "image_sizes" - rope_theta = "vit_rope_theta" - rotary_inv_freq = "vit_rotary_inv_freq" - kv_channels = "vit_kv_channels" - max_image_tokens = "max_image_tokens" - patch_embeddings = "patch_embeddings" - hidden_dims = "vit_hidden_dims" - image_patches_meta = "vit_image_patches_meta" - out_channels = "vit_out_channels" - - @config_class() class ImageNormalizationConfig(Config): mean_r: float = Field( @@ -107,7 +81,7 @@ class PatchConvolutionConfig(BlockConfig): @config_class(registry=True) -class VisionEncoderConfig(BaseModelConfig): +class VisionEncoderConfig(ModuleConfig): _abstract = False patch_convolution_layer: PatchConvolutionConfig = Field( desc="Configuration for the patch convolution layer.", @@ -199,13 +173,6 @@ def _validate(self) -> None: self.adapter_init_method_std = self.adapter_size**-0.5 super()._validate() - def setup_tensor_space(self, tensor_space: "TensorSpace"): - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) - self.transformer.setup_tensor_space(tensor_space) - @property def enabled(self) -> bool: return self.type != VisionEncoderType.none diff --git a/fast_llm/layers/vision/patch_convolution.py b/fast_llm/layers/vision/patch_convolution.py index b40a9593e..46cf86708 100644 --- a/fast_llm/layers/vision/patch_convolution.py +++ b/fast_llm/layers/vision/patch_convolution.py @@ -4,11 +4,11 @@ from fast_llm.core.ops import split from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.vision.config import PatchConvolutionConfig, VisionEncoderKwargs +from fast_llm.layers.vision.config import PatchConvolutionConfig from fast_llm.tensor import TensorMeta @@ -34,6 +34,7 @@ def __init__( ) input_dim = TensorDim("input_channels", self._config.input_channels) patch_dim = TensorDim("patch", self._config.patch_size) + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self.convolution = self._config.convolution.get_layer( self._hidden_dim, @@ -55,23 +56,16 @@ def forward( metrics: dict | None = None, ) -> torch.Tensor: if isinstance(input_, TensorMeta): - # TODO: Replace last dim instead - hidden_dims = kwargs[VisionTransformerKwargs.hidden_dims] - return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] - sequence_length = kwargs[AttentionKwargs.sequence_length] - out_channels = kwargs[VisionEncoderKwargs.out_channels] + return TensorMeta.from_dims( + input_.dims[:-1] + (self._hidden_dim,), tensor_name="patch conv output", dtype=input_.dtype + ) # TODO: Avoid padding - reshape_dims = (micro_batch_size, sequence_length, out_channels) - group = self._tensor_space.distributed.tensor_group - input_ = self.convolution(input_) - patch_embeddings = self.norm(input_.flatten(1)) - patch_embeddings = patch_embeddings.view(reshape_dims) + patch_embeddings = self.normalization(input_.flatten(1)).view_as(input_) - # TODO: Sequence first" - if sequence_first: + # TODO: Permute earlier? + if kwargs[AttentionKwargs.sequence_first]: patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() if self._sequence_parallel: - patch_embeddings = split(patch_embeddings, group=group, dim=0) + patch_embeddings = split(patch_embeddings, group=self._parallel_dim.group, dim=0) return patch_embeddings diff --git a/fast_llm/layers/vision/preprocessing.py b/fast_llm/layers/vision/preprocessing.py index 83246608e..83331c739 100644 --- a/fast_llm/layers/vision/preprocessing.py +++ b/fast_llm/layers/vision/preprocessing.py @@ -4,13 +4,10 @@ import torch import torchvision.transforms.v2 as torchvision_transforms -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.vision.config import ImageNormalizationConfig, VisionEncoderConfig, VisionEncoderKwargs -from fast_llm.tensor import TensorMeta +from fast_llm.layers.vision.config import ImageNormalizationConfig, VisionEncoderConfig from fast_llm.utils import div @@ -66,26 +63,6 @@ def resize(image: torch.Tensor, target_height: int, target_width: int) -> torch. ) -def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: - freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) - max_patches_per_side = max_image_size // patch_size - - h = torch.arange(max_patches_per_side) - w = torch.arange(max_patches_per_side) - - freqs_h = torch.outer(h, freqs[::2]).float() - freqs_w = torch.outer(w, freqs[1::2]).float() - inv_freq = torch.cat( - [ - freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), - freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), - ], - dim=-1, - ).reshape(-1, kv_channels // 2) - - return torch.cat((inv_freq, inv_freq), dim=-1) - - def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: patch_height = height // patch_size patch_width = width // patch_size @@ -94,25 +71,11 @@ def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tenso ) -class VisionPreprocessor(Preprocessor): +class VisionPreprocessor: def __init__(self, config: VisionEncoderConfig, distributed: Distributed): self._config = config self._distributed = distributed - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( - ( - TensorDim( - "vision_batch", - kwargs[AttentionKwargs.micro_batch_size] * kwargs[AttentionKwargs.sequence_q_dim].size, - ), - TensorDim("in_channels", 3), - TensorDim("patch_size", self._config.patch_size), - TensorDim("patch_size", self._config.patch_size), - ), - dtype=self._distributed.config.training_dtype.torch, - ) - def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: max_image_size = kwargs.get(VisionEncoderKwargs.max_image_size) patch_size = self._config.patch_size @@ -203,12 +166,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: kwargs[VisionTransformerKwargs.patch_position_ids] = torch.cat(patch_position_ids).to( device=self._distributed.device ) - kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( - kwargs[VisionEncoderKwargs.rope_theta], - kwargs[VisionEncoderKwargs.kv_channels], - max_image_size, - patch_size, - ).to(device=self._distributed.device) kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size**2, patch_size**2) # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 58d1ed04c..6721daea2 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,24 +4,22 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.models.gpt.conversion.config import ( AprielHybridSSMCheckpointFormat, AutoGPTHuggingfaceCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, LlamaCheckpointFormat, - LlavaCheckpointFormat, - LlavaHybridCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, MTPLlamaCheckpointFormat, - PixtralCheckpointFormat, Qwen2CheckpointFormat, ) from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds @@ -29,7 +27,7 @@ if typing.TYPE_CHECKING: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM - from fast_llm.models.gpt.model import GPTInferenceRunner, GPTModel + from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.gpt.trainer import GPTTrainer logger = logging.getLogger(__name__) @@ -83,7 +81,7 @@ def micro_batch_splits(self) -> int: @config_class() -class GPTBaseModelConfig(LanguageModelBaseConfig): +class GPTBaseModelConfig(LanguageModelConfig, BaseModelConfig): _abstract = False # Debug, to get an exact match with megatron init. @@ -91,6 +89,12 @@ class GPTBaseModelConfig(LanguageModelBaseConfig): default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing ) + @property + def base_model_class(self) -> type["GPTBaseModel"]: + from fast_llm.models.gpt.model import GPTBaseModel + + return GPTBaseModel + @config_class(dynamic_type={FastLLMModelConfig: "gpt"}) class GPTModelConfig(FastLLMModelConfig): @@ -107,9 +111,6 @@ class GPTModelConfig(FastLLMModelConfig): DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, AprielHybridSSMCheckpointFormat, - PixtralCheckpointFormat, - LlavaCheckpointFormat, - LlavaHybridCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b5f429376..9ddcf6300 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -5,7 +5,7 @@ from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer -from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner @@ -13,8 +13,6 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding -from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -37,22 +35,42 @@ def __init__( ): self._hidden_dim = TensorDim("hidden", config.embeddings_layer.hidden_size) super().__init__(config, distributed_config) + + hidden_dim = TensorDim("hidden", self.embeddings_layer.hidden_size) + self.embedding = self._config.embeddings_layer.get_layer( + distributed_config, + hidden_dim=hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + self.decoder = self._config.decoder.get_layer( + distributed_config, + hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + self.head = self._config.output_layer.get_layer( + distributed_config, + self._config.embeddings_layer, + hidden_dim=hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron( param, self._config.decoder.block, config.embeddings_layer.hidden_size ) # Noqa - # `self._reference_models` is not populated at this point, so we pass a mutable dict. - self._preprocessors: list[Preprocessor] = self._config.get_preprocessors(distributed_config) # TODO ====== Vision ====== # if self._config.vision_encoder.enabled: # self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) # self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) - def get_layers(self) -> list[Layer]: - return self._config.get_blocks(self._distributed_config) + def get_layers(self) -> list["Layer"]: + return self.embedding.get_layers() + self.decoder.get_layers() + self.head.get_layers() # TODO ====== Vision ====== # def get_vision_layers(self) -> list[Layer]: @@ -67,16 +85,10 @@ def get_layers(self) -> list[Layer]: # MultiModalEmbedding(self._config, self._tensor_space), # ] - def get_embedding_layers(self) -> list[Layer]: - if self._config.vision_encoder.enabled: - return self.get_vision_layers() - else: - return [LanguageModelEmbedding(self._config, self._tensor_space)] - def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: - # TODO: How much of this is generalizable? + # TODO ====== Remove (Move batch splitting elsewhere) ====== # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence if isinstance(batch_meta, GPTBatchConfig): @@ -193,8 +205,6 @@ def preprocess_meta( kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( hidden_dims[:2], tensor_name="labels", dtype=torch.int64 ) - for preprocessor in self._preprocessors: - preprocessor.preprocess_meta(kwargs) reference_kwargs = {} for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] @@ -228,7 +238,7 @@ def preprocess( iteration: int, metrics: dict | None = None, ) -> list[tuple[torch.Tensor, dict]]: - # TODO: How much of this is generalizable? + # TODO ====== Move batch splitting elsewhere, align interface with LayerBase ====== assert self._is_setup if preprocessed_meta is None: @@ -319,6 +329,39 @@ def preprocess( labels[start : end + 1, idx] = -100 else: labels[idx, start : end + 1] = -100 + + # TODO ====== Preference spans ====== + if batch.chosen_spans is not None: + chosen_valid_spans = [] + for spans in batch.chosen_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + chosen_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans + + rejected_valid_spans = [] + for spans in batch.rejected_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + rejected_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans + # TODO ====== Vision ====== # if self._config.vision_encoder.enabled: # if self._config.vision_encoder.image_break_token is not None: @@ -358,8 +401,10 @@ def preprocess( # ) # kwargs[LanguageModelKwargs.tokens] = tokens - for preprocessor in self._preprocessors: - preprocessor.preprocess(tokens, kwargs) + # TODO ====== Turn into super() call ====== + self.embedding.preprocess(tokens, kwargs) + self.decoder.preprocess(tokens, kwargs) + self.head.preprocess(tokens, kwargs) # TODO ====== Vision ====== # image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) @@ -372,10 +417,6 @@ def preprocess( return preprocessed - @property - def embedding(self) -> LanguageModelEmbedding: - return self.layers[0] - # TODO ====== Vision ====== # @property # def embedding(self) -> LanguageModelEmbedding: @@ -392,37 +433,41 @@ def embedding(self) -> LanguageModelEmbedding: # else: # return 0 - @property - def model_head(self) -> LanguageModelHead: - return self.layers[self.model_head_indices[0]] - - @property - def model_head_indices(self) -> list[int]: - return sorted([len(self) - 1 - 2 * i for i in range(self._config.output_layer.prediction_heads)]) - def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: - if self._config.output_layer.tied_weight: - return { - WORD_EMBEDDINGS_WEIGHT: ( - self.embedding.word_embeddings_weight, - # TODO ====== Vision ====== - # (self.embedding_layer_index, *self.model_head_indices), - (0, *self.model_head_indices), - ) - } - elif self._config.output_layer.prediction_heads > 1: - return { - OUTPUT_WEIGHTS: ( - self.model_head.output_weights, - tuple(self.model_head_indices), - ) - } - else: - return {} + # TODO ====== Tied weights ====== + if self._config.tied_embedding_weight: + raise NotImplementedError() + return {} + # if self._config.output_layer.tied_weight: + # return { + # WORD_EMBEDDINGS_WEIGHT: ( + # self.embedding.word_embeddings_weight, + # # TODO ====== Vision ====== + # # (self.embedding_layer_index, *self.model_head_indices), + # (0, *self.model_head_indices), + # ) + # } + # elif self._config.output_layer.prediction_heads > 1: + # return { + # OUTPUT_WEIGHTS: ( + # self.model_head.output_weights, + # tuple(self.model_head_indices), + # ) + # } + # else: + # return {} + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return ( + self.embeddings_layer.get_loss_definitions(count) + + self.decoder.get_loss_definitions(count) + + self.output_layer.get_loss_definitions(count) + ) class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): - base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel + # TODO: Can we drop class? + pass class GPTInferenceRunner(InferenceRunner): diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index f14f028e1..7bc215670 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -248,6 +248,7 @@ def test_lm_head( else: logit_weight = None + # TODO ====== Fix ====== for prediction_distance, layer_index in enumerate(model.model_head_indices): # Prepare the LM head head: LanguageModelHead = model[layer_index] diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 0dc3462eb..b086c291f 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -41,7 +41,7 @@ def get_stage(base_model: BaseModel | list[Layer], distributed: Distributed): # Create a fast-llm stage which allocates and initializes meta tensors correctly. stage = Stage( config=StageConfig(), - base_model=base_model, + layers=base_model, distributed_config=distributed.config, begin=0, end=1, From f31a313a4b6176826ab081ee9a998623922cb1e4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 3 Oct 2025 19:48:25 -0400 Subject: [PATCH 13/15] Fix merge --- Megatron-LM | 2 +- examples/mistral.yaml | 6 +- fast_llm/core/distributed.py | 54 ++++- fast_llm/engine/base_model/base_model.py | 44 +++- fast_llm/engine/checkpoint/huggingface.py | 12 +- fast_llm/engine/evaluation/config.py | 7 + .../engine/evaluation/lm_eval/evaluator.py | 2 + .../evaluation/lm_eval/fast_llm_wrapper.py | 226 ++++++++++++------ fast_llm/engine/inference/huggingface.py | 157 +++++++++++- fast_llm/engine/multi_stage/fsdp.py | 6 +- fast_llm/engine/multi_stage/multi_stage.py | 99 +++++--- fast_llm/engine/multi_stage/stage.py | 10 +- fast_llm/engine/multi_stage/stage_base.py | 32 ++- fast_llm/engine/schedule/runner.py | 7 +- fast_llm/layers/block/sequence.py | 44 ++-- fast_llm/layers/decoder/block.py | 4 + fast_llm/layers/language_model/config.py | 78 +++--- fast_llm/layers/language_model/head.py | 56 +++-- .../layers/language_model/language_model.py | 61 +++++ .../language_model/multi_token_prediction.py | 25 +- fast_llm/models/gpt/config.py | 49 ++-- fast_llm/models/gpt/conversion/apriel.py | 25 +- fast_llm/models/gpt/conversion/llama.py | 91 +++---- fast_llm/models/gpt/conversion/mistral.py | 4 +- fast_llm/models/gpt/conversion/mtp_llama.py | 70 ++++-- fast_llm/models/gpt/huggingface.py | 11 +- fast_llm/models/gpt/model.py | 130 +++------- fast_llm/models/gpt/trainer.py | 7 +- .../modeling_apriel_hybrid_ssm.py | 35 +++ tests/layers/test_lm_head.py | 127 ++++++---- tests/models/test_checkpoint.py | 2 +- tests/test_attention.py | 14 +- tests/test_config.py | 43 ++-- tests/test_multi_stage.py | 9 +- tests/utils/compare_tensor_logs.py | 4 +- tests/utils/distributed_configs.py | 6 +- tests/utils/model_configs.py | 27 ++- tests/utils/utils.py | 18 +- 38 files changed, 1037 insertions(+), 567 deletions(-) create mode 100644 fast_llm/layers/language_model/language_model.py diff --git a/Megatron-LM b/Megatron-LM index 30e7aeccd..dee27459d 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 30e7aeccd87ec22e424f35c6e61f05ceb878a8df +Subproject commit dee27459d46fecc513be76732a0095bb38be32fb diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 88655954f..2e4a57de7 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -27,7 +27,7 @@ optimizer: beta_2: 0.95 model: base_model: - embeddings_layer: + embeddings: hidden_size: 4096 vocab_size: 32000 dropout: 0.0 @@ -54,11 +54,11 @@ model: epsilon: 1.0e-05 dropout: 0.0 num_blocks: 32 - output_layer: - tied_weight: false + head: normalization: type: rms_norm epsilon: 1.0e-05 + tied_embedding_weight: false multi_stage: zero_stage: 2 distributed: diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index 86f8e7297..16b7c3921 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -107,6 +107,54 @@ def broadcast_scalar( return tensor.item() +def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None, src: int = 0) -> typing.Any: + """ + Broadcasts a Python object from src rank to all other ranks in the ProcessGroup. + Returns the object on all ranks. + """ + assert group is not None + + if group.rank() == src: + tensor = _object_to_tensor(input_object) + size = tensor.numel() + broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + broadcast_tensor.copy_(tensor) + broadcast_scalar(size, torch.int64, group, src) + broadcast(broadcast_tensor, src, group) + return input_object + else: + size = int(broadcast_scalar(None, torch.int64, group, src)) + output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + broadcast(output_tensor, src, group) + return _tensor_to_object(output_tensor) + + +def broadcast_optional(tensor: torch.Tensor | None, group: ProcessGroup = None, src: int = 0) -> torch.Tensor: + """ + Broadcasts an optional tensor of size, shape, and dtype unknown in advance. + Returns the tensor on all ranks or None if no tensor was sent. + """ + assert group is not None + + if group.rank() == src: + has_tensor = tensor is not None + if has_tensor: + meta = (has_tensor, tensor.shape, tensor.dtype) + else: + meta = (has_tensor, None, None) + broadcast_object(meta, group, src) + if has_tensor: + broadcast(tensor.to(torch.cuda.current_device()), src, group) + return tensor + else: + has_tensor, shape, dtype = broadcast_object(None, group, src) + if not has_tensor: + return None + output_tensor = torch.empty(shape, dtype=dtype, device=torch.cuda.current_device()) + broadcast(output_tensor, src, group) + return output_tensor + + def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: assert group is not None work = group.send([tensor], dst, tag) @@ -186,7 +234,11 @@ def scatter( def _object_to_tensor(obj: typing.Any) -> torch.Tensor: f = io.BytesIO() pickle.Pickler(f).dump(obj) - return torch.tensor(torch.UntypedStorage.from_buffer(f.getvalue(), dtype=torch.uint8), dtype=torch.uint8) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + return torch.ByteTensor(byte_storage) def _tensor_to_object(tensor: torch.Tensor) -> typing.Any: diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index ce7002c54..5df59d4cd 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -73,6 +73,11 @@ def forward( ) -> torch.Tensor: pass + def unwrap(self) -> "Layer": + # Get the actual module contained in this layer, + # undoing any wrapping for the Fast-LLM engine (ex. `LayerWithNamespace`) + return self + class LayerWithNamespace(Layer): """ @@ -81,12 +86,13 @@ class LayerWithNamespace(Layer): TODO: Consider namespace for losses and metrics? """ - def __init__(self, layer: Layer, namespace: str): + def __init__(self, layer: Layer, namespace: str = None): super().__init__(layer._distributed_config) self._layer = layer self._namespace = namespace self.layer_count = self._layer.layer_count self.get_compute_usage = self._layer.get_compute_usage + self.module_name = self._layer.module_name def setup(self, distributed: Distributed) -> None: self._layer.setup(distributed) @@ -95,12 +101,21 @@ def setup(self, distributed: Distributed) -> None: def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: - return self._layer.forward(input_, kwargs[self._namespace], losses, metrics) + if self._namespace in kwargs: + kwargs = kwargs[self._namespace] + else: + # TODO: Forward meta doesn't go through preprocessing so doesn't have a namespace. + # Using kwargs as-is since it's generally unused. + assert isinstance(input_, TensorMeta) + return self._layer.forward(input_, kwargs, losses, metrics) def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: assert self._namespace not in kwargs kwargs[self._namespace] = kwargs.copy() - return self._layer.preprocess(batch, kwargs[self._namespace]) + self._layer.preprocess(batch, kwargs[self._namespace]) + + def unwrap(self) -> "Layer": + return self._layer.unwrap() class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], LayerBase): @@ -118,11 +133,11 @@ def __init__( @abc.abstractmethod def preprocess_meta(self, batch_meta: typing.Any, phase: PhaseType) -> list[tuple[TensorMeta, dict]]: - # TODO ====== Remove (Move batch splitting elsewhere) ====== + # TODO Remove (Move batch splitting elsewhere) pass @abc.abstractmethod - def preprocess( + def preprocess_batch( self, batch: typing.Any, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, @@ -131,16 +146,19 @@ def preprocess( iteration: int, metrics: dict | None = None, ) -> list[tuple[torch.Tensor, dict]]: - # TODO ====== Move batch splitting elsewhere, align interface with LayerBase ====== + # TODO Move batch splitting elsewhere, align interface with LayerBase pass - def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: - # TODO ====== Tied weights ====== - # Return tuples of independently defined metas to tie together. - # For each tied weight, return the weight and the tuple of layers sharing it. - # The weight should be defined in the first layer in the set. - # Warning: This may return buffers instead of metas after stage setup. - # The name (dict key) is used to insert the weight in the kwargs of the forward pass. + def get_tied_parameters(self) -> dict[str, list[ParameterMeta]]: + """ + Return tuples of independently defined metas to tie together. + Metas should be compatible, i.e. have the same tensor dimensions. + Tied weights are named (dict keys) for convenience only. + Warning: Initialization and optimization properties are defined on the first appearance of the tied weight. + To prevent any confusion, the metas should be provided in the same order they appear in the model. + TODO: Improve? + Note: This may return buffers instead of metas after stage setup. + """ return {} def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None: diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index e5d14711d..afe381295 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -31,7 +31,7 @@ def export_config(cls, config: BaseModelConfig) -> dict: @classmethod @abc.abstractmethod - def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]: + def get_converters(cls, config: BaseModelConfig, exported_config: dict) -> list[WeightConverter]: pass @@ -39,6 +39,10 @@ class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, architecture: typing.ClassVar[str] base_model_converter_class: typing.ClassVar[type[HuggingFaceBaseModelConverter]] + def __init__(self, model: "FastLLMModel"): + self._exported_config = self._export_config(model.config) + super().__init__(model) + @classmethod @abc.abstractmethod def get_transformers_configuration_class(cls) -> type["transformers.PretrainedConfig"]: @@ -126,10 +130,8 @@ def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: Assert.eq(config["architecture"], cls.architecture) return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) - def _create_weight_converters( - self, - ) -> list[WeightConverter]: - return self.base_model_converter_class.get_converters(self._model.config.base_model) + def _create_weight_converters(self) -> list[WeightConverter]: + return self.base_model_converter_class.get_converters(self._model.config.base_model, self._exported_config) def _load_weights( self, config: CheckpointLoadConfig, device diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 4eb5d71df..4f035e174 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -98,6 +98,13 @@ class LmEvalEvaluatorConfig(EvaluatorConfig): " If not set, it is inferred from the Fast-LLM model config or tokenizer.", ) + communication_timeout_sec: float = Field( + default=600.0, + desc="Maximum wait time (in seconds) for tensor-parallel or data-parallel model " + "operations such as forward, generate, or gathering data. Needed because some " + "ranks may have no data or post-processing can be slow, exceeding the default 60s timeout.", + ) + def get_evaluator( self, name: str, diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 9040b11b4..14aed65c4 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -66,6 +66,8 @@ def setup( add_bos_token=self._config.add_bos_token, prefix_token_id=self._config.prefix_token_id, max_length=self._config.max_length, + batch_config=self._batch_config, + communication_timeout_sec=self._config.communication_timeout_sec, ) self._is_setup = True diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 3a606b41d..bc42515e7 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -16,6 +16,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM +from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.attention.rotary.config import NoRotaryConfig logger = logging.getLogger(__name__) @@ -34,22 +35,29 @@ def __init__( add_bos_token: bool | None = False, prefix_token_id: int | None = None, max_length: int | None = None, + batch_config: BatchConfig | None = None, + communication_timeout_sec: float = 600.0, ): super().__init__() # === Distributed setup === self._rank = 0 # For lm_eval: always run on main rank - self._world_size = 1 + self._world_size = 1 # For lm_eval: always world size 1 + + self.communication_timeout_sec = communication_timeout_sec + self._distributed: Distributed = model._inference_runner._fast_llm_model.distributed + self._world_group = self._distributed.world_group + if ( self._distributed.config.sequence_data_rank == 0 and self._distributed.config.pipeline_rank == 0 and self._distributed.config.tensor_rank == 0 ): - self._group = self._distributed.batch_data_group + self._leading_batch_data_group = self._distributed.batch_data_group else: - self._group = torch.distributed.GroupMember.NON_GROUP_MEMBER + self._leading_batch_data_group = None # === Model & tokenizer setup === self._model = model @@ -76,7 +84,10 @@ def __init__( # === Batch configuration === self._batch_schedule = 1 self._batch_sizes = {} # Not used dynamically by lm_eval - self._batch_size_per_gpu = model._inference_runner._batch_config.micro_batch_size + + # NOTE: We can not take batch configuration from inference runner as it has a dummy batch config + self._batch_size_per_gpu = batch_config.micro_batch_size if batch_config else 1 + self._batch_size = self._batch_size_per_gpu * self._distributed.config.batch_data_parallel self._max_batch_size = self._batch_size @@ -171,11 +182,15 @@ def run(self, cli_args: list[str], completed_steps: int, run_index: int): completed_steps, ) else: - self.worker_model_invoke() + # On the rest of the bath data group leaders, we run the full generate/forward pass + # On all other ranks, we only invoke worker_forward + if self._leading_batch_data_group: + self.worker_model_invoke() + else: + self._model.worker_forward(communication_timeout_sec=self.communication_timeout_sec) - # TODO: do we need it here as self.stop_workers() and self.worker_model_invoke() - # already have barrier - safe_barrier(self._distributed.world_group, f"lm_eval Run end") + # Model forward workers end earlier, so sync here for all gpus + safe_barrier(self._world_group, f"lm_eval Run end") def _model_invoke( self, @@ -185,39 +200,44 @@ def _model_invoke( max_length, stop, generate: bool, - continue_generate: bool, + continue_work: bool, **generation_kwargs, ): # TODO: Consider passing true messages and payloads around instead of combining all data into a large tuple. # Messages could include types like logits, generate, finished. - # Group is always None if world size is 1 - if self._group is None: - # Must not be called with continue_generate false on one process - assert continue_generate + # Call directly if on one gpu or data group size is 1 + if self._world_group is None or self._leading_batch_data_group is None: + # Must not be called with continue_work false on one gpu + assert self._world_group or continue_work + # Still call then continue_work false to stop model forward workers return self._model_invoke_inner( - input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, **generation_kwargs ) - world_size = self._group.size() + assert self._world_group.rank() == 0 - assert self._group.rank() == 0 + batch_data_parallel_size = self._leading_batch_data_group.size() - if continue_generate: + if continue_work: assert input_ids is not None if generate: assert max_length is not None and stop is not None - # always divide by world_size, if not full batch, some ranks will get less work or not at all - assert self._batch_size % world_size == 0 - step = self._batch_size // world_size + # Always divide by batch_data_parallel_size, if not full batch, some ranks will get less work or not at all. + assert self._batch_size % batch_data_parallel_size == 0 + step = self._batch_size // batch_data_parallel_size - input_ids = [input_ids[i * step : (i + 1) * step] for i in range(world_size)] + # Data is send to every rank and micro batches are repeated for the same batch_data_parallel rank. + input_ids = [input_ids[i * step : (i + 1) * step] for i in range(batch_data_parallel_size)] attention_mask = [ - attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None - for i in range(world_size) + (attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None) + for i in range(batch_data_parallel_size) + ] + labels = [ + (labels[i * step : (i + 1) * step] if labels is not None else None) + for i in range(batch_data_parallel_size) ] - labels = [labels[i * step : (i + 1) * step] if labels is not None else None for i in range(world_size)] scatter_list = [ [ @@ -227,36 +247,46 @@ def _model_invoke( max_length, stop, generate, - continue_generate, + continue_work, generation_kwargs, ] - for i in range(world_size) + for i in range(batch_data_parallel_size) ] else: - scatter_list = [[None, None, None, None, None, None, False, None] for _ in range(world_size)] + scatter_list = [[[], None, None, None, None, None, False, {}] for _ in range(batch_data_parallel_size)] - input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(self._leading_batch_data_group, "model_invoke_wait", timeout=self.communication_timeout_sec) + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, generation_kwargs = ( scatter_object( scatter_list, - group=self._group, + group=self._leading_batch_data_group, ) ) - if not continue_generate: + # Always call inner function to propagate stop signal to TP workers if continue_work is False + result = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, **generation_kwargs + ) + + if not continue_work: return None assert len(input_ids) > 0 - result = self._model_invoke_inner( - input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + # At the end, some data-parallel ranks may have no data, so the wait can + # exceed the standard 60s timeout. + safe_barrier( + self._leading_batch_data_group, "model_invoke_gather_wait", timeout=self.communication_timeout_sec ) + gather_list = gather_object(result, group=self._leading_batch_data_group) - gather_list = gather_object(result, group=self._group) - # Clean gather list from empty shards + # Clean gather list from empty shards (from not full batches). gather_list = [el for el in gather_list if len(el) > 0] # If it was model generate tensors could be of different length - # so we aggregate results to list instead of a tensor + # so we aggregate results to list instead of a tensor. if generate: result = sum((el.tolist() for el in gather_list), []) else: @@ -266,57 +296,109 @@ def _model_invoke( return result def worker_model_invoke(self): - assert self._group is not None - # if isinstance(self.group, dist.ProcessGroup): - if not isinstance(self._group, int): - # groups is None for world_size 1 - assert self._group.rank() != 0 - # on worker ranks the function need to wait to be called multiple times - while True: - input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( - scatter_object( - None, - group=self._group, - ) + # Group is None for world_size 1 and this function must not be called for world_size 1. + assert self._world_group + assert self._leading_batch_data_group + # The function must not be called on the main rank. + assert self._world_group.rank() != 0 + + device = torch.cuda.current_device() + + # On worker ranks the function need to wait to be called multiple times + while True: + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(self._leading_batch_data_group, "model_invoke_wait", timeout=self.communication_timeout_sec) + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, generation_kwargs = ( + scatter_object( + None, + group=self._leading_batch_data_group, ) + ) + # NOTE: scatter_object keeps tensors on the same device as the source, + # so they must be moved to the current device. + # TODO: With scatter_object, tensors are copied GPU → CPU → GPU, then scattered, + # and finally copied again to the correct GPU here. We already have a scatter + # primitive for tensors of known size and type; we need to extend it to + # handle optional tensors of unknown type and size directly (src GPU → dst GPU); + # and use it for scattering tensors like input_ids, attention_mask, labels. + if isinstance(input_ids, torch.Tensor): + input_ids = input_ids.to(device) + if isinstance(attention_mask, torch.Tensor): + attention_mask = attention_mask.to(device) + if isinstance(labels, torch.Tensor): + labels = labels.to(device) + + if continue_work: + logger.info(f"worker_model_invoke: input_id device {input_ids.device}, shape {input_ids.shape}") + + # Always call inner function to propagate stop signal to TP workers if continue_work is False + result = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, **generation_kwargs + ) - # Stop signal was send, end waiting/processing loop - if not continue_generate: - break + # Stop signal was send, end waiting/processing loop + if not continue_work: + break - # if some data was received, work, otherwise return empty tensor - if len(input_ids) > 0: - result = self._model_invoke_inner( - input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs - ) - else: - result = input_ids + # If some data was received, return processed results, otherwise return empty tensor + if len(input_ids) == 0: + result = input_ids - gather_object(result, group=self._group) - else: - # TODO: implement distributed model support - assert self._group == torch.distributed.GroupMember.NON_GROUP_MEMBER - safe_barrier(self._distributed.world_group, "lm_eval_end") + # At the end, some data-parallel ranks may have no data, so the wait can + # exceed the standard 60s timeout. + safe_barrier( + self._leading_batch_data_group, "model_invoke_gather_wait", timeout=self.communication_timeout_sec + ) + gather_object(result, group=self._leading_batch_data_group) + + safe_barrier(self._leading_batch_data_group, "lm_eval_end") def stop_workers(self): # Group is always None if world size is 1 - if self._group is None: + if self._world_group is None: return - self._model_invoke(None, None, None, None, None, None, continue_generate=False) - safe_barrier(self._distributed.world_group, "lm_eval_end") + + self._model_invoke([], None, None, None, None, None, continue_work=False) + + # Only if data group size > 1 worker_model_invoke is called and need to be synced here + if self._leading_batch_data_group: + safe_barrier(self._leading_batch_data_group, "lm_eval_end") def _model_invoke_inner( - self, input_ids, attention_mask, labels, max_length, stop, generate: bool, **generation_kwargs + self, + input_ids, + attention_mask, + labels, + max_length, + stop, + generate: bool, + continue_work: bool, + **generation_kwargs, ): + # If stopping, stop model forward workers and return. + if not continue_work: + # continue_work=False should not occur on a single GPU. + # This function must be called on batch data parallel group leaders. + # If there is only one data rank, the leader is global rank 0 and the data group will be None. + assert self._world_group is not None + assert self._world_group.rank() == 0 or self._leading_batch_data_group + self._model.stop_workers() + return None + + # If input_ids is empty, there is no work to process - return early + if len(input_ids) == 0: + # Receiving no work can only happen on non-zero ranks + assert self._world_group is not None and self._world_group.rank() != 0 + return None + if generate: return self._model_generate_inner(input_ids, attention_mask, max_length, stop, **generation_kwargs) else: return self._model_call_inner(input_ids, attention_mask, labels) def _model_call(self, input_ids, attention_mask=None, labels=None): - return self._model_invoke( - input_ids, attention_mask, labels, None, None, generate=False, continue_generate=True - ) + return self._model_invoke(input_ids, attention_mask, labels, None, None, generate=False, continue_work=True) def _model_generate(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): return self._model_invoke( @@ -326,7 +408,7 @@ def _model_generate(self, input_ids, attention_mask, max_length, stop, **generat max_length, stop, generate=True, - continue_generate=True, + continue_work=True, **generation_kwargs, ) @@ -370,6 +452,8 @@ def _model_call_inner(self, input_ids, attention_mask=None, labels=None): output_attentions=False, output_hidden_states=False, return_dict=True, + coordinator_forward=True, + communication_timeout_sec=self.communication_timeout_sec, ).logits.cpu() def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): @@ -398,6 +482,8 @@ def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **g stopping_criteria=stopping_criteria, pad_token_id=self._tokenizer.pad_token_id, use_cache=False, + coordinator_forward=True, + communication_timeout_sec=self.communication_timeout_sec, **generation_kwargs, ) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 54a82492b..b634f8a4d 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -7,7 +7,9 @@ import transformers.generation.utils import transformers.modeling_outputs +from fast_llm.core.distributed import broadcast_object, broadcast_optional, safe_barrier from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat +from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode @@ -57,12 +59,12 @@ def __init__( # or set existing model which also must be setup, so, do not accept not setup model assert fast_llm_model.is_setup - # We only support data parallel for now - Assert.eq(fast_llm_model.distributed.config.model_parallel, 1) - Assert.eq(fast_llm_model.distributed.config.sequence_data_parallel, 1) - self._inference_runner.setup() + # We only support data parallel and tensor parallel for now + Assert.eq(fast_llm_model.distributed.config.pipeline_parallel, 1) + Assert.eq(fast_llm_model.distributed.config.sequence_data_parallel, 1) + # Transformers needs to be able to inspect the base model. self.fast_llm_base_model = fast_llm_model.base_model @@ -112,7 +114,7 @@ def _init_weights(self, module) -> None: class HuggingfaceBaseModelForCausalLM(HuggingfacePreTrainedModel, transformers.generation.utils.GenerationMixin): - def forward( + def inner_forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, @@ -127,3 +129,148 @@ def forward( ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: # Meant to be overridden in derived classes raise NotImplementedError() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values=None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + coordinator_forward: bool = False, + communication_timeout_sec: float = 600.0, + continue_work: bool = True, + ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast | None: + """ + Forward pass compatible with HuggingFace forward. + + Additional arguments: + coordinator_forward (bool): + If True, only the TP group coordinator (rank 0) should call forward; + other ranks must call worker_forward. + If False, all TP group ranks call forward independently and return logits. + communication_timeout_sec (float): Maximum time (in seconds) to wait for the start of + forward or for a stop signal to worker ranks before timing out in worker_forward. + Must match the value passed to worker_forward. + continue_work (bool): Whether to continue processing in a TP group. + Only applies for coordinator_forward=True. + + Notes: + - In coordinator_forward=True mode, forward on rank 0 distributes data to other ranks. + - After processing, the coordinator (rank 0) must call `stop_workers()` before continuing, + to unblock worker_forward on other ranks. + - This mode augments HuggingFace generate with tensor-parallel capability. + """ + distributed: Distributed = self._inference_runner._fast_llm_model.distributed + + if coordinator_forward and distributed.world_group and distributed.tensor_group: + assert distributed.tensor_group.rank() == 0 + assert past_key_values is None and not use_cache + + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(distributed.tensor_group, "forward_wait", timeout=communication_timeout_sec) + + broadcast_optional(input_ids, distributed.tensor_group, 0) + broadcast_optional(attention_mask, distributed.tensor_group, 0) + broadcast_optional(position_ids, distributed.tensor_group, 0) + broadcast_optional(inputs_embeds, distributed.tensor_group, 0) + broadcast_optional(labels, distributed.tensor_group, 0) + + broadcast_object( + (past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, continue_work), + distributed.tensor_group, + 0, + ) + + if not coordinator_forward or continue_work: + return self.inner_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + ) + + return None + + def worker_forward( + self, + communication_timeout_sec: float = 600.0, + ): + """ + Run the forward loop on worker ranks in coordinated mode. + + This function must be called on all worker ranks (i.e., all ranks except the + coordinator/leading data-parallel rank). In coordinated mode, the coordinator + rank calls `forward`, which distributes inputs to workers. Each worker then + receives its inputs and runs a forward pass. + + Workers stay in this loop until a stop signal is broadcast, which happens when + the coordinator rank calls `stop_workers`. + + Args: + communication_timeout_sec (float): Maximum time (in seconds) to wait for the + start of a forward call or for a stop signal from the coordinator before + timing out. Must match the value passed to `forward`. + + Notes: + - Coordinator rank: calls `forward` in coordinated mode and later + `stop_workers` to unblock workers. + - Worker ranks: call `worker_forward` once and remain inside the loop, + executing forward passes with broadcasted inputs until a stop signal + is received. + """ + distributed: Distributed = self._inference_runner._fast_llm_model.distributed + assert distributed.world_group and distributed.tensor_group and distributed.tensor_group.rank() != 0 + + while True: + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(distributed.tensor_group, "forward_wait", timeout=communication_timeout_sec) + + input_ids = broadcast_optional(None, distributed.tensor_group, 0) + attention_mask = broadcast_optional(None, distributed.tensor_group, 0) + position_ids = broadcast_optional(None, distributed.tensor_group, 0) + inputs_embeds = broadcast_optional(None, distributed.tensor_group, 0) + labels = broadcast_optional(None, distributed.tensor_group, 0) + + past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, continue_work = ( + broadcast_object(None, distributed.tensor_group, 0) + ) + + if not continue_work: + break + + self.inner_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + ) + + safe_barrier(distributed.world_group, "forward_work_end") + + def stop_workers(self): + distributed: Distributed = self._inference_runner._fast_llm_model.distributed + # On single gpu or no tp, no worker_forward to stop + if distributed.world_group is None or distributed.tensor_group is None: + return + self.forward(coordinator_forward=True, continue_work=False) + safe_barrier(distributed.world_group, "forward_work_end") diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 868cc2db4..827079f6e 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -56,11 +56,11 @@ def __init__( # The index range of the parameters in the buffer. self._parameter_begins_in_buffer = { parameter_meta.tensor_name: offset - for parameter_meta, offset in zip(parameter_metas, parameter_offsets[:-1]) + for parameter_meta, offset in zip(parameter_metas, parameter_offsets[:-1], strict=True) } self._parameter_ends_in_buffer = { parameter_meta.tensor_name: offset - for parameter_meta, offset in zip(parameter_metas, parameter_offsets[1:]) + for parameter_meta, offset in zip(parameter_metas, parameter_offsets[1:], strict=True) } # Shard properties @@ -377,7 +377,7 @@ def reduce_gradients( assert self._mode.support_backward if not self._requires_grad: return - for buffer, meta in zip(self._parameter_buffers.values(), self._parameter_metas.values()): + for buffer, meta in zip(self._parameter_buffers.values(), self._parameter_metas.values(), strict=True): if buffer.param_grad_is_zero: # noqa assert allow_no_grad or meta.allow_no_grad, meta triton_fill(buffer.grad_buffer, 0) # noqa diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 77dc4e7dd..f45f93862 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -3,7 +3,6 @@ import typing import warnings -import numpy as np import torch from torch._C._distributed_c10d import ProcessGroup @@ -62,42 +61,51 @@ def __init__( self._num_stages, self._config.distributed.pipeline_parallel * self._config.multi_stage.stages_per_pipeline_stage, ) + # Keep track of which stage each parameter belongs to. + self._parameter_stages: dict[str, int] = {} + for stage_index in range(self._num_stages): + for layer in self._layers[stage_splits[stage_index] : stage_splits[stage_index + 1]]: + for meta in layer.parameters(): + assert meta.tensor_name not in self._parameter_stages + self._parameter_stages[meta.tensor_name] = stage_index + + # Determine which stages belong to this pipeline rank. + self._stage_pipeline_ranks = { + stage_index: (stage_index // self._config.multi_stage.stages_per_pipeline_stage) + % self._config.distributed.pipeline_parallel + for stage_index in (range(self._num_stages)) + } + + # Set up tied weights. + self._tied_parameters = self._get_tied_parameters() + self._tied_parameter_duplicates = [{} for _ in range(self._num_stages)] + for tied_parameter in self._tied_parameters.values(): + for meta in tied_parameter.metas[1:]: + self._tied_parameter_duplicates[self._parameter_stages[meta.tensor_name]][ + meta.tensor_name + ] = tied_parameter # Create the stages. self._stages = [ Stage( config=self._config.multi_stage, - layers=self._layers[stage_splits[i] : stage_splits[i + 1]], + layers=self._layers[stage_splits[stage_index] : stage_splits[stage_index + 1]], distributed_config=self._config.distributed, - index=i, + index=stage_index, + tied_parameter_duplicates=tied_parameter_duplicates_.keys(), ) - for i in (range(self._num_stages)) + for stage_index, tied_parameter_duplicates_ in enumerate(self._tied_parameter_duplicates) ] if self._verbose: log_main_rank(lambda: f" Total parameters: {sum(stage_.parameter_count for stage_ in self._stages):,} ") - # Keep track of which stage each parameter belongs to. - self._parameter_stages: dict[str, int] = {} - for stage_index, stage in enumerate(self._stages): - for parameter_name in stage.parameter_names: - assert parameter_name not in self._parameter_stages - self._parameter_stages[parameter_name] = stage_index - - # Determine which stages belong to this pipeline rank. - self._stage_pipeline_ranks = { - stage_index: (stage_index // self._config.multi_stage.stages_per_pipeline_stage) - % self._config.distributed.pipeline_parallel - for stage_index in (range(self._num_stages)) - } self._stages_owned = { stage_index: self._stages[stage_index] for stage_index, stage_rank in self._stage_pipeline_ranks.items() if stage_rank == self._config.distributed.pipeline_rank } - # Set up tied weights. - self._tied_parameters = self._get_tied_parameters(stage_splits[1:]) self._tied_weight_main_stages_on_device = { stage_index: self._stages[stage_index] for stage_index in sorted( @@ -318,6 +326,16 @@ def _setup_stages(self) -> None: if self._mode.support_forward and weight_buffer_index is not None else [] ) + tied_weight_duplicate_buffers = ( + { + parameter_name: self._stages[tied_parameter.main_stage].get_parameter_buffer( + tied_parameter.metas[0].tensor_name + ) + for parameter_name, tied_parameter in self._tied_parameter_duplicates[stage_index].items() + } + if self._mode.support_forward and stage_index in self._stages_on_device + else None + ) stage.setup( distributed=self._distributed, weight_shards=stage_weight_shards, @@ -326,6 +344,7 @@ def _setup_stages(self) -> None: grad_buffers=stage_grad_buffers, mode=self._mode if stage_index in self._stages_on_device else StageMode.off_device, is_tied_weight_copy=stage_index in self._stages_on_device and stage_index not in self._stages_owned, + tied_parameter_duplicate_buffers=tied_weight_duplicate_buffers, weight_buffer_shared_with=weight_buffer_shared_with, ) @@ -533,17 +552,43 @@ def _get_buffer_placement(self, num_shared_buffers: int | None) -> tuple[list[se } return buffer_contents, buffer_indices - def _get_tied_parameters(self, stage_ends) -> dict[str, "TiedParameter"]: + def _get_tied_parameters(self) -> dict[str, "TiedParameter"]: tied_parameters = {} - for name, (meta, layer_indexes) in self._base_model.get_tied_weights().items(): - Assert.eq(list(layer_indexes), sorted(layer_indexes)) - Assert.incl(meta, list(self._base_model[layer_indexes[0]].parameters())) - stage_indexes = sorted({np.searchsorted(stage_ends, i, side="right").item() for i in layer_indexes}) + for name, metas in self._base_model.get_tied_parameters().items(): + if len(metas) <= 1: + continue + stage_indexes = [self._parameter_stages[meta.tensor_name] for meta in metas] + # TODO: Ambiguous if multiple instances are on the same stage? + Assert.eq( + sorted(stage_indexes), + stage_indexes, + msg="Tied parameters should be provided in the order they appear in the model.", + ) + for meta in metas[1:]: + # TODO: Improve. Compare initializations? (Not currently possible) + if ( + len(meta.dims) != len(metas[0].dims) + or any(dim != dim_ for dim, dim_ in zip(meta.dims, metas[0].dims, strict=True)) + or meta.sequence_tensor_parallel != metas[0].sequence_tensor_parallel + ): + raise ValueError( + f"Tied parameter group `{name}` has incompatible tied parameters {metas[0]} and {meta}." + ) + if ( + meta.requires_grad != metas[0].requires_grad + or meta.lr_scale != metas[0].lr_scale + or meta.param_weight_decay != metas[0].param_weight_decay + ): + logger.warning( + f"Tied parameters `{metas[0]}` and `{meta}` in tied parameter group `{name}` have different optimization parameters." + f" Only those of `{metas[0].tensor_name}` will be used." + ) + all_ranks = {self._stage_pipeline_ranks[stage_index] for stage_index in stage_indexes} tied_parameters[name] = TiedParameter( name=name, - meta=meta, + metas=tuple(metas), all_ranks=all_ranks, on_device=self._config.distributed.pipeline_rank in all_ranks, main_stage=stage_indexes[0], @@ -555,11 +600,11 @@ def _get_tied_parameters(self, stage_ends) -> dict[str, "TiedParameter"]: class TiedParameter: name: str # Parameter definition. - meta: ParameterMeta + metas: tuple[ParameterMeta, ...] # Whether the local rank is involved at all. on_device: bool # Process group for reduction. - group: ProcessGroup | None = dataclasses.field(init=False) + group: ProcessGroup | None = dataclasses.field(repr=False, init=False) all_ranks: set[int] # The index of the main stage. main_stage: int diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index bb3133256..9f5543590 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -47,6 +47,7 @@ def setup( # noqa grad_buffers: list[torch.Tensor | None] | None = None, mode: StageMode = StageMode.training, is_tied_weight_copy: bool = False, + tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None = None, weight_buffer_shared_with: collections.abc.Sequence["Stage"] = (), ) -> None: super().setup( @@ -56,6 +57,7 @@ def setup( # noqa weight_buffers=weight_buffers, grad_buffers=grad_buffers, mode=mode, + tied_parameter_duplicate_buffers=tied_parameter_duplicate_buffers, ) self._is_tied_weight_copy = is_tied_weight_copy if self._mode.support_forward: @@ -68,6 +70,9 @@ def setup( # noqa self._accumulators = [] with torch.enable_grad(): for meta in self._parameter_metas: + if meta.tensor_name in self._tied_parameter_duplicates: + # Already handled in the main stage. + continue buffer = self.get_parameter_buffer(meta.tensor_name) if not buffer.requires_grad: continue @@ -140,7 +145,6 @@ def forward( # TODO: Handle variable shape. output_global = output - # TODO ====== Use ====== kwargs["hidden_states"][self._layers[i].module_name] = { "layer_type": type(layer).__name__, "tensor": output_global, @@ -227,7 +231,7 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] ): check_parallel_match(output, self._distributed.tensor_group, f"layer {self._layers[i].module_name} fw") if self._config.debug_layer_outputs: - name = f"layer {self._layers[i].module_name} fw" + name = f"{self._layers[i].module_name} fw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" if (nms := kwargs.get("micro_batch_splits", 1)) > 1: @@ -260,7 +264,7 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any ) ) if self._config.debug_layer_gradients: - name = f"layer {self._layers[i].module_name} bw" + name = f"{self._layers[i].module_name} bw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" if (nms := kwargs.get("micro_batch_splits", 1)) > 1: diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 4778780ee..96d80ce06 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -31,6 +31,7 @@ def __init__( layers: list[Layer], distributed_config: DistributedConfig, index: int, + tied_parameter_duplicates: typing.Iterable[str] = (), ): super().__init__(config) self._distributed_config = distributed_config.validate() @@ -39,9 +40,10 @@ def __init__( self._is_setup = False self._index = index self._layers = layers + self._tied_parameter_duplicates = set(tied_parameter_duplicates) - parameter_metas, frozen_metas = self._get_parameter_metas() - self._parameter_metas = parameter_metas + frozen_metas + parameter_metas, frozen_metas, duplicate_metas = self._get_parameter_metas() + self._parameter_metas = parameter_metas + frozen_metas + duplicate_metas self._fsdps = [] if parameter_metas: self._fsdps.append( @@ -106,6 +108,7 @@ def setup( weight_buffers: list[torch.Tensor | None] | None, grad_buffers: list[torch.Tensor | None] | None, mode: StageMode = StageMode.training, + tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None, ) -> None: assert not self._is_setup distributed.check_config(self._distributed_config) @@ -142,7 +145,11 @@ def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) - module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) + if meta.tensor_name in self._tied_parameter_duplicates: + assert tied_parameter_duplicate_buffers is not None + module._parameters[key] = tied_parameter_duplicate_buffers.pop(meta.tensor_name) + else: + module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) i += 1 i = 0 @@ -150,6 +157,7 @@ def _replace(module: torch.nn.Module): layer.apply(_replace) Assert.eq(i, len(self._parameter_metas)) + assert not tied_parameter_duplicate_buffers, tied_parameter_duplicate_buffers.keys() def initialize_weights(self) -> None: # TODO: Avoid all the _on_device checks @@ -172,6 +180,9 @@ def initialize_weights(self) -> None: ] for meta in metas: + if meta.tensor_name in self._tied_parameter_duplicates: + # Initialization is not managed by this stage. + continue fsdp = self._fsdps[fsdp_index := self._fsdp_index[meta.tensor_name]] parameter = weight_shards_split[fsdp_index][meta.tensor_name] # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) @@ -309,24 +320,31 @@ def _export_shard( for fsdp, shard in zip(self._fsdps, shards, strict=True): yield from fsdp.export_shard(shard, data_type) - def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta]]: + def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta], list[ParameterMeta]]: # Get all the stage parameters, # then separate the parameters with and without weight decay, # and squeeze the non-tensor parallel and sequence parallel ones in the middle. # This allows running the optimizer, grad norm and sequence_parallel reduction on contiguous buffers. parameter_metas: list[ParameterMeta] = [] frozen_metas: list[ParameterMeta] = [] + duplicate_metas: list[ParameterMeta] = [] meta: ParameterMeta for layer in self._layers: - for name, meta in layer.named_parameters(): + for meta in layer.parameters(): Assert.custom(isinstance, meta, ParameterMeta) Assert.eq(meta.dtype, self._distributed_config.optimization_dtype.torch) - if meta.requires_grad: + if meta.tensor_name in self._tied_parameter_duplicates: + duplicate_metas.append(meta) + elif meta.requires_grad: parameter_metas.append(meta) else: frozen_metas.append(meta) - return self._reorder_parameter_metas(parameter_metas), self._reorder_parameter_metas(frozen_metas) + return ( + self._reorder_parameter_metas(parameter_metas), + self._reorder_parameter_metas(frozen_metas), + self._reorder_parameter_metas(duplicate_metas), + ) @classmethod def _reorder_parameter_metas(cls, parameter_metas): diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 58449f207..133b3206b 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -324,7 +324,7 @@ def _preprocess_data( for micro_batch in range(batch_config.sequential_micro_batches): micro_batch_data = next(data_iterator) if not preprocessed: - micro_batch_data = self._multi_stage.base_model.preprocess( + micro_batch_data = self._multi_stage.base_model.preprocess_batch( micro_batch_data, context.schedule.preprocessed_meta, phase=context.phase, @@ -339,11 +339,6 @@ def _preprocess_data( num_micro_batches=batch_config.sequential_micro_batches, micro_batch_splits=batch_config.micro_batch_splits, ) - for name, tied_parameter in self._tied_parameters.items(): - if tied_parameter.on_device: - kwargs[name] = self._stages[tied_parameter.main_stage].get_parameter_buffer( - tied_parameter.meta.tensor_name - ) data_index = context.schedule.get_data_index(micro_batch, micro_batch_split) if self._stages_owned[0]: context.inputs[context.schedule.get_step(StepType.forward, 0, data_index).global_index] = input_ diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index 57621a848..530df950e 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -1,4 +1,6 @@ import collections +import functools +import typing import torch.nn @@ -30,8 +32,9 @@ def __init__( lr_scale=lr_scale, peft=peft, ) + self.extend( - layers := [ + [ self._config.block.get_layer( distributed_config, hidden_dim, @@ -41,18 +44,24 @@ def __init__( for _ in range(self._config.num_blocks) ] ) + + @functools.cached_property + def _layers_with_namespace(self) -> list[Layer]: + # This needs to be in a property because `module_name` is set after `__init__`. # Wrap all blocks in a namespace using the unique module name of the first one. - namespace = layers[0].module_name if self._config.num_blocks > 0 else "" - # Note: Pytorch won't redundantly register modules because it doesn't look into lists. - self._layers_with_namespace = [ - LayerWithNamespace(sublayer, namespace) for layer in layers for sublayer in layer.get_layers() - ] + namespace = self[0].module_name if self._config.num_blocks > 0 else "" + return [LayerWithNamespace(sublayer, namespace) for layer in self for sublayer in layer.get_layers()] def get_layers(self) -> list["Layer"]: return self._layers_with_namespace + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + self._layers_with_namespace[0].preprocess(batch, kwargs) + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self[0].get_loss_definitions(count=count * self.num_blocks) if self._config.num_blocks > 0 else [] + return ( + self[0].get_loss_definitions(count=count * self._config.num_blocks) if self._config.num_blocks > 0 else [] + ) class PatternBlockSequence[ConfigType: PatternBlockSequenceConfig](BlockBase[ConfigType], torch.nn.ModuleList): @@ -75,7 +84,7 @@ def __init__( peft=peft, ) self.extend( - layers := [ + [ self._config.blocks[name].get_layer( distributed_config, hidden_dim, @@ -85,24 +94,31 @@ def __init__( for name in self._config.expanded_pattern ] ) + + @functools.cached_property + def _layers_with_namespace(self) -> list[Layer]: + # This needs to be in a property because `module_name` is set after `__init__`. # Wrap each set of blocks with identical config in a namespace # using the unique module name of the first such block. - # Note: Pytorch won't redundantly register modules because it doesn't look into lists. - self._layers_with_namespace = [ - LayerWithNamespace(sublayer, layers[self._config.preprocessing_layers[name]].module_name) - for name, layer in zip(self._config.expanded_pattern, layers) + return [ + LayerWithNamespace(sublayer, self[self._config.preprocessing_layers[name]].module_name) + for name, layer in zip(self._config.expanded_pattern, self) for sublayer in layer.get_layers() ] - def get_layers(self) -> list["Layer"]: + def get_layers(self) -> list[Layer]: return self._layers_with_namespace + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + for _, index in self._config.preprocessing_layers.items(): + self._layers_with_namespace[index].preprocess(batch, kwargs) + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: # TODO: Prevent name conflicts. return sum( ( self[self._config.preprocessing_layers[name]].get_loss_definitions(count=count * count_) - for name, count_ in collections.Counter(self.expanded_pattern).items() + for name, count_ in collections.Counter(self._config.expanded_pattern).items() ), [], ) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 08dd5a815..8b19db66a 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -175,5 +175,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) ) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self.mixer.preprocess(batch, kwargs) + self.mlp.preprocess(batch, kwargs) + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 2c5032967..d2fbc4909 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,7 +1,8 @@ +import abc import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import LossDef, ModuleConfig +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -10,30 +11,18 @@ from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.vision.config import VisionEncoderConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.layers.language_model.embedding import LanguageModelEmbedding - from fast_llm.layers.language_model.head import LanguageModelHead + from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction -# class LanguageModelLossNames: -# language_model_loss = "language_model_loss" -# z_loss = "z_loss" -# dpo_loss = "dpo_loss" -# distil_lm_loss = "distillation_language_model_loss" # the next token perdiciton of combined distillation loss -# distillation_loss = "distillation_loss" - - class LanguageModelKwargs(BlockKwargs): - token_ids = "token_ids" position_ids = "position_ids" - embedding_map = "embedding_map" # TODO: These are generic labels = "labels" - tokens = "tokens" phase = "phase" chosen_spans = "chosen_spans" rejected_spans = "rejected_spans" @@ -48,10 +37,6 @@ class LanguageModelEmbeddingsConfig(BlockConfig): desc="Configuration for the word embedding (weight).", hint=FieldHint.architecture, ) - vision_encoder: VisionEncoderConfig = Field( - desc="Configuration for the vision encoder that transforms images into embeddings.", - hint=FieldHint.optional, - ) position_embeddings: OptionalParameterConfig = Field( desc="Configuration for the word embedding (weight).", hint=FieldHint.architecture, @@ -122,7 +107,7 @@ def get_layer( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - ): + ) -> "LanguageModelHeadBase": return self.layer_class( self, distributed_config, @@ -132,8 +117,13 @@ def get_layer( peft=peft, ) + @property + @abc.abstractmethod + def max_prediction_distance(self) -> int: + pass + -@config_class(dynamic_type={LanguageModelHeadBaseConfig: "default"}) +@config_class(dynamic_type={LanguageModelHeadBaseConfig: "language_model_head"}) class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): _abstract = False normalization: NormalizationConfig = Field( @@ -193,9 +183,9 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - enable_dpo: bool | None = Field( - default=False, - desc="Whether to enable DPO loss", + dpo_reference_model: str | None = Field( + default=None, + desc="Name of the reference model to use for dpo.", hint=FieldHint.feature, ) dpo_beta: float | None = Field( @@ -203,11 +193,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Beta value for DPO loss.", hint=FieldHint.feature, ) - dpo_reference_model: str | None = Field( - default=None, - desc="Name of the reference model to use for dpo.", - hint=FieldHint.feature, - ) distillation_model: str | None = Field( default=None, desc="Name of the reference model to use for knowledge distillation." @@ -253,6 +238,15 @@ def _validate(self) -> None: else: self.language_model_loss_factor = 0.0 super()._validate() + assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both + + @property + def max_prediction_distance(self) -> int: + return 1 + + @property + def enable_dpo(self) -> bool: + return self.dpo_reference_model is not None @config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) @@ -260,7 +254,6 @@ class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): _abstract = False # Needs to be `DecoderBlockConfig` for the `return_input` interface. # TODO: Make a generic wrapper for returning input instead? - # TODO ====== Tied weight ====== block: DecoderBlockConfig = Field( desc="Configuration for the decoder block before each head.", hint=FieldHint.architecture, @@ -276,7 +269,6 @@ class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # TODO ====== Adjust ====== prediction_loss_coefficient: list[float] | None = Field( default=None, desc="Loss coefficient for each prediction head.", @@ -297,11 +289,9 @@ def layer_class(self) -> "type[MultiTokenPrediction]": return MultiTokenPrediction - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - # TODO ====== Wrong ====== - return self.block.get_loss_definitions(count=count * self.prediction_heads) + self.head.get_loss_definitions( - count=count * self.prediction_heads - ) + @property + def max_prediction_distance(self) -> int: + return self.prediction_heads @config_class() @@ -311,8 +301,8 @@ class LanguageModelConfig(ModuleConfig): desc="Configuration for the language model decoder.", hint=FieldHint.architecture, ) - embeddings_layer: LanguageModelEmbeddingsConfig = Field() - output_layer: LanguageModelHeadBaseConfig = Field() + embeddings: LanguageModelEmbeddingsConfig = Field() + head: LanguageModelHeadBaseConfig = Field() # TODO: Allow overriding in sub-models? peft: PeftConfig = Field( desc="Configuration for parameter-efficient fine tuning.", @@ -331,17 +321,3 @@ class LanguageModelConfig(ModuleConfig): " Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.", hint=FieldHint.testing, ) - - # def __len__(self) -> int: - # return len(self.decoder) + 2 * self.output_layer.prediction_heads - - # def __getitem__(self, index: int) -> BlockConfig: - # if index <= 0: - # Assert.eq(index, 0) - # return self.embeddings_layer - # elif index <= len(self.decoder): - # return self.decoder[index - 1] - # else: - # # Start at the last decoder layer so all MTP heads are treated similarly. - # index - len(self.decoder) - # return self.embeddings_layer diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 42b7e3d6c..4b0e3d102 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,3 +1,4 @@ +import abc import functools import logging import typing @@ -6,7 +7,7 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.core.ops import split_op +from fast_llm.core.ops import gather_op, split_op from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim @@ -22,10 +23,10 @@ from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( LanguageModelEmbeddingsConfig, + LanguageModelHeadBaseConfig, LanguageModelHeadConfig, LanguageModelKwargs, ) -from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -34,7 +35,13 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelHeadConfig](Block[ConfigType]): +class LanguageModelHeadBase[ConfigType: LanguageModelHeadBaseConfig](Block[ConfigType]): + @abc.abstractmethod + def get_output_weights(self) -> list[torch.Tensor]: + pass + + +class LanguageModelHead[ConfigType: LanguageModelHeadConfig](LanguageModelHeadBase[ConfigType]): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). TODO: Cleanup (dynamic type? composition?) @@ -99,15 +106,12 @@ def __init__( self._vocab_dim = TensorDim( "vocab", embeddings_config.vocab_size, self._parallel_dim if self._vocab_parallel else None ) - # Only the first head defines the output weights - if self._prediction_distance == 0 and not self._config.tied_weight: - # untie embedding weights - self.output_weights = self._config.output_weight.get_parameter( - (self._vocab_dim, self._hidden_dim), - default_initialization=init_normal_(std=self._hidden_size**-0.5), - lr_scale=self._lr_scale, - peft=self._peft, - ) + self.output_weights = self._config.output_weight.get_parameter( + (self._vocab_dim, self._hidden_dim), + default_initialization=init_normal_(std=self._hidden_size**-0.5), + lr_scale=self._lr_scale, + peft=self._peft, + ) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -116,7 +120,7 @@ def forward( if self._is_last_head: return TensorMeta.from_dims( (scalar_dim,), - tensor_name="Loss", + tensor_name=f"{self.module_name} output", reductions=( (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.AVG), ), @@ -185,7 +189,7 @@ def _forward_backward( self._parallel_dim.size if self._sequence_parallel_logits else 1 ) - output_weights = self._get_output_weights(kwargs) + output_weights = self.output_weights loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( ln_output.detach(), targets, output_weights, grad_output, kwargs, losses ) @@ -244,12 +248,8 @@ def _get_targets( targets = None return targets - def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - if self._config.tied_weight: - return kwargs[WORD_EMBEDDINGS_WEIGHT] - if self._prediction_distance > 0: - return kwargs[OUTPUT_WEIGHTS] - return self.output_weights + def get_output_weights(self) -> list[torch.Tensor]: + return [self.output_weights] def _logits_cross_entropy_forward_backward_split( self, @@ -266,6 +266,14 @@ def _logits_cross_entropy_forward_backward_split( ) if targets is None: # TODO: Make a proper way of returning the model output. + loss = loss.detach() + if kwargs.get("global_logits"): + if self._vocab_parallel: + loss = gather_op(loss, self._parallel_dim.group, 2) + elif self._sequence_parallel_logits: + loss = gather_op( + loss, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 + ) kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss return None, None else: @@ -423,9 +431,9 @@ def _logits_cross_entropy_forward_backward( if dpo_loss is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) if self._config.distillation_model is not None and distillation_loss is not None: - losses[self._distillation_language_model_loss_name].append(distillation_loss.detach()) + losses[self._distillation_loss_name].append(distillation_loss.detach()) if self._config.distillation_model is not None and lm_loss is not None: - losses[self._distillation_loss_name].append(lm_loss.detach()) + losses[self._distillation_language_model_loss_name].append(lm_loss.detach()) return loss, output_parallel_linear_backward(grad, context) if self.training else None @@ -467,7 +475,9 @@ def _distillation_loss_name(self) -> str: def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] if self._config.logit_z_loss: - LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) + loss_defs.append( + LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) + ) if self._config.enable_dpo: loss_defs.append( LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py new file mode 100644 index 000000000..9a3bef195 --- /dev/null +++ b/fast_llm/layers/language_model/language_model.py @@ -0,0 +1,61 @@ +import logging +import typing + +from fast_llm.config import Configurable +from fast_llm.engine.base_model.base_model import Layer, LayerBase +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.language_model.config import LanguageModelConfig +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding + +logger = logging.getLogger(__name__) + + +class LanguageModel[ConfigType: LanguageModelConfig](Configurable[ConfigType], LayerBase): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + ): + super().__init__(config, distributed_config) + + self._hidden_dim = TensorDim("hidden", config.embeddings.hidden_size) + self.embeddings: LanguageModelEmbedding = self._config.embeddings.get_layer( + distributed_config, + hidden_dim=self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + self.decoder = self._config.decoder.get_layer( + distributed_config, + self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + self.head = self._config.head.get_layer( + distributed_config, + self._config.embeddings, + hidden_dim=self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + + def get_layers(self) -> list["Layer"]: + return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() + + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + self.embeddings.preprocess(batch, kwargs) + self.decoder.preprocess(batch, kwargs) + self.head.preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + return ( + self.embeddings.get_loss_definitions(count) + + self.decoder.get_loss_definitions(count) + + self.head.get_loss_definitions(count) + ) diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index 79555d866..e0eb8175d 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -1,6 +1,10 @@ +import functools +import typing + import torch from fast_llm.engine.base_model.base_model import Layer, LayerWithNamespace +from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import BlockBase @@ -62,16 +66,27 @@ def __init__( ] ) + @functools.cached_property + def _layers_with_namespace(self) -> list[Layer]: # Wrap all blocks in a namespace using the unique module name of the first one. + # This needs to be in a property because `module_name` is set after `__init__`. namespace = self.blocks[0].module_name - # Note: Pytorch won't redundantly register modules because it doesn't look into lists. - self._blocks_with_namespace = [ - LayerWithNamespace(sublayer, namespace) for layer in self.blocks for sublayer in layer.get_layers() - ] + return [LayerWithNamespace(sublayer, namespace) for layer in self.blocks for sublayer in layer.get_layers()] def get_layers(self) -> list[Layer]: return [ module - for block, head in zip(self._blocks_with_namespace, self.heads, strict=True) + for block, head in zip(self._layers_with_namespace, self.heads, strict=True) for module in (block, head) ] + + def get_output_weights(self) -> list[torch.Tensor]: + return sum((head.get_output_weights() for head in self.heads), []) + + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + self._layers_with_namespace[0].preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self.blocks[0].get_loss_definitions(count=count * self._config.prediction_heads) + [ + loss_definition for head in self.heads for loss_definition in head.get_loss_definitions(count=count) + ] diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 6721daea2..1e57f3b8c 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -10,7 +10,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelConfig +from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig from fast_llm.models.gpt.conversion.config import ( AprielHybridSSMCheckpointFormat, AutoGPTHuggingfaceCheckpointFormat, @@ -148,41 +148,42 @@ class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): def _validate(self) -> None: if self.batch.sequence_length is None: # TODO: Drop this. - self.batch.sequence_length = self.model.base_model.embeddings_layer.num_position_embeddings + self.batch.sequence_length = self.model.base_model.embeddings.num_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() - if self.model.base_model.embeddings_layer.position_embeddings.enabled: - Assert.geq(self.model.base_model.embeddings_layer.num_position_embeddings, self.batch.sequence_length) + if self.model.base_model.embeddings.position_embeddings.enabled: + Assert.geq(self.model.base_model.embeddings.num_position_embeddings, self.batch.sequence_length) - distillation_model = self.model.base_model.output_layer.distillation_model - dpo_reference_model = self.model.base_model.output_layer.dpo_reference_model - - if self.model.base_model.output_layer.enable_dpo: - assert dpo_reference_model is not None - Assert.none(distillation_model) + # TODO: Avoid digging inside the model. + head = self.model.base_model.head + if isinstance(head, MultiTokenPredictionConfig): + prediction_heads = head.prediction_heads + head = head.head else: - Assert.none(dpo_reference_model) + prediction_heads = 1 - if distillation_model is None and dpo_reference_model is None: - Assert.empty(self.reference_models) - else: - assert distillation_model is None or dpo_reference_model is None # currently don't support both - expected_names = {name for name in (distillation_model, dpo_reference_model) if name is not None} - Assert.eq(self.reference_models.keys(), expected_names) + expected_names = {name for name in (head.distillation_model, head.dpo_reference_model) if name is not None} + Assert.eq(self.reference_models.keys(), expected_names) for reference_model in self.reference_models.values(): - output_layer = reference_model.model.base_model.output_layer - Assert.none(output_layer.distillation_model) - Assert.none(output_layer.dpo_reference_model) + reference_head = reference_model.model.base_model.head + if isinstance(reference_head, MultiTokenPredictionConfig): + reference_prediction_heads = reference_head.prediction_heads + reference_head = reference_head.heads + else: + reference_prediction_heads = 1 + Assert.geq(reference_prediction_heads, prediction_heads) + + Assert.none(reference_head.distillation_model) + Assert.none(reference_head.dpo_reference_model) # TODO: Support more LM head features. - Assert.none(output_layer.cross_entropy_splits) + Assert.none(reference_head.cross_entropy_splits) Assert.eq( - reference_model.model.base_model.embeddings_layer.vocab_parallel, - self.model.base_model.embeddings_layer.vocab_parallel, + reference_model.model.base_model.embeddings.vocab_parallel, + self.model.base_model.embeddings.vocab_parallel, ) - Assert.geq(output_layer.prediction_heads, output_layer.prediction_heads) @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 5b32c481d..4b9849630 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -24,11 +24,11 @@ class AprielDiscreteMamba2Converter: @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { "type": "discrete_mamba_2", "state_size": config["ssm_cfg"]["d_state"], - "d_inner": config["ssm_cfg"].get("d_inner") or hidden_size * config["ssm_cfg"].get("expand", 1), + "d_inner": config["ssm_cfg"].get("d_inner") or config["hidden_size"] * config["ssm_cfg"].get("expand", 1), "add_linear_biases": config["ssm_cfg"]["bias"], "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, "n_qk_heads": config["ssm_cfg"]["n_qk_heads"], @@ -117,17 +117,17 @@ def get_converters( class AprielMamba2Converter: @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { "type": "mamba_2", "state_size": config["ssm_cfg"]["d_state"], - "d_inner": config["ssm_cfg"].get("d_inner") or hidden_size * config["ssm_cfg"].get("expand", 1), + "d_inner": config["ssm_cfg"].get("d_inner") or config["hidden_size"] * config["ssm_cfg"].get("expand", 1), "add_linear_biases": config["ssm_cfg"]["bias"], "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, - "d_xb": config["ssm_cfg"].get("d_xb") or hidden_size, + "d_xb": config["ssm_cfg"].get("d_xb") or config["hidden_size"], "dt_layer": {"bias": {"enabled": config["ssm_cfg"].get("dt_proj_bias", True)}}, "dt_rank": ( - math.ceil(hidden_size) + math.ceil(config["hidden_size"]) if config["ssm_cfg"].get("dt_rank", "auto") == "auto" else config["ssm_cfg"]["dt_rank"] ), @@ -246,8 +246,8 @@ class AprielBlockConverter: _config_classes = {value: key for key, value in layout_names.items()} @classmethod - def import_config(cls, config: dict, hidden_size: int, layout_name: str = "t") -> dict: - return cls._converter_classes[cls._config_classes[layout_name]].import_config(config, hidden_size) + def import_config(cls, config: dict, layout_name: str = "t") -> dict: + return cls._converter_classes[cls._config_classes[layout_name]].import_config(config) @classmethod def export_config(cls, config) -> dict: @@ -270,18 +270,18 @@ class AprielDecoderConverter(MistralDecoderConverter): block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: layout = config["hybrid_block_layout"] if len(layout) == 1: return { - "block": cls.block_converter_class.import_config(config, hidden_size, layout[0]), + "block": cls.block_converter_class.import_config(config, layout[0]), "num_blocks": config["num_hidden_layers"], } else: return { "type": "pattern", "blocks": { - layout_name: cls.block_converter_class.import_config(config, hidden_size, layout_name) + layout_name: cls.block_converter_class.import_config(config, layout_name) for layout_name in set(layout) }, "pattern": layout, @@ -317,14 +317,13 @@ def get_converters( fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, - fast_llm_layer_start: int = 1, ) -> list[WeightConverter]: converters = [] for block_index in range(config.num_blocks): block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] converters += cls.block_converter_class.get_converters( block_config, - f"{fast_llm_prefix}.{block_index+fast_llm_layer_start}", + f"{fast_llm_prefix}.{block_index}", f"{hf_prefix}.{block_index}", drop_on_export, ) diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 629a3ceed..786d923f2 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -184,7 +184,7 @@ def import_weight( class LlamaAttentionConverter: @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: try: rope_type = config["rope_scaling"]["rope_type"] except (KeyError, TypeError): @@ -224,7 +224,7 @@ def import_config(cls, config: dict, hidden_size: int) -> dict: "dropout": config["attention_dropout"], } if out["head_size"] is None: - out["head_size"] = div(hidden_size, out["heads"]) + out["head_size"] = div(config["hidden_size"], out["heads"]) return out @@ -360,9 +360,9 @@ class LlamaBlockConverter: hf_norm_2_name: typing.ClassVar[str] = "post_attention_layernorm" @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { - "mixer": cls.mixer_converter_class.import_config(config, hidden_size), + "mixer": cls.mixer_converter_class.import_config(config), "mlp": cls.mlp_converter_class.import_config(config), "normalization": cls.normalization_converter_class.import_config(config), } @@ -412,9 +412,9 @@ class LlamaDecoderConverter: block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { - "block": cls.block_converter_class.import_config(config, hidden_size), + "block": cls.block_converter_class.import_config(config), "num_blocks": config["num_hidden_layers"], } @@ -434,13 +434,12 @@ def get_converters( fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, - fast_llm_layer_start: int = 1, ) -> list[WeightConverter]: converters = [] for block_index in range(config.num_blocks): converters += cls.block_converter_class.get_converters( config.block, - f"{fast_llm_prefix}.{block_index+fast_llm_layer_start}", + f"{fast_llm_prefix}.{block_index}", f"{hf_prefix}.{block_index}", drop_on_export, ) @@ -477,47 +476,32 @@ class LlamaHeadConverter: @classmethod def import_config(cls, config: dict) -> dict: - return { - "tied_weight": config["tie_word_embeddings"], - "normalization": cls.normalization_converter_class.import_config(config), - } + return {"normalization": cls.normalization_converter_class.import_config(config)} @classmethod def export_config(cls, config: LanguageModelHeadConfig) -> dict: Assert.custom(isinstance, config, LanguageModelHeadConfig) - return safe_merge_dicts( - cls.normalization_converter_class.export_config(config.normalization), - {"tie_word_embeddings": config.tied_weight}, - ) + return cls.normalization_converter_class.export_config(config.normalization) @classmethod def get_converters( - cls, config: LanguageModelHeadConfig, block_config: DecoderBlockConfig, fast_llm_prefix: str, start_index: int + cls, + config: LanguageModelHeadConfig, + exported_config: dict, + fast_llm_prefix: str, ) -> list[WeightConverter]: - converters = [] - for prediction_distance in range(config.prediction_heads): - if prediction_distance > 0: - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", - "", - drop_on_export=True, - ) - converters += cls.normalization_converter_class.get_converters( + return [ + *cls.normalization_converter_class.get_converters( config.normalization, - f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", + f"{fast_llm_prefix}.final_norm", f"model.norm", - drop_on_export=prediction_distance > 0, - ) - converters.append( + ), get_parameter_converter( - f"{fast_llm_prefix}.{start_index}.output_weights", + f"{fast_llm_prefix}.output_weights", "lm_head.weight", - drop_on_import=config.tied_weight, - ) - ) - - return converters + drop_on_import=exported_config["tie_word_embeddings"], + ), + ] class LlamaBaseModelConverter: @@ -529,41 +513,30 @@ class LlamaBaseModelConverter: @classmethod def import_config(cls, config: dict) -> dict: return { - "embeddings_layer": cls.embeddings_converter_class.import_config(config), - "decoder": cls.decoder_converter_class.import_config(config, config["hidden_size"]), - "output_layer": cls.head_converter_class.import_config(config), + "embeddings": cls.embeddings_converter_class.import_config(config), + "decoder": cls.decoder_converter_class.import_config(config), + "head": cls.head_converter_class.import_config(config), + "tied_embedding_weight": config["tie_word_embeddings"], } @classmethod def export_config(cls, config: GPTBaseModelConfig) -> dict: Assert.custom(isinstance, config, GPTBaseModelConfig) return safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings_layer), + cls.embeddings_converter_class.export_config(config.embeddings), cls.decoder_converter_class.export_config(config.decoder), - cls.head_converter_class.export_config(config.output_layer), + cls.head_converter_class.export_config(config.head), + {"tie_word_embeddings": config.tied_embedding_weight}, ) @classmethod - def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: + def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: return [ - *cls.embeddings_converter_class.get_converters(config.embeddings_layer, "layers.0", "model"), - *cls.decoder_converter_class.get_converters(config.decoder, "layers", "model.layers"), - *cls.head_converter_class.get_converters( - config.output_layer, config.decoder[len(config.decoder) - 1], "layers", len(config.decoder) + 1 - ), + *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), + *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), + *cls.head_converter_class.get_converters(config.head, exported_config, "head"), ] - def _create_weight_converters( - self, - ) -> list[WeightConverter]: - base_model_config = self._model.config.base_model - self.embeddings_converter_class.get_converters(base_model_config.embeddings_layer, "layers.0", "model") - converters = self.decoder_converter_class.get_converters(base_model_config.decoder, "layers", "model.layers") - self.head_converter_class.get_converters( - base_model_config.decoder, base_model_config.decoder.block, "layers", len(base_model_config.decoder) + 1 - ) - return converters - class LlamaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: GPTModel diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index 4673f5b2c..bfc7d5569 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -16,8 +16,8 @@ class MistralAttentionConverter(LlamaAttentionConverter): @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: - return safe_merge_dicts(super().import_config(config, hidden_size), {"window_size": config["sliding_window"]}) + def import_config(cls, config: dict) -> dict: + return safe_merge_dicts(super().import_config(config), {"window_size": config["sliding_window"]}) @classmethod def export_config(cls, config: AttentionConfig) -> dict: diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 194c263f9..5b83fed69 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -4,63 +4,93 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import WeightConverter -from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.config import LanguageModelHeadConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig, MultiTokenPredictionConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaBaseModelConverter, + LlamaBlockConverter, + LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, get_parameter_converter, ) -from fast_llm.utils import safe_merge_dicts +from fast_llm.utils import Assert, safe_merge_dicts class MTPLlamaHeadConverter(LlamaHeadConverter): @classmethod def import_config(cls, config: dict) -> dict: - return safe_merge_dicts( - super().import_config(config), - {"prediction_heads": config["prediction_heads"]}, - ) + return { + "type": "multi_token_prediction", + "block": LlamaBlockConverter.import_config(config), + "head": super().import_config(config), + "prediction_heads": config["prediction_heads"], + } @classmethod - def export_config(cls, config: LanguageModelHeadConfig) -> dict: + def export_config(cls, config: MultiTokenPredictionConfig) -> dict: + Assert.custom(isinstance, config, MultiTokenPredictionConfig) return safe_merge_dicts( - super().export_config(config), + super().export_config(config.head), {"prediction_heads": config.prediction_heads}, ) @classmethod def get_converters( - cls, config: LanguageModelHeadConfig, block_config: DecoderBlockConfig, fast_llm_prefix: str, start_index: int + cls, + config: LanguageModelHeadConfig, + exported_config: dict, + fast_llm_prefix: str, ) -> list[WeightConverter]: converters = [] for prediction_distance in range(config.prediction_heads): - if prediction_distance > 0: - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", - f"model.mtp_heads.{prediction_distance - 1}", - ) + converters += cls.block_converter_class.get_converters( + config.block, + f"{fast_llm_prefix}.blocks.{prediction_distance}", + ( + f"model.layers.{exported_config["num_hidden_layers"]-1}" + if prediction_distance == 0 + else f"model.mtp_heads.{prediction_distance - 1}" + ), + ) converters += cls.normalization_converter_class.get_converters( - config.normalization, - f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", + config.head.normalization, + f"{fast_llm_prefix}.heads.{prediction_distance}.final_norm", f"model.mtp_norms.{prediction_distance}", ) converters.append( get_parameter_converter( - f"{fast_llm_prefix}.{start_index}.output_weights", + f"{fast_llm_prefix}.heads.0.output_weights", "lm_head.weight", - drop_on_import=config.tied_weight, + drop_on_import=exported_config["tie_word_embeddings"], ) ) return converters +class MTPLlamaDecoderConverter(LlamaDecoderConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "block": cls.block_converter_class.import_config(config), + "num_blocks": config["num_hidden_layers"] - 1, + } + + @classmethod + def export_config(cls, config: FixedBlockSequenceConfig) -> dict: + # TODO: Support PatternBlockSequenceConfig with compatible configs. + Assert.custom(isinstance, config, FixedBlockSequenceConfig) + return safe_merge_dicts( + cls.block_converter_class.export_config(config.block), + {"num_hidden_layers": config.num_blocks + 1}, + ) + + class MTPLlamaBaseModelConverter(LlamaBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[MTPLlamaDecoderConverter]] = MTPLlamaDecoderConverter head_converter_class: typing.ClassVar[type[MTPLlamaHeadConverter]] = MTPLlamaHeadConverter diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 2f99ae4c3..9215e6dc7 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -32,7 +32,7 @@ class HuggingfaceGPTModelForCausalLM(HuggingfaceBaseModelForCausalLM): # _supports_cache_class = False # _tied_weights_keys = [] - def forward( + def inner_forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, @@ -79,7 +79,7 @@ def forward( # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) - batch = self.fast_llm_base_model.preprocess( + batch = self.fast_llm_base_model.preprocess_batch( GPTBatch(input_ids, sequence_lengths=sequence_lenghts), phase=PhaseType.inference, iteration=iteration ) ((input_, kwargs),) = batch @@ -99,10 +99,15 @@ def forward( else: kwargs["output_hidden_states"] = False + kwargs["global_logits"] = True + self._inference_runner.forward(input_, kwargs, iteration=iteration) # TODO: Make a proper way of returning the model output. - logits = kwargs["logits"] + if kwargs[AttentionKwargs.sequence_first]: + logits = kwargs["logits"].transpose(0, 1) + else: + logits = kwargs["logits"] # TODO: convert hidden state form dict to list to be the same as with HFs hidden_states = None diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9ddcf6300..2d6352686 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -4,8 +4,7 @@ import torch from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import BaseModel, Layer -from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner @@ -13,6 +12,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -class GPTBaseModel[ConfigType: GPTBaseModelConfig](BaseModel[ConfigType]): +class GPTBaseModel[ConfigType: GPTBaseModelConfig](LanguageModel[ConfigType], BaseModel[ConfigType]): """ A transformer-based language model generalizing the GPT model architecture. """ @@ -33,35 +33,12 @@ def __init__( config: GPTBaseModelConfig, distributed_config: DistributedConfig, ): - self._hidden_dim = TensorDim("hidden", config.embeddings_layer.hidden_size) super().__init__(config, distributed_config) - - hidden_dim = TensorDim("hidden", self.embeddings_layer.hidden_size) - self.embedding = self._config.embeddings_layer.get_layer( - distributed_config, - hidden_dim=hidden_dim, - lr_scale=None, - peft=self._config.peft, - ) - self.decoder = self._config.decoder.get_layer( - distributed_config, - hidden_dim, - lr_scale=None, - peft=self._config.peft, - ) - self.head = self._config.output_layer.get_layer( - distributed_config, - self._config.embeddings_layer, - hidden_dim=hidden_dim, - lr_scale=None, - peft=self._config.peft, - ) - if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron( - param, self._config.decoder.block, config.embeddings_layer.hidden_size + param, self._config.decoder.block, config.embeddings.hidden_size ) # Noqa # TODO ====== Vision ====== @@ -69,9 +46,6 @@ def __init__( # self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) # self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) - def get_layers(self) -> list["Layer"]: - return self.embedding.get_layers() + self.decoder.get_layers() + self.head.get_layers() - # TODO ====== Vision ====== # def get_vision_layers(self) -> list[Layer]: # vit_layers = [ @@ -88,7 +62,7 @@ def get_layers(self) -> list["Layer"]: def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: - # TODO ====== Remove (Move batch splitting elsewhere) ====== + # TODO Remove (Move batch splitting elsewhere) # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence if isinstance(batch_meta, GPTBatchConfig): @@ -99,7 +73,7 @@ def preprocess_meta( else: micro_batch_size, sequence_length = batch_meta.shape if phase != PhaseType.inference: - sequence_length -= self._config.output_layer.prediction_heads + sequence_length -= self._config.head.prediction_heads micro_sequence_length = sequence_length truncate_documents = True @@ -229,7 +203,7 @@ def preprocess_meta( return preprocessed_meta - def preprocess( + def preprocess_batch( self, batch: GPTBatch, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, @@ -238,7 +212,7 @@ def preprocess( iteration: int, metrics: dict | None = None, ) -> list[tuple[torch.Tensor, dict]]: - # TODO ====== Move batch splitting elsewhere, align interface with LayerBase ====== + # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup if preprocessed_meta is None: @@ -247,7 +221,7 @@ def preprocess( _, common_kwargs = preprocessed_meta[0] sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size sequence_first = common_kwargs[AttentionKwargs.sequence_first] - prediction_heads: int = self._config.output_layer.prediction_heads + max_prediction_distance = self._config.head.max_prediction_distance batch.token_ids = batch.token_ids.to( device=self._distributed.device, @@ -261,7 +235,7 @@ def preprocess( (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta ] - reference_batch = reference_model.fast_llm_model.base_model.preprocess( + reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch( batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration ) @@ -304,10 +278,10 @@ def preprocess( if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: - labels = token_ids[sequence_offset : sequence_k + prediction_heads] + labels = token_ids[sequence_offset : sequence_k + max_prediction_distance] else: # TODO: Avoid multiple contiguous calls? - labels = token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() + labels = token_ids[:, sequence_offset : sequence_k + max_prediction_distance].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config if batch.loss_masking_spans is not None: @@ -317,20 +291,26 @@ def preprocess( if not spans.numel(): continue valid_spans = spans[ - (spans[:, 0] <= sequence_k + prediction_heads - 1) & (spans[:, 1] >= sequence_offset) + (spans[:, 0] <= sequence_k + max_prediction_distance - 1) + & (spans[:, 1] >= sequence_offset) ] if valid_spans.numel(): # if span is partially within the sequence, truncate parts of spans that are outside of the sequence valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) + valid_spans[:, 1].clamp_(max=sequence_k + max_prediction_distance - 1) valid_spans -= sequence_offset + loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - labels[start : end + 1, idx] = -100 + loss_mask[start : end + 1, idx] = False else: - labels[idx, start : end + 1] = -100 + loss_mask[idx, start : end + 1] = False + if self._config.output_layer.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask + labels = torch.where(loss_mask, labels, -100) + kwargs[LanguageModelKwargs.labels] = labels + kwargs.update(reference_logits[i]) - # TODO ====== Preference spans ====== if batch.chosen_spans is not None: chosen_valid_spans = [] for spans in batch.chosen_spans: @@ -375,12 +355,6 @@ def preprocess( # labels_cloned = True # labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) # Loss-masking for distillation losses - if self._config.distillation_model is not None: - loss_mask = torch.ones_like(labels, dtype=torch.bool) - loss_mask = torch.where(labels == -100, False, loss_mask) - kwargs[LanguageModelKwargs.loss_mask] = loss_mask - kwargs[LanguageModelKwargs.labels] = labels - kwargs.update(reference_logits[i]) # TODO ====== Vision ====== # if self._config.vision_encoder.enabled: @@ -400,69 +374,23 @@ def preprocess( # else [[]] * kwargs[AttentionKwargs.micro_batch_size] # ) # kwargs[LanguageModelKwargs.tokens] = tokens - - # TODO ====== Turn into super() call ====== - self.embedding.preprocess(tokens, kwargs) - self.decoder.preprocess(tokens, kwargs) - self.head.preprocess(tokens, kwargs) - - # TODO ====== Vision ====== # image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) # if image_patches is not None: # preprocessed.append((image_patches, kwargs)) # else: # preprocessed.append((tokens, kwargs)) + self.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) return preprocessed - # TODO ====== Vision ====== - # @property - # def embedding(self) -> LanguageModelEmbedding: - # return self.layers[self.embedding_layer_index] - - # @property - # def transformer_layers(self) -> list[TransformerBlock]: - # return self.layers[self.embedding_layer_index + 1 : -1] - - # @property - # def embedding_layer_index(self) -> int: - # if self._config.vision_encoder.enabled: - # return self._config.vision_encoder.transformer.num_layers + 2 - # else: - # return 0 - - def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: - # TODO ====== Tied weights ====== + def get_tied_parameters(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: + # TODO: Integrate to the `LayerBase` interface, move to `LanguageModel`, `MultiTokenPrediction`? + output_weights = self.head.get_output_weights() if self._config.tied_embedding_weight: - raise NotImplementedError() - return {} - # if self._config.output_layer.tied_weight: - # return { - # WORD_EMBEDDINGS_WEIGHT: ( - # self.embedding.word_embeddings_weight, - # # TODO ====== Vision ====== - # # (self.embedding_layer_index, *self.model_head_indices), - # (0, *self.model_head_indices), - # ) - # } - # elif self._config.output_layer.prediction_heads > 1: - # return { - # OUTPUT_WEIGHTS: ( - # self.model_head.output_weights, - # tuple(self.model_head_indices), - # ) - # } - # else: - # return {} - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return ( - self.embeddings_layer.get_loss_definitions(count) - + self.decoder.get_loss_definitions(count) - + self.output_layer.get_loss_definitions(count) - ) + output_weights.insert(0, self.embeddings.word_embeddings_weight) + return {output_weights[0].tensor_name: output_weights} if len(output_weights) > 1 else {} class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 171a9a9b2..b4994b43e 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -22,13 +22,14 @@ def _get_sampling_parameters( parameters = super()._get_sampling_parameters(parameters, _return_dict=True) parameters.update( { - "vocab_size": self._config.model.base_model.embeddings_layer.vocab_size, + "vocab_size": self._config.model.base_model.embeddings.vocab_size, "sequence_length": self._config.batch.sequence_length, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - "use_preference_loss_spans": self._config.model.base_model.output_layer.enable_dpo, + # OK since DPO is not supported for MTP. + "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), "cross_document_attention": self._config.batch.cross_document_attention, "truncate_documents": self._config.batch.truncate_documents, - "extra_tokens": self._config.model.base_model.output_layer.prediction_heads, + "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } ) diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index 5c0a2216c..40c4cfa87 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -1216,6 +1216,41 @@ def __init__(self, config: AprielHybridSSMConfig, **kwargs): # Initialize weights and apply final processing self.post_init() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache and past_key_values is None: + # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) + output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **flash_attn_kwargs, + ) + past_key_values: HybridMambaAttentionDynamicCache = output.past_key_values + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + return output + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 7bc215670..0de823e2a 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,9 +7,8 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs +from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -21,9 +20,7 @@ def _reverse_kl_loss( loss_mask: torch.Tensor | None, teacher_softmax_temperature: float = 1.0, ): - scaled_target = target / teacher_softmax_temperature - - scaled_target = torch.clamp(target, min=-50, max=50) + scaled_target = torch.clamp(target / teacher_softmax_temperature, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) with torch.enable_grad(): @@ -101,56 +98,61 @@ def _lm_head( @pytest.mark.slow @pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) @pytest.mark.parametrize( - ("config_dict", "distributed_config_dict", "loss_masking"), + ("config_dict", "distributed_config_dict", "loss_masking", "prediction_heads"), ( - ({}, {}, False), - ({}, {"compute_dtype": DataType.bfloat16}, False), - ({"embeddings_layer": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False), - ({"sequence_first": True}, {}, False), - ({"output_layer": {"logit_z_loss": 1e-3}}, {}, False), - ({"output_layer": {"logits_scale_factor": 5.0}}, {}, False), - ({"output_layer": {"tied_weight": False}}, {}, False), - ({"output_layer": {"prediction_heads": 2}}, {}, False), - ({}, {}, True), + ({}, {}, False, 1), + ({}, {"compute_dtype": DataType.bfloat16}, False, 1), + ({"embeddings": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False, 1), + ({"sequence_first": True}, {}, False, 1), + ({"head": {"logit_z_loss": 1e-3}}, {}, False, 1), + ({"head": {"logits_scale_factor": 5.0}}, {}, False, 1), + ({"tied_embedding_weight": True}, {}, False, 1), + ({}, {}, False, 2), + ({}, {}, True, 1), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.cross_entropy, } }, {}, False, + 1, ), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.reverse_kl, } }, {}, False, + 1, ), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + "language_model_loss_factor": 1.0, } }, {}, True, + 1, ), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.reverse_kl, } }, {}, True, + 1, ), ), ) @@ -159,24 +161,37 @@ def test_lm_head( config_dict: dict[str, typing.Any], distributed_config_dict: dict[str, typing.Any], loss_masking: bool, + prediction_heads: int, ): + torch.cuda.manual_seed(0) + torch.manual_seed(0) + head_config = { + "cross_entropy_implementation": cross_entropy_impl, + "normalization": {"type": "rms_norm"}, + } config = GPTBaseModelConfig.from_dict( { "decoder": { "num_blocks": 0, }, - "embeddings_layer": { + "embeddings": { "vocab_size": VOCAB_SIZE, "hidden_size": HIDDEN_SIZE, }, - "output_layer": { - "cross_entropy_implementation": cross_entropy_impl, - "normalization": {"type": "rms_norm"}, - }, + "head": ( + head_config + if prediction_heads == 1 + else { + "type": "multi_token_prediction", + "head": head_config, + "prediction_heads": prediction_heads, + } + ), }, config_dict, update_type=UpdateType.update, ) + head_config: LanguageModelHeadConfig = config.head if prediction_heads == 1 else config.head.head model, distributed = get_base_model( GPTModelConfig.from_dict( @@ -188,22 +203,22 @@ def test_lm_head( ) sequence_first = config.sequence_first or ( - config.output_layer.cross_entropy_splits is not None and config.output_layer.cross_entropy_splits > 1 + head_config.cross_entropy_splits is not None and head_config.cross_entropy_splits > 1 ) input_ = torch.randn( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=( distributed.config.optimization_dtype.torch - if config.embeddings_layer.full_precision_residual + if config.embeddings.full_precision_residual else distributed.config.compute_dtype.torch ), device=distributed.device, requires_grad=True, ) label_shape = ( - (SEQUENCE_LENGTH + config.output_layer.prediction_heads - 1, BATCH_SIZE) + (SEQUENCE_LENGTH + config.head.max_prediction_distance - 1, BATCH_SIZE) if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.output_layer.prediction_heads - 1) + else (BATCH_SIZE, SEQUENCE_LENGTH + config.head.max_prediction_distance - 1) ) if loss_masking: loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) @@ -213,7 +228,7 @@ def test_lm_head( AttentionKwargs.sequence_first: sequence_first, AttentionKwargs.grad_output: 1.0, } - if config.output_layer.distillation_model is None: + if head_config.distillation_model is None: target = torch.randint( 0, VOCAB_SIZE, @@ -226,42 +241,43 @@ def test_lm_head( kwargs[LanguageModelKwargs.labels] = target else: - assert config.output_layer.prediction_heads == 1 + assert config.head.max_prediction_distance == 1 target = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), dtype=input_.dtype, device=distributed.device, ) - kwargs[f"{config.output_layer.distillation_model}_logits"] = target + kwargs[f"{head_config.distillation_model}_logits"] = target if loss_mask is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask - if config.output_layer.tied_weight or config.output_layer.prediction_heads > 1: - logit_weight = ( + if config.tied_embedding_weight or config.head.max_prediction_distance > 1: + logit_weight = torch.nn.Parameter( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device - ) - .normal_(config.embeddings_layer.hidden_size**-0.5) - .requires_grad_(True) + ).normal_(config.embeddings.hidden_size**-0.5) ) - kwargs[WORD_EMBEDDINGS_WEIGHT if config.output_layer.tied_weight else OUTPUT_WEIGHTS] = logit_weight else: logit_weight = None - # TODO ====== Fix ====== - for prediction_distance, layer_index in enumerate(model.model_head_indices): + for prediction_distance, head in enumerate((model.head,) if prediction_heads == 1 else model.head.heads): # Prepare the LM head - head: LanguageModelHead = model[layer_index] Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) - stage = get_stage([head], distributed) + is_duplicate = config.tied_embedding_weight or prediction_distance > 0 + stage = get_stage( + [head], + distributed, + tied_parameter_duplicates=[head.output_weights.tensor_name] if is_duplicate else [], + tied_parameter_duplicate_buffers={head.output_weights.tensor_name: logit_weight} if is_duplicate else {}, + ) # Get reference outputs and grads - if logit_weight is None: - logit_weight = head.output_weights - else: + if is_duplicate: logit_weight.grad_buffer = torch.full_like(logit_weight, float("nan")) logit_weight.param_grad_is_zero = True + else: + logit_weight = head.output_weights ref_input = input_.detach().requires_grad_() ref_rms_weight = head.final_norm.weight.detach().requires_grad_() @@ -277,9 +293,9 @@ def test_lm_head( loss_mask, rms_weight=ref_rms_weight, logit_weight=ref_logit_weight, - logit_scale_factor=config.output_layer.logits_scale_factor, - logit_z_loss=config.output_layer.logit_z_loss, - distillation_loss_implementation=config.output_layer.distillation_loss_implementation, + logit_scale_factor=head_config.logits_scale_factor, + logit_z_loss=head_config.logit_z_loss, + distillation_loss_implementation=head_config.distillation_loss_implementation, ) # Prepare LM head inputs @@ -292,13 +308,18 @@ def test_lm_head( output_grad = torch.randn_like(shared_hidden) loss_name = f"language_model_loss_{prediction_distance}" if prediction_distance > 0 else "language_model_loss" - Assert.eq(head._loss_name, loss_name) loss_keys = {loss_name} if ref_z_loss is not None: - loss_keys.add("z_loss") - if config.output_layer.distillation_model is not None: + loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") + if head_config.distillation_model is not None: loss_keys.add("distillation_loss") - loss_keys.add("distil_lm_loss") + if head_config.language_model_loss_factor > 0: + loss_keys.add("distillation_language_model_loss") + + Assert.eq( + {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, + {loss_key: 1 for loss_key in loss_keys}, + ) losses = {key: [] for key in loss_keys} output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) @@ -306,7 +327,7 @@ def test_lm_head( threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 - ) * config.output_layer.logits_scale_factor + ) * head_config.logits_scale_factor Assert.eq(losses.keys(), loss_keys) Assert.eq(len(losses[loss_name]), 1) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 97a618cfc..5db18d7ff 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -328,7 +328,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) test_input = torch.randint( 0, - model_ref.config.fast_llm_config.base_model.embeddings_layer.vocab_size, + model_ref.config.fast_llm_config.base_model.embeddings.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda", diff --git a/tests/test_attention.py b/tests/test_attention.py index dceaa8282..a19cba8f0 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -2,13 +2,13 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.attention.attention import Attention from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs -from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert -def test_varlen_preprocessor(): +def test_varlen_preprocessing(): sequence_lengths = [torch.tensor([8, 13, 4, 11], dtype=torch.int32), torch.tensor([11, 16, 9], dtype=torch.int32)] # First micro-sequence: # [0...7,0...3] + [0...10,0] -> [0,8,12,23,24] @@ -28,8 +28,12 @@ def test_varlen_preprocessor(): ] micro_sequence_length = 12 sequence_length = 36 - varlen_preprocessor = FlashAttnVarlenPreprocessor( - AttentionConfig(head_size=64), DistributedConfig(compute_dtype="bfloat16") + attention = Attention( + AttentionConfig(head_size=64), + DistributedConfig(compute_dtype="bfloat16"), + hidden_dim=TensorDim("", 1), + lr_scale=None, + peft=None, ) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { @@ -40,6 +44,6 @@ def test_varlen_preprocessor(): AttentionKwargs.sequence_length: sequence_length, AttentionKwargs.sequence_lengths: sequence_lengths, } - varlen_preprocessor.preprocess(torch.empty(1, device="cpu"), kwargs) + attention.preprocess(torch.empty(1, device="cpu"), kwargs) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_config.py b/tests/test_config.py index 6d2583ba3..326200537 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -74,7 +74,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config = GPTModelConfig.from_dict( { "base_model": { - "embeddings_layer": { + "embeddings": { "hidden_size": 1024, # Default }, "decoder": { @@ -92,7 +92,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, # Default }, - "output_layer": {"tied_weight": False}, + "tied_embedding_weight": False, }, "multi_stage": {"zero_stage": 3}, "distributed": {"compute_dtype": "bfloat16"}, @@ -105,7 +105,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config.save_metadata(save_config) base_model_update = { - "embeddings_layer": {"hidden_size": 512, "vocab_size": 1000}, + "embeddings": {"hidden_size": 512, "vocab_size": 1000}, "decoder": { "block": { "mixer": { @@ -127,51 +127,50 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): } ) serialized_config = pretrained_config.model.to_dict() - expected_config = {"type": "gpt", "distributed": DistributedConfig().to_dict()} + expected_config = {"distributed": DistributedConfig().to_dict()} if load_config == ModelConfigType.fast_llm: expected_config["multi_stage"] = {"zero_stage": 3} expected_config["distributed"].update({"seed": 1234, "compute_dtype": "float16"}) if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { - "embeddings_layer": { + "embeddings": { "hidden_size": 512, "vocab_size": 1000, }, "decoder": { - "type": "fixed", "block": { - "type": "decoder", "mixer": { - "type": "attention", - "rotary": {"type": "default"}, "window_size": 32, "head_groups": 1, }, "mlp": { - "type": "mlp", "intermediate_size": 4096, # Implicit default, default value "activation": "silu", # Implicit default, non-default value }, - "normalization": {"type": "rms_norm", "implementation": "triton"}, + "normalization": {"implementation": "triton"}, }, "num_blocks": 12, }, - "output_layer": {"tied_weight": False, "normalization": {"type": "layer_norm"}}, - "peft": {"type": "lora", "freeze_others": False}, + "tied_embedding_weight": False, + "peft": {"freeze_others": False}, } else: - base_model_update["decoder"]["type"] = "fixed" - base_model_update["decoder"]["block"]["type"] = "decoder" - base_model_update["decoder"]["block"]["normalization"]["type"] = "layer_norm" - base_model_update["decoder"]["block"]["mixer"]["type"] = "attention" - base_model_update["decoder"]["block"]["mixer"]["rotary"] = {"type": "none"} - base_model_update["decoder"]["block"]["mlp"] = {"type": "mlp"} - base_model_update["output_layer"] = {"normalization": {"type": "layer_norm"}} - base_model_update["peft"] = {"type": "lora", "freeze_others": False} expected_config["base_model"] = base_model_update - check_equal_nested(serialized_config, expected_config) + check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) + + +def _trim_type(config: dict): + # Serialization inserts dynamic types, we ignore them during the comparison. + if "type" in config: + del config["type"] + for key in list(config): + if isinstance(value := config[key], dict): + _trim_type(value) + if not value: + del config[key] + return config def _check_dim(dim: DistributedDim, name: str, rank: int, size: int, global_rank: int): diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index cc5a60a8a..407b47767 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -5,7 +5,6 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel -from fast_llm.layers.decoder.block import DecoderBlock from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup @@ -42,14 +41,14 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, DecoderBlock) else 0 - for layer in model_ref.base_model.layers + sum(p.numel() for p in layer.unwrap().mlp.parameters()) if layer.module_name.startswith("decoder") else 0 + for layer in model_ref.base_model.get_layers() ] # Make sure each layer has its own buffer so the check below works. Assert.eq( - num_stages := len(model_ref.base_model.layers), - len(model_frozen.base_model.layers), + num_stages := len(model_ref.base_model.get_layers()), + len(model_frozen.base_model.get_layers()), len(model_ref.stages), len(model_frozen.stages), ) diff --git a/tests/utils/compare_tensor_logs.py b/tests/utils/compare_tensor_logs.py index 51ee66d31..1c8ebd76a 100644 --- a/tests/utils/compare_tensor_logs.py +++ b/tests/utils/compare_tensor_logs.py @@ -79,7 +79,9 @@ def _compare_dict_keys(self, dict_ref, dict_test, errors, name): keys_test = set(dict_test) if keys_ref != keys_test: errors.append( - f">>>> {name} do not match. Missing = {keys_ref - keys_test}, extra = {keys_test - keys_ref}." + f">>>> {name} do not match." + f"\n Missing = \n{"\n * ".join(keys_ref - keys_test)}" + f"\n Extra = \n{"\n * ".join(keys_test - keys_ref)}" ) # Avoid set to preserve ordering. diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 863be2cae..fac595905 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -110,7 +110,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon DistributedTestingConfig( name="ce4", compare="simple", - config_args=["model.base_model.output_layer.cross_entropy_splits=4"], + config_args=["model.base_model.head.cross_entropy_splits=4"], num_gpus=1, compare_config=_compare_layer_mismatch, ), @@ -228,8 +228,8 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.embeddings_layer.vocab_parallel=False", - "model.base_model.output_layer.cross_entropy_splits=4", + "model.base_model.embeddings.vocab_parallel=False", + "model.base_model.head.cross_entropy_splits=4", ], num_gpus=2, compare_config=_compare_layer_match, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index aa8100126..6b313aa8a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -189,7 +189,7 @@ def _update_and_add_testing_config( }, "model": { "base_model": { - "embeddings_layer": { + "embeddings": { "word_embeddings": init_1, "position_embeddings": {"enabled": True, **init_1}, "hidden_size": 256, @@ -215,7 +215,8 @@ def _update_and_add_testing_config( }, "num_blocks": 2, }, - "output_layer": {"output_weight": init_1}, + "head": {"output_weight": init_1}, + "tied_embedding_weight": True, }, "multi_stage": { "debug_param_init": _LOG_LEVEL, @@ -324,7 +325,7 @@ def _update_and_add_testing_config( updates={ ("model", "base_model", "decoder", "block", "mixer", "head_groups"): 4, ("model", "base_model", "decoder", "block", "mixer", "rotary", "type"): "default", - ("model", "base_model", "embeddings_layer", "position_embeddings", "enabled"): False, + ("model", "base_model", "embeddings", "position_embeddings", "enabled"): False, }, megatron_args=[ "--group-query-attention", @@ -354,8 +355,8 @@ def _update_and_add_testing_config( ("model", "base_model", "decoder", "block", "mlp", "activation"): "silu", ("model", "base_model", "decoder", "block", "mlp", "add_linear_biases"): False, ("model", "base_model", "decoder", "block", "normalization", "type"): "rms_norm", - ("model", "base_model", "output_layer", "normalization", "type"): "rms_norm", - ("model", "base_model", "output_layer", "tied_weight"): False, + ("model", "base_model", "head", "normalization", "type"): "rms_norm", + ("model", "base_model", "tied_embedding_weight"): False, }, megatron_args=[ "--swiglu", @@ -436,12 +437,22 @@ def _update_and_add_testing_config( }, ) + +_llama_block = MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]["block"] + + _update_and_add_testing_config( # Tests multi-token prediction, custom HF model and converter. "llama", "mtp_llama", updates={ - ("model", "base_model", "output_layer", "prediction_heads"): 2, + ("model", "base_model", "head"): { + "type": "multi_token_prediction", + "block": _llama_block, + "head": MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["head"], + "prediction_heads": 2, + }, + ("model", "base_model", "decoder", "num_blocks"): 1, }, # Megatron doesn't support multi-token prediction. megatron_args=None, @@ -456,6 +467,8 @@ def _update_and_add_testing_config( ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, compare_factor=2.0, + # Arg update for cross-entropy splits doesn't work here. + skip_tests=("ce4", "ms"), ) _update_and_add_testing_config( @@ -549,8 +562,6 @@ def _update_and_add_testing_config( compare_factor=2.0, ) -_llama_block = MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]["block"] - _update_and_add_testing_config( # Tests hybrid Mamba, llamba converter. diff --git a/tests/utils/utils.py b/tests/utils/utils.py index b086c291f..098f0240e 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -11,7 +11,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier -from fast_llm.engine.base_model.base_model import BaseModel, Layer +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.logging import configure_logging from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig @@ -32,22 +32,26 @@ def result_path(): def get_base_model(config: FastLLMModelConfig): # Create a base model (and distributed). # Using a full model config so we have the model type and distributed config in the same argument. - base_model = config.get_model_class().base_model_class(config.base_model, config.distributed) + base_model = config.get_base_model_config_class().get_base_model(config.base_model, config.distributed) base_model.setup(distributed := Distributed(config.distributed)) return base_model, distributed -def get_stage(base_model: BaseModel | list[Layer], distributed: Distributed): +def get_stage( + layers: list[Layer], + distributed: Distributed, + tied_parameter_duplicates: typing.Iterable[str] = (), + tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None = None, +): # Create a fast-llm stage which allocates and initializes meta tensors correctly. stage = Stage( config=StageConfig(), - layers=base_model, + layers=layers, distributed_config=distributed.config, - begin=0, - end=1, index=0, + tied_parameter_duplicates=tied_parameter_duplicates, ) - stage.setup(distributed=distributed) + stage.setup(distributed=distributed, tied_parameter_duplicate_buffers=tied_parameter_duplicate_buffers) stage.initialize_weights() stage.restore_parameters() stage.reset_gradients() From 3d8497256b280c9cedde58a195d8c3ebd87c6eb2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 3 Oct 2025 23:41:28 -0400 Subject: [PATCH 14/15] model --- fast_llm/layers/block/block.py | 1 + fast_llm/layers/language_model/config.py | 33 +++-- .../layers/language_model/language_model.py | 40 ++++-- fast_llm/layers/vision/config.py | 79 +++++------ fast_llm/layers/vision/vision_encoder.py | 67 +++++++++ fast_llm/models/gpt/config.py | 6 + fast_llm/models/gpt/conversion/llama.py | 16 +-- fast_llm/models/gpt/model.py | 98 +------------ fast_llm/models/gpt/trainer.py | 10 -- fast_llm/models/multimodal/__init__.py | 0 fast_llm/models/multimodal/config.py | 89 ++++++++++++ fast_llm/models/multimodal/model.py | 133 ++++++++++++++++++ fast_llm/models/multimodal/trainer.py | 14 ++ tests/layers/test_lm_head.py | 2 +- 14 files changed, 398 insertions(+), 190 deletions(-) create mode 100644 fast_llm/layers/vision/vision_encoder.py create mode 100644 fast_llm/models/multimodal/__init__.py create mode 100644 fast_llm/models/multimodal/config.py create mode 100644 fast_llm/models/multimodal/model.py create mode 100644 fast_llm/models/multimodal/trainer.py diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index ab6cb22b0..67ce5eea9 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -103,6 +103,7 @@ def __init__( config: ConfigType, distributed_config: DistributedConfig, *, + # TODO: Review. Use `input_dim(s)` and `output_dim(s)` instead? hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index d2fbc4909..25fa2d91e 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -2,7 +2,6 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -16,6 +15,7 @@ if typing.TYPE_CHECKING: from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase + from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction @@ -41,12 +41,6 @@ class LanguageModelEmbeddingsConfig(BlockConfig): desc="Configuration for the word embedding (weight).", hint=FieldHint.architecture, ) - hidden_size: int = Field( - default=1024, - desc="Size of the model's main hidden dimension, e.g., for its input and output layers.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", @@ -295,24 +289,29 @@ def max_prediction_distance(self) -> int: @config_class() -class LanguageModelConfig(ModuleConfig): - # TODO: block +class LanguageModelConfig(BlockConfig): decoder: BlockSequenceConfig = Field( desc="Configuration for the language model decoder.", hint=FieldHint.architecture, ) - embeddings: LanguageModelEmbeddingsConfig = Field() - head: LanguageModelHeadBaseConfig = Field() - # TODO: Allow overriding in sub-models? - peft: PeftConfig = Field( - desc="Configuration for parameter-efficient fine tuning.", + embeddings: LanguageModelEmbeddingsConfig = Field( hint=FieldHint.architecture, + desc="Configuration for the language model embeddings.", + ) + head: LanguageModelHeadBaseConfig = Field( + hint=FieldHint.architecture, desc="Configuration for the language model head(s)." ) tied_embedding_weight: bool = Field( default=False, desc="Tie the output weights (logits) with the vocabulary embedding.", hint=FieldHint.architecture, ) + hidden_size: int = Field( + default=1024, + desc="Size of the model's main hidden dimension, e.g., for its input and output layers.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) sequence_first: bool | None = Field( default=None, desc="Override the default dimension ordering", @@ -321,3 +320,9 @@ class LanguageModelConfig(ModuleConfig): " Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.", hint=FieldHint.testing, ) + + @property + def layer_class(self) -> "type[LanguageModel]": + from fast_llm.layers.language_model.language_model import LanguageModel + + return LanguageModel diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 9a3bef195..56d41dc3a 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -1,52 +1,64 @@ import logging import typing -from fast_llm.config import Configurable -from fast_llm.engine.base_model.base_model import Layer, LayerBase +import torch + +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding logger = logging.getLogger(__name__) -class LanguageModel[ConfigType: LanguageModelConfig](Configurable[ConfigType], LayerBase): +class LanguageModel[ConfigType: LanguageModelConfig](BlockBase[ConfigType]): _config: ConfigType def __init__( self, config: ConfigType, distributed_config: DistributedConfig, + *, + # Unused, but required by the `BlockBase` interface. + hidden_dim: TensorDim | None = None, + lr_scale: float | None, + peft: PeftConfig | None, ): - super().__init__(config, distributed_config) - - self._hidden_dim = TensorDim("hidden", config.embeddings.hidden_size) + super().__init__( + config, + distributed_config, + hidden_dim=TensorDim("hidden", self._config.hidden_size), + lr_scale=lr_scale, + peft=peft, + ) self.embeddings: LanguageModelEmbedding = self._config.embeddings.get_layer( distributed_config, hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, + lr_scale=self._lr_scale, + peft=self._peft, ) self.decoder = self._config.decoder.get_layer( distributed_config, self._hidden_dim, - lr_scale=None, - peft=self._config.peft, + lr_scale=self._lr_scale, + peft=self._peft, ) self.head = self._config.head.get_layer( distributed_config, self._config.embeddings, hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, + lr_scale=self._lr_scale, + peft=self._peft, ) - def get_layers(self) -> list["Layer"]: + def get_layers(self) -> list[Layer]: return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() - def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(batch, kwargs) self.decoder.preprocess(batch, kwargs) diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py index fbd0ed807..1af986eef 100644 --- a/fast_llm/layers/vision/config.py +++ b/fast_llm/layers/vision/config.py @@ -1,15 +1,14 @@ -import enum import typing -from fast_llm.config import Config, Field, FieldHint, config_class -from fast_llm.engine.base_model.config import ModuleConfig +from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig from fast_llm.layers.common.linear.config import Convolution2DConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.decoder.config import MLPBaseConfig +from fast_llm.utils import Assert if typing.TYPE_CHECKING: - pass + from fast_llm.layers.vision.vision_encoder import VisionEncoder @config_class() @@ -51,12 +50,6 @@ class ImageNormalizationConfig(Config): ) -class VisionEncoderType(str, enum.Enum): - none = "none" - # TODO: better name? normalization, patch size, adapter can change based on implementation, no standard way currently. - pixtral = "pixtral" - - @config_class() class PatchConvolutionConfig(BlockConfig): _abstract = False @@ -81,26 +74,34 @@ class PatchConvolutionConfig(BlockConfig): @config_class(registry=True) -class VisionEncoderConfig(ModuleConfig): +class VisionEncoderConfig(BlockConfig): _abstract = False - patch_convolution_layer: PatchConvolutionConfig = Field( + patch_convolution: PatchConvolutionConfig = Field( desc="Configuration for the patch convolution layer.", hint=FieldHint.architecture, ) - adapter_layer: MLPBaseConfig = Field( + adapter: MLPBaseConfig = Field( desc="Configuration for the adapter layer.", hint=FieldHint.architecture, ) + # TODO: ====== Appropriate name?? ====== decoder: BlockSequenceConfig = Field( desc="Configuration for the vision decoder.", hint=FieldHint.architecture, ) - - type: VisionEncoderType = Field( - default=VisionEncoderType.none, - desc="Type of the vision encoder. Choices: none, pixtral.", + hidden_size: int = Field( + default=1024, + desc="Size of the vision encoder main hidden dimension.", hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), ) + + @property + def layer_class(self) -> "type[VisionEncoder]": + from fast_llm.layers.vision.vision_encoder import VisionEncoder + + return VisionEncoder + # transformer: TransformerConfig = Field( # desc="Configuration for the vision transformer architecture.", # hint=FieldHint.core, @@ -134,20 +135,20 @@ class VisionEncoderConfig(ModuleConfig): # desc="Whether to use bias in the adapter linear layer.", # hint=FieldHint.optional, # ) - image_normalization: ImageNormalizationConfig = Field( - desc="Configuration for the normalization layers applied to the image patches.", - hint=FieldHint.optional, - ) - image_break_token: int | None = Field( - default=None, - desc="Token id to separate image rows. If None, no token id is applied.", - hint=FieldHint.optional, - ) - image_end_token: int | None = Field( - default=None, - desc="Token id to indicate the end of an image. If None, no token id is applied.", - hint=FieldHint.optional, - ) + # image_normalization: ImageNormalizationConfig = Field( + # desc="Configuration for the normalization layers applied to the image patches.", + # hint=FieldHint.optional, + # ) + # image_break_token: int | None = Field( + # default=None, + # desc="Token id to separate image rows. If None, no token id is applied.", + # hint=FieldHint.optional, + # ) + # image_end_token: int | None = Field( + # default=None, + # desc="Token id to indicate the end of an image. If None, no token id is applied.", + # hint=FieldHint.optional, + # ) # adapter_lr_scale: float | None = Field( # default=None, # desc="Custom learning rate scale for the adapter weights.", @@ -166,19 +167,3 @@ class VisionEncoderConfig(ModuleConfig): # hint=FieldHint.optional, # valid=check_field(Assert.geq, 0), # ) - - def _validate(self) -> None: - with self._set_implicit_default(): - if self.adapter_init_method_std is None: - self.adapter_init_method_std = self.adapter_size**-0.5 - super()._validate() - - @property - def enabled(self) -> bool: - return self.type != VisionEncoderType.none - - -for name in VisionEncoderType: - # We need this because we are using the reserved field name `type`. - # TODO: Implement proper dynamic typing. - VisionEncoderConfig.register_subclass(name.value, VisionEncoderConfig) diff --git a/fast_llm/layers/vision/vision_encoder.py b/fast_llm/layers/vision/vision_encoder.py new file mode 100644 index 000000000..b4fa189d5 --- /dev/null +++ b/fast_llm/layers/vision/vision_encoder.py @@ -0,0 +1,67 @@ +import logging +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.vision.config import VisionEncoderConfig + +logger = logging.getLogger(__name__) + + +class VisionEncoder[ConfigType: VisionEncoderConfig](BlockBase[VisionEncoderConfig]): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + vision_hidden_dim = TensorDim("hidden", self._config.hidden_size) + super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) + self.patch_convolution = self._config.patch_convolution.get_layer( + distributed_config, + vision_hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + # TODO: ====== Appropriate name?? ====== + self.decoder = self._config.decoder.get_layer( + distributed_config, + vision_hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + # TODO: ====== Hidden dim ====== + self.adapter = self._config.adapter.get_layer( + distributed_config, + vision_hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + def get_layers(self) -> list["Layer"]: + return self.patch_convolution.get_layers() + self.decoder.get_layers() + self.adapter.get_layers() + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + self.patch_convolution.preprocess(batch, kwargs) + self.decoder.preprocess(batch, kwargs) + self.adapter.preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + return ( + self.patch_convolution.get_loss_definitions(count) + + self.decoder.get_loss_definitions(count) + + self.adapter.get_loss_definitions(count) + ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 1e57f3b8c..a901a0466 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -10,6 +10,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig from fast_llm.models.gpt.conversion.config import ( AprielHybridSSMCheckpointFormat, @@ -84,6 +85,11 @@ def micro_batch_splits(self) -> int: class GPTBaseModelConfig(LanguageModelConfig, BaseModelConfig): _abstract = False + # TODO: Allow overriding in sub-models? + peft: PeftConfig = Field( + desc="Configuration for parameter-efficient fine tuning.", + hint=FieldHint.architecture, + ) # Debug, to get an exact match with megatron init. use_megatron_initialization: bool = Field( default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 786d923f2..a92492260 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -449,19 +449,13 @@ def get_converters( class LlamaEmbeddingsConverter: @classmethod def import_config(cls, config: dict) -> dict: - return { - "vocab_size": config["vocab_size"], - "hidden_size": config["hidden_size"], - } + return {"vocab_size": config["vocab_size"]} @classmethod def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict: Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig) assert not config.position_embeddings.enabled - return { - "vocab_size": config.vocab_size, - "hidden_size": config.hidden_size, - } + return {"vocab_size": config.vocab_size} @classmethod def get_converters( @@ -516,6 +510,7 @@ def import_config(cls, config: dict) -> dict: "embeddings": cls.embeddings_converter_class.import_config(config), "decoder": cls.decoder_converter_class.import_config(config), "head": cls.head_converter_class.import_config(config), + "hidden_size": config["hidden_size"], "tied_embedding_weight": config["tie_word_embeddings"], } @@ -526,7 +521,10 @@ def export_config(cls, config: GPTBaseModelConfig) -> dict: cls.embeddings_converter_class.export_config(config.embeddings), cls.decoder_converter_class.export_config(config.decoder), cls.head_converter_class.export_config(config.head), - {"tie_word_embeddings": config.tied_embedding_weight}, + { + "tie_word_embeddings": config.tied_embedding_weight, + "hidden_size": config.hidden_size, + }, ) @classmethod diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2d6352686..158bbd92c 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -30,34 +30,14 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](LanguageModel[ConfigType], Ba def __init__( self, - config: GPTBaseModelConfig, + config: ConfigType, distributed_config: DistributedConfig, ): - super().__init__(config, distributed_config) + super().__init__(config, distributed_config, lr_scale=self._config.lr_scale, peft=self._config.peft) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) - param.init_parameter = get_init_megatron( - param, self._config.decoder.block, config.embeddings.hidden_size - ) # Noqa - - # TODO ====== Vision ====== - # if self._config.vision_encoder.enabled: - # self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) - # self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) - - # TODO ====== Vision ====== - # def get_vision_layers(self) -> list[Layer]: - # vit_layers = [ - # VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, block_index=idx + 1) - # for idx in range(self._config.vision_encoder.transformer.num_layers) - # ] - # return [ - # PatchConv(self._config.vision_encoder, self._tensor_space), - # *vit_layers, - # VisionAdapter(self._config.vision_encoder, self._tensor_space), - # MultiModalEmbedding(self._config, self._tensor_space), - # ] + param.init_parameter = get_init_megatron(param, self._config.decoder.block, config.hidden_size) # Noqa def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType @@ -123,33 +103,6 @@ def preprocess_meta( LanguageModelKwargs.mask_inputs: not truncate_documents, } - # TODO ====== Vision ====== - # if self._config.vision_encoder.enabled: - # try: - # max_image_size = batch_meta.max_image_size - # except AttributeError: - # max_image_size = 256 - # logger.warning("Inference mode: max_image_size not provided, defaulting to 256") - # vision_kwargs = { - # VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, - # VisionEncoderKwargs.max_image_size: max_image_size, - # VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, - # VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size, - # VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size, - # } - # vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden] - # vision_hidden_dims = ( - # (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) - # if sequence_first - # else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) - # ) - # vision_kwargs.update( - # { - # VisionTransformerKwargs.hidden_dims: vision_hidden_dims, - # } - # ) - # common_kwargs.update(vision_kwargs) - sequence_k_pasts = range( sequence_q_dim.size * self._distributed_config.sequence_data_rank, sequence_length, @@ -192,13 +145,6 @@ def preprocess_meta( reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs - # TODO ====== Vision ====== - # if self._config.vision_encoder.enabled: - # # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size - # preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) - # else: - # preprocessed_meta.append((tokens, kwargs)) - preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -342,44 +288,6 @@ def preprocess_batch( rejected_valid_spans.append(valid_spans) kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans - # TODO ====== Vision ====== - # if self._config.vision_encoder.enabled: - # if self._config.vision_encoder.image_break_token is not None: - # if not labels_cloned: - # labels = labels.clone() - # labels_cloned = True - # labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) - # if self._config.vision_encoder.image_end_token is not None: - # if not labels_cloned: - # labels = labels.clone() - # labels_cloned = True - # labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) - # Loss-masking for distillation losses - - # TODO ====== Vision ====== - # if self._config.vision_encoder.enabled: - # batch_images = ( - # batch.images if batch.images is not None else [[]] * kwargs[AttentionKwargs.micro_batch_size] - # ) - # kwargs[VisionEncoderKwargs.images] = [ - # [ - # img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) - # for img in images - # ] - # for images in batch_images - # ] - # kwargs[VisionEncoderKwargs.image_positions] = ( - # batch.image_positions - # if batch.image_positions is not None - # else [[]] * kwargs[AttentionKwargs.micro_batch_size] - # ) - # kwargs[LanguageModelKwargs.tokens] = tokens - # image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) - # if image_patches is not None: - # preprocessed.append((image_patches, kwargs)) - # else: - # preprocessed.append((tokens, kwargs)) - self.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index b4994b43e..2c71a5a16 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -33,14 +33,4 @@ def _get_sampling_parameters( } ) - # TODO ====== Vision ====== - # if self._config.model.base_model.vision_encoder.enabled: - # parameters.update( - # { - # "patch_size": self._config.model.base_model.vision_encoder.patch_size, - # "max_image_size": self._config.batch.max_image_size, - # "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, - # "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, - # } - # ) return parameters if _return_dict else GPTSamplingParameters(**parameters) diff --git a/fast_llm/models/multimodal/__init__.py b/fast_llm/models/multimodal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/models/multimodal/config.py b/fast_llm/models/multimodal/config.py new file mode 100644 index 000000000..2415734e4 --- /dev/null +++ b/fast_llm/models/multimodal/config.py @@ -0,0 +1,89 @@ +import logging +import typing + +from fast_llm.config import Field, FieldHint, FieldUpdate, config_class +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.vision.config import VisionEncoderConfig +from fast_llm.models.gpt.config import ( + GPTBaseModelConfig, + GPTBatchConfig, + GPTModelConfig, + GPTTrainerConfig, + PretrainedGPTModelConfig, +) + +if typing.TYPE_CHECKING: + from fast_llm.models.multimodal.huggingface import HuggingfaceMultiModalModelForCausalLM + from fast_llm.models.multimodal.model import MultiModalBaseModel, MultiModalModel, MultiModalModelInferenceRunner + from fast_llm.models.multimodal.trainer import MultiModalTrainer + +logger = logging.getLogger(__name__) + + +@config_class() +class MultiModalBatchConfig(GPTBatchConfig): + pass + + +@config_class() +class MultiModalBaseModelConfig(GPTBaseModelConfig): + vision_encoder: VisionEncoderConfig = Field( + hint=FieldHint.architecture, + desc="Configuration for the vision encoder.", + ) + + @property + def base_model_class(self) -> type["MultiModalBaseModel"]: + from fast_llm.models.multimodal.model import MultiModalBaseModel + + return MultiModalBaseModel + + +@config_class(dynamic_type={FastLLMModelConfig: "gpt"}) +class MultiModalModelConfig(GPTModelConfig): + _abstract = False + model_name: typing.ClassVar[str] = "gpt" + base_model: GPTBaseModelConfig = FieldUpdate() + # TODO: ====== Conversion ====== + checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + + @classmethod + def get_model_class(cls) -> type["MultiModalModel"]: + from fast_llm.models.multimodal.model import MultiModalModel + + return MultiModalModel + + @classmethod + def get_inference_runner_class(cls) -> type["MultiModalModelInferenceRunner"]: + from fast_llm.models.multimodal.model import MultiModalModelInferenceRunner + + return MultiModalModelInferenceRunner + + @classmethod + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceMultiModalModelForCausalLM"]: + from fast_llm.models.multimodal.huggingface import HuggingfaceMultiModalModelForCausalLM + + return HuggingfaceMultiModalModelForCausalLM + + +@config_class() +class PretrainedMultiModalModelConfig(PretrainedGPTModelConfig): + _abstract = False + model: MultiModalModelConfig = FieldUpdate() + + +@config_class(dynamic_type={RunnableConfig: "train_gpt", TrainerConfig: "gpt"}) +class MultiModalTrainerConfig(PretrainedMultiModalModelConfig, GPTTrainerConfig): + data: MultiModalDataConfig = FieldUpdate() + batch: MultiModalBatchConfig = FieldUpdate() + # TODO: Use dynamic model type? + reference_models: dict[str, PretrainedMultiModalModelConfig] = FieldUpdate() + + @classmethod + def get_trainer_class(cls) -> type["MultiModalTrainer"]: + from fast_llm.models.multimodal.trainer import MultiModalTrainer + + return MultiModalTrainer diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py new file mode 100644 index 000000000..7426191f7 --- /dev/null +++ b/fast_llm/models/multimodal/model.py @@ -0,0 +1,133 @@ +import logging +import typing + +import torch + +from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.inference.runner import InferenceRunner +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.models.gpt.model import GPTBaseModel, GPTModel +from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalBatchConfig, MultiModalModelConfig +from fast_llm.tensor import TensorMeta + +logger = logging.getLogger(__name__) + + +class MultiModalBaseModel[ConfigType: MultiModalBaseModelConfig](GPTBaseModel[ConfigType]): + """ + A transformer-based language model generalizing the GPT model architecture. + """ + + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + ): + super().__init__(config, distributed_config) + self.vision_encoder = self._config.vision_encoder.get_layer( + distributed_config, + self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + + def preprocess_meta( + self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType + ) -> list[tuple[TensorMeta, dict]]: + # TODO Remove (Move batch splitting elsewhere) + # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # try: + # max_image_size = batch_meta.max_image_size + # except AttributeError: + # max_image_size = 256 + # logger.warning("Inference mode: max_image_size not provided, defaulting to 256") + # vision_kwargs = { + # VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, + # VisionEncoderKwargs.max_image_size: max_image_size, + # VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, + # VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size, + # VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size, + # } + # vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden] + # vision_hidden_dims = ( + # (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + # if sequence_first + # else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + # ) + # vision_kwargs.update( + # { + # VisionTransformerKwargs.hidden_dims: vision_hidden_dims, + # } + # ) + # common_kwargs.update(vision_kwargs) + + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + # preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + # else: + # preprocessed_meta.append((tokens, kwargs)) + pass + + def preprocess_batch( + self, + batch: GPTBatch, + preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, + *, + phase: PhaseType, + iteration: int, + metrics: dict | None = None, + ) -> list[tuple[torch.Tensor, dict]]: + # TODO Move batch splitting elsewhere, align interface with LayerBase + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # if self._config.vision_encoder.image_break_token is not None: + # if not labels_cloned: + # labels = labels.clone() + # labels_cloned = True + # labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + # if self._config.vision_encoder.image_end_token is not None: + # if not labels_cloned: + # labels = labels.clone() + # labels_cloned = True + # labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) + # Loss-masking for distillation losses + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # batch_images = ( + # batch.images if batch.images is not None else [[]] * kwargs[AttentionKwargs.micro_batch_size] + # ) + # kwargs[VisionEncoderKwargs.images] = [ + # [ + # img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + # for img in images + # ] + # for images in batch_images + # ] + # kwargs[VisionEncoderKwargs.image_positions] = ( + # batch.image_positions + # if batch.image_positions is not None + # else [[]] * kwargs[AttentionKwargs.micro_batch_size] + # ) + # kwargs[LanguageModelKwargs.tokens] = tokens + # image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + # if image_patches is not None: + # preprocessed.append((image_patches, kwargs)) + # else: + # preprocessed.append((tokens, kwargs)) + pass + + +class MultiModalModel[ConfigType: MultiModalModelConfig](GPTModel[ConfigType]): + # TODO: Can we drop class? + pass + + +class MultiModalInferenceRunner(InferenceRunner): + model_class: typing.ClassVar[type[MultiModalModel]] = MultiModalModel + batch_config_class: typing.ClassVar[type[MultiModalBatchConfig]] = MultiModalBatchConfig diff --git a/fast_llm/models/multimodal/trainer.py b/fast_llm/models/multimodal/trainer.py new file mode 100644 index 000000000..c4071aafe --- /dev/null +++ b/fast_llm/models/multimodal/trainer.py @@ -0,0 +1,14 @@ +import logging + +from fast_llm.models.gpt.trainer import GPTTrainer +from fast_llm.models.multimodal.config import MultiModalTrainerConfig + +logger = logging.getLogger(__name__) + + +class MultiModalTrainer[ConfigType: MultiModalTrainerConfig](GPTTrainer[ConfigType]): + def _get_data(self) -> MultiModalData: + return MultiModalData( + config=self._config.data, + distributed_config=self._config.model.distributed, + ) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 0de823e2a..d65d33a8b 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -255,7 +255,7 @@ def test_lm_head( logit_weight = torch.nn.Parameter( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device - ).normal_(config.embeddings.hidden_size**-0.5) + ).normal_(config.hidden_size**-0.5) ) else: logit_weight = None From 8f8ef19e0b5ebeafd06c449ae39b5bc890a20595 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 3 Oct 2025 23:48:42 -0400 Subject: [PATCH 15/15] cleanup --- fast_llm/models/gpt/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 2c71a5a16..54ea13dc4 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -32,5 +32,4 @@ def _get_sampling_parameters( "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } ) - return parameters if _return_dict else GPTSamplingParameters(**parameters)