diff --git a/Dockerfile b/Dockerfile index 526026fa..6bc900ae 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,7 +38,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" triton==3.1.0 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM @@ -47,3 +47,8 @@ COPY --chmod=777 ./tests tests COPY --chmod=777 ./tools tools COPY --chmod=777 ./fast_llm_external_models fast_llm_external_models COPY --chmod=777 --exclude=./fast_llm/csrc/ ./fast_llm/ fast_llm/ + +# Set a dummy default user so we don't run in root by default. +# The image is still compatible with any user id. +RUN useradd user +USER user diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 2e4a57de..904325c5 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -28,7 +28,6 @@ optimizer: model: base_model: embeddings: - hidden_size: 4096 vocab_size: 32000 dropout: 0.0 decoder: @@ -58,6 +57,7 @@ model: normalization: type: rms_norm epsilon: 1.0e-05 + hidden_size: 4096 tied_embedding_weight: false multi_stage: zero_stage: 2 diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb5..88841966 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,6 +32,8 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None + images: list[list[torch.Tensor]] | None = None + image_positions: list[list[torch.Tensor]] | None = None chosen_spans: list[torch.Tensor] | None = None rejected_spans: list[torch.Tensor] | None = None @@ -49,12 +51,24 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + has_images = any(sample.images is not None for sample in batch) + if has_images: + images = [ + [] if sample.images is None else [torch.from_numpy(image) for image in sample.images] for sample in batch + ] + image_positions = [ + [] if sample.image_positions is None else torch.from_numpy(sample.image_positions) for sample in batch + ] + else: + images, image_positions = None, None return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, chosen_spans=stacked_chosen_spans, rejected_spans=stacked_rejected_spans, + images=images, + image_positions=image_positions, ) diff --git a/fast_llm/data/dataset/gpt/components/__init__.py b/fast_llm/data/dataset/gpt/components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fast_llm/data/dataset/gpt/components/config.py b/fast_llm/data/dataset/gpt/components/config.py new file mode 100644 index 00000000..59c41910 --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/config.py @@ -0,0 +1,20 @@ +import dataclasses + +from fast_llm.engine.config_utils.data_type import DataType + +# TODO: Store span type? +# class SpanType(enum.StrEnum): +# none = "none" +# loss_masking = "loss_masking" +# preference = "preference" + + +@dataclasses.dataclass(kw_only=True) +class GPTMemmapDatasetHeader: + num_documents: int + token_data_type: DataType = DataType.int64 + has_spans: bool = False + has_images: bool = False + + def __post_init__(self): + self.token_data_type = DataType(self.token_data_type) diff --git a/fast_llm/data/dataset/gpt/components/images.py b/fast_llm/data/dataset/gpt/components/images.py new file mode 100644 index 00000000..cf217b5f --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/images.py @@ -0,0 +1,243 @@ +import io +import math +import typing + +import numpy as np +import PIL.Image + +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.gpt.memmap import BufferOffset, ShiftMap +from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.utils import Assert, div + + +class GPTImageDatasetComponent: + def __init__( + self, + header: GPTMemmapDatasetHeader, + index_binary_buffer: memoryview, + binary_buffer: memoryview, + offset: BufferOffset, + ): + self._header = header + self._index_binary_buffer = index_binary_buffer + self._binary_buffer = binary_buffer + + self._count_cumsum = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._header.num_documents + 1, + offset=offset.value, + ) + offset.value += self._count_cumsum.nbytes + self._sizes = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._count_cumsum[-1] * 2, + offset=offset.value, + ).reshape(-1, 2) + offset.value += self._sizes.nbytes + self._positions = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._count_cumsum[-1], + offset=offset.value, + ).reshape(-1, 2) + offset.value += self._positions.nbytes + + def get( + self, + index: int, + start_offset: int, + end_offset: int, + shift_map: ShiftMap, + buffer_offset: BufferOffset, + parameters: GPTSamplingParameters, + ) -> tuple[list[np.ndarray] | None, np.ndarray | None]: + # We get images from the document, discarding those outside the selected range. + images = [] + positions = [] + for image_index in range(self._count_cumsum[index], self._count_cumsum[index + 1]): + image_buffer_size = self._sizes[image_index].prod(initial=3) + image_position = shift_map.shift(self._positions[image_index].item()) + if start_offset <= image_position < end_offset: + images.append( + np.frombuffer( + self._binary_buffer, + dtype=np.dtype(np.uint8), + count=image_buffer_size, + offset=buffer_offset.value, + ).reshape(3, *self._sizes[image_index]) + ) + positions.append(self._positions[image_index]) + + buffer_offset.value += image_buffer_size + + def _get_insert(self, image_index: int, parameters: GPTSamplingParameters): + height, width = resized_image_length + height_patches = div(height, parameters.patch_size) + width_patches = div(width, parameters.patch_size) + image_size = height_patches * width_patches + if parameters.image_break_token is not None: + image_size += height_patches + elif parameters.image_end_token is not None: + image_size += 1 + + image_token_array = np.full((image_size,), -100, dtype=np.int64) + if parameters.image_break_token is not None: + for row in range(height_patches): + position = (row + 1) * width_patches + row + image_token_array[position] = parameters.image_break_token + + if parameters.image_end_token is not None: + # Will override the last image_break_token. + image_token_array[-1] = parameters.image_end_token + + start_pos = 0 + sample_token_ids = [] + for idx, im_position in enumerate(sample.image_positions): + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] + if parameters.image_break_token is not None: + height, width = resized_image_lengths[idx] + num_patches_h = div(height, parameters.patch_size) + num_patches_w = div(width, parameters.patch_size) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + # account for break tokens after each row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = parameters.image_break_token + # handle the last row separately + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if parameters.image_end_token is not None: + image_token_array[last_row_position] = parameters.image_end_token + else: + image_token_array[last_row_position] = parameters.image_break_token + else: + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if parameters.image_end_token is not None: + image_token_array[-1] = parameters.image_end_token + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) + image_sizes[idx] + start_pos = im_position + + resized_image_lengths = [ + get_resize_dims( + *image_length, + parameters.max_image_size, + parameters.max_image_size, + parameters.patch_size, + ) + for image_length in image_lengths + ] + return images, positions + + @classmethod + def write_document_and_gather_index( + cls, document: GPTSample, index_data: dict[str, typing.Any], binary_stream: io.BufferedWriter + ): + has_images = document.images is not None + if "has_images" in index_data: + Assert.eq(index_data["has_images"], has_images) + else: + index_data["has_images"] = has_images + if has_images: + if "image_sizes" not in index_data: + index_data["image_sizes"] = [] + if "image_positions" not in index_data: + index_data["image_positions"] = [] + if "num_pixels" not in index_data: + index_data["num_pixels"] = 0 + for image, image_position in zip(document.images, document.image_positions, strict=True): + # assume 3 channels (RGB) for all images + # TODO: Not consistent with GPTSample? + with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + if img.mode != "RGB": + # Convert all images to RGB + img = img.convert("RGB") + pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." + index_data["image_sizes"].append(np.array(pixels.shape[1:])) + index_data["image_positions"].append(image_position) + # TODO: Shouldn't pixel count exclude the channel dimension? + index_data["num_pixels"] += pixels.size + binary_stream.write(pixels.tobytes(order="C")) + # Cumsum holds both image counts and buffer offsets. + if "image_cumsum" not in index_data: + index_data["image_cumsum"] = [0] + index_data["image_cumsum"].append(len(index_data["image_sizes"])) + + @classmethod + def write_index(self, index_data: dict[str, typing.Any], index_stream: io.BufferedWriter): + if index_data["has_images"]: + Assert.leq(index_data["image_cumsum"][-1], np.iinfo(np.int32).max) + Assert.eq(len(index_data["image_cumsum"]), index_data["num_documents"] + 1) + Assert.eq(len(index_data["image_sizes"]), index_data["image_cumsum"][-1]) + Assert.eq(len(index_data["image_positions"]), index_data["image_cumsum"][-1]) + index_stream.write(np.array(index_data["image_cumsum"], dtype=np.int32).tobytes(order="C")) + # n_pixels * 3 per image + index_stream.write(np.stack(index_data["image_sizes"], dtype=np.int32).tobytes(order="C")) + # Position of each image in the document + index_stream.write(np.array(index_data["image_positions"], dtype=np.int32).tobytes(order="C")) + + def get_sizes(self, index: int, parameters: GPTSamplingParameters) -> list[int]: + return [ + get_num_image_tokens( + *get_resize_dims( + *size.item(), + parameters.max_image_size, + parameters.max_image_size, + parameters.patch_size, + ), + parameters.patch_size, + image_break=parameters.image_break_token is not None, + image_end=parameters.image_end_token is not None, + ) + for size in self._sizes[self._count_cumsum[index] : self._count_cumsum[index + 1]] + ] + + def get_unshifted_positions_and_sizes( + self, index: int, parameters: GPTSamplingParameters + ) -> list[tuple[int, int]]: + return [ + (position, size) + for position, size in zip( + self._positions[self._count_cumsum[index] : self._count_cumsum[index + 1]], + self.get_sizes(index, parameters), + strict=True, + ) + ] + + +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) diff --git a/fast_llm/data/dataset/gpt/components/spans.py b/fast_llm/data/dataset/gpt/components/spans.py new file mode 100644 index 00000000..a4c331e0 --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/spans.py @@ -0,0 +1,73 @@ +import io +import typing + +import numpy as np + +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.memmap import BufferOffset, ShiftMap +from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.utils import Assert + + +class GPTSpansDatasetComponent: + def __init__( + self, + header: GPTMemmapDatasetHeader, + index_binary_buffer: memoryview, + binary_buffer: memoryview, + offset: BufferOffset, + ): + self._header = header + self._index_binary_buffer = index_binary_buffer + self._binary_buffer = binary_buffer + + self._count_cumsum = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._header.num_documents + 1, + offset=offset.value, + ) + offset.value += self._count_cumsum.nbytes + self._spans = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._count_cumsum[-1] * 2, + offset=offset.value, + ).reshape(-1, 2) + offset.value += self._spans.nbytes + + def get(self, index: int, start_offset: int, end_offset: int, shift_map: ShiftMap) -> list[tuple[int, int]]: + loss_masking_spans = [] + for span_begin, span_end in self._spans[self._count_cumsum[index] : self._count_cumsum[index + 1]].tolist(): + span_begin = max(shift_map.shift(span_begin), start_offset) - start_offset + span_end = min(shift_map.shift(span_end), end_offset - 1) - start_offset + if span_end > span_begin: + loss_masking_spans.append((span_begin, span_end)) + return loss_masking_spans + + @classmethod + def write_document_and_gather_index( + cls, document: GPTSample, index_data: dict[str, typing.Any], binary_stream: io.BufferedWriter + ): + has_spans = document.loss_masking_spans is not None + if "has_span" in index_data: + Assert.eq(index_data["has_span"], has_spans) + else: + index_data["has_span"] = has_spans + if has_spans: + if "spans" not in index_data: + index_data["spans"] = [] + index_data["spans"].extend(document.loss_masking_spans) + if "spans_cumsum" not in index_data: + index_data["spans_cumsum"] = [0] + index_data["spans_cumsum"].append(len(index_data["spans"])) + + @classmethod + def write_index(self, index_data: dict[str, typing.Any], index_stream: io.BufferedWriter): + if index_data["has_spans"]: + # Should be ok, checking just in case. + Assert.leq(index_data["spans_cumsum"][-1], np.iinfo(np.int32).max) + Assert.eq(len(index_data["spans_cumsum"]), index_data["num_documents"] + 1) + Assert.eq(len(index_data["spans"]), index_data["spans_cumsum"][-1]) + index_stream.write(np.array(index_data["spans_cumsum"], dtype=np.int32).tobytes(order="C")) + index_stream.write(np.vstack(index_data["spans"], dtype=np.int32).tobytes(order="C")) diff --git a/fast_llm/data/dataset/gpt/components/tokens.py b/fast_llm/data/dataset/gpt/components/tokens.py new file mode 100644 index 00000000..3a91fef5 --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/tokens.py @@ -0,0 +1,63 @@ +import io +import typing + +import numpy as np + +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.memmap import BufferOffset, ShiftMap +from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.utils import Assert + + +class GPTTokensDatasetComponent: + def __init__( + self, + header: GPTMemmapDatasetHeader, + index_binary_buffer: memoryview, + binary_buffer: memoryview, + offset: BufferOffset, + ): + self._header = header + self._index_binary_buffer = index_binary_buffer + self._binary_buffer = binary_buffer + self.sizes = np.frombuffer( + self._index_binary_buffer, dtype=np.int32, count=self._header.num_documents, offset=offset.value + ) + self._item_size = self._header.token_data_type.numpy.itemsize + offset.value += self.sizes.nbytes + + def get( + self, index: int, start_offset: int, end_offset: int, shift_map: ShiftMap, buffer_offset: BufferOffset + ) -> np.ndarray: + unshifted_start_offset = shift_map.unshift(start_offset) + token_ids = np.frombuffer( + self._binary_buffer, + dtype=self._header.token_data_type, + count=shift_map.unshift(end_offset) - unshifted_start_offset, + offset=buffer_offset.value + unshifted_start_offset * self._item_size, + ) + buffer_offset.value += self.sizes[index] * self._item_size + return token_ids + + @classmethod + def write_document_and_gather_index( + cls, document: GPTSample, index_data: dict[str, typing.Any], binary_stream: io.BufferedWriter + ): + if "token_data_type" in index_data: + Assert.eq(document.token_ids.dtype, index_data["token_data_type"]) + else: + index_data["token_data_type"] = document.token_ids.dtype + if "document_lengths" not in index_data: + index_data["document_lengths"] = [] + index_data["document_lengths"].append(document_length := len(document.token_ids)) + if "num_tokens" not in index_data: + index_data["num_tokens"] = 0 + index_data["num_tokens"] += document_length + + # Write document to binary file + binary_stream.write(document.token_ids.tobytes(order="C")) + + @classmethod + def write_index(self, index_data: dict[str, typing.Any], index_stream: io.BufferedWriter): + # Document (tokens) lengths. + index_stream.write(np.array(index_data["document_lengths"], dtype=np.int32).tobytes(order="C")) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 656cd7d2..782ed6b4 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -61,7 +61,15 @@ class GPTSamplingConfig(SamplingConfig): @dataclasses.dataclass(kw_only=True) -class GPTSamplingParameters(SamplingParameters): +class ImageSamplingParameters: + patch_size: int | None = None + max_image_size: int | None = None + image_break_token: int | None = None + image_end_token: int | None = None + + +@dataclasses.dataclass(kw_only=True) +class GPTSamplingParameters(SamplingParameters, ImageSamplingParameters): """ Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ @@ -70,6 +78,7 @@ class GPTSamplingParameters(SamplingParameters): vocab_size: int use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False + use_images: bool = False cross_document_attention: bool = True truncate_documents: bool = True # How many extra tokens to add to the sequence length. @@ -138,11 +147,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of tokens in the dataset.", hint=FieldHint.optional, ) + num_pixels: int | None = Field( + default=None, + desc="Expected number of pixels in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) + return GPTMemmapDataset( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels + ) @config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 2b2c8b3b..843f6735 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -158,9 +158,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer.tokenize(prefix, add_eos=False)], dtype=np.int64) + middle = np.array([*self._tokenizer.tokenize(middle, add_bos=False, add_eos=False)], dtype=np.int64) + suffix = np.array([*self._tokenizer.tokenize(suffix, add_bos=False)], dtype=np.int64) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 89622977..d2bd84bf 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -3,7 +3,7 @@ import numpy as np -from fast_llm.data.dataset.gpt.config import GPTSamplingData +from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset if typing.TYPE_CHECKING: @@ -12,7 +12,7 @@ class GPTIndexedDataset(IndexedDataset): @abc.abstractmethod - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, @@ -20,7 +20,7 @@ def get_document_sizes(self) -> np.ndarray: """ @abc.abstractmethod - def get_document_size(self, index: int) -> int: + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: """ The size of a document in the dataset. """ @@ -38,12 +38,12 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe _dataset: GPTIndexedDataset - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] + return self._dataset.get_document_sizes(parameters)[self._begin : self._end] - def get_document_size(self, index: int) -> int: - return self._dataset.get_document_size(self._begin + index) + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: + return self._dataset.get_document_size(self._begin + index, parameters) class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( @@ -51,10 +51,10 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( ): _datasets: list[GPTIndexedDataset] - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: # TODO: This can be really big. - return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + return np.concatenate([dataset.get_document_sizes(parameters) for dataset in self._datasets]) - def get_document_size(self, index: int) -> int: + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") - return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item(), parameters) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index f39fd56f..99d9957f 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,16 +1,52 @@ +import functools +import json import pathlib import struct import typing import numpy as np +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.components.images import GPTImageDatasetComponent +from fast_llm.data.dataset.gpt.components.spans import GPTSpansDatasetComponent +from fast_llm.data.dataset.gpt.components.tokens import GPTTokensDatasetComponent +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER -from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_INDEX_HEADER from fast_llm.utils import Assert, div +class BufferOffset: + # This makes offsets mutable. + def __init__(self, value: int): + self.value: int = value + + +class ShiftMap: + """ + A map between original and shifted token indices (i.e., accounting for extra content such as images). + Also serves as a cache so we don't have to recompute positions and sizes every time. + """ + + def __init__(self, positions_and_sizes: list[tuple[int, int]]): + self._positions_and_sizes = positions_and_sizes + + @functools.cached_property + def shifted_positions(self) -> list[int]: + return [self.shift(position) for position, _ in self._positions_and_sizes] + + def shift(self, index: int) -> int: + return index + sum(size for position, size in self._positions_and_sizes if index > position) + + def unshift(self, index: int) -> int: + return index - sum( + size + for shifted_position, (_, size) in zip(self.shifted_positions, self._positions_and_sizes, strict=True) + if shifted_position < index + ) + + class GPTMemmapDataset(GPTIndexedDataset): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, @@ -26,293 +62,241 @@ def __init__( prefix: pathlib.Path | str, num_documents: int | None = None, num_tokens: int | None = None, + num_pixels: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix, num_documents, num_tokens, num_pixels) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init( + self, + name: str, + prefix: pathlib.Path | str, + num_documents: int | None = None, + num_tokens: int | None = None, + num_pixels: int | None = None, + ) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) - self._has_spans = 0 - self._has_preference_spans = False with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: - self._has_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack("= 2 and bool(struct.unpack("= 3 and bool(struct.unpack("= 4 and bool(struct.unpack("= 2: - self._spans = [] - self._num_spans = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_documents, - offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, - ) - span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes - self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] - for idx in range(self._num_documents): - self._spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_spans[idx] * 2, - offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) - ) - - # read preference spans - self._chosen_spans = None - self._rejected_spans = None - if self._has_preference_spans and self._version >= 3: - self._chosen_spans = [] - self._rejected_spans = [] - chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes - for idx in range(self._num_documents): - self._chosen_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - rejected_span_offset = ( - offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes - ) - for idx in range(self._num_documents): - self._rejected_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") - self._bin_buffer = memoryview(self._bin_buffer_mmap) - - self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) + self._spans = ( + GPTSpansDatasetComponent(self._header, self._index_binary_buffer, self._binary_buffer, offset) + if self._header.has_spans + else None + ) + self._images = ( + GPTImageDatasetComponent(self._header, self._index_binary_buffer, self._binary_buffer, offset) + if self._header.has_images + else None + ) + + if num_pixels is not None: + Assert.eq(num_pixels, self._images.total_pixels) + + # TODO: Simplify. + self._num_tokens = ( + self._binary_buffer_mmap.size + if self._images is None + else self._binary_buffer_mmap.size - self._images.total_pixels + ) if num_tokens is not None: - assert self._num_tokens == num_tokens + Assert.eq(num_tokens, self._num_tokens) - def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + def __getstate__(self) -> tuple[str, pathlib.Path]: + return (self._name, self._prefix) - def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): + def __setstate__(self, state: tuple[str, pathlib.Path]): self._init(*state) def __del__(self): if hasattr(self, "_bin_buffer_mmap"): - self._bin_buffer_mmap._mmap.close() # noqa - del self._bin_buffer_mmap + self._binary_buffer_mmap._mmap.close() # noqa + del self._binary_buffer_mmap if hasattr(self, "_index_bin_buffer"): - self._index_bin_buffer_mmap._mmap.close() # noqa - del self._index_bin_buffer_mmap + self._index_binary_buffer_mmap._mmap.close() # noqa + del self._index_binary_buffer_mmap def get( self, - idx: int, - offset: int = 0, - length: int | None = None, - use_loss_masking_spans: bool = False, - use_preference_loss_spans: bool = False, + index: int, + start_offset: int = 0, + end_offset: int | None = None, + parameters: GPTSamplingParameters | None = None, ) -> GPTSample: - token_ids = np.frombuffer( - self._bin_buffer, - dtype=self._dtype, - count=self._document_sizes[idx] - offset if length is None else length, - offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - ) - sample_spans = None - if use_loss_masking_spans and self._spans is not None: - sample_spans = self._spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - sample_spans = sample_spans[ - (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - ] - - # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - - chosen_span = None - rejected_span = None - - if use_preference_loss_spans: - if not self._has_preference_spans: - raise ValueError("No preference spans found in memmap dataset.") - elif self._has_preference_spans and self._chosen_spans is None: - raise ValueError("Failed to read chosen spans from memmap dataset.") - elif self._has_preference_spans and self._rejected_spans is None: - raise ValueError("Failed to read rejected spans from memmap dataset.") - else: - chosen_span = self._chosen_spans[idx] - # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] + if end_offset is None: + end_offset = self.get_document_size(index, parameters) - # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset - chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset + shift_map = ShiftMap( + self._images.get_unshifted_positions_and_sizes(index, parameters) if parameters.use_images else [] + ) - rejected_span = self._rejected_spans[idx] + buffer_offset = BufferOffset(self._buffer_offsets[index].item()) + sample = GPTSample(token_ids=self._tokens.get(index, start_offset, end_offset, shift_map, buffer_offset)) - # filter spans that are outside the range of the selected tokens in the document - rejected_span = rejected_span[ - (rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset) - ][0] + if parameters.use_loss_masking_spans: + sample.loss_masking_spans = self._spans.get(index, start_offset, end_offset, shift_map) - # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset - rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + if parameters.use_images: + sample.images, sample.image_positions = self._images.get( + index, start_offset, end_offset, shift_map, buffer_offset + ) - return GPTSample( - token_ids=token_ids, - loss_masking_spans=sample_spans, - chosen_span=chosen_span, - rejected_span=rejected_span, - ) + start_pos = 0 + sample_token_ids = [] + for idx, im_position in enumerate(sample.image_positions): + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] + if parameters.image_break_token is not None: + height, width = resized_image_lengths[idx] + num_patches_h = div(height, parameters.patch_size) + num_patches_w = div(width, parameters.patch_size) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + # account for break tokens after each row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = parameters.image_break_token + # handle the last row separately + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if parameters.image_end_token is not None: + image_token_array[last_row_position] = parameters.image_end_token + else: + image_token_array[last_row_position] = parameters.image_break_token + else: + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if parameters.image_end_token is not None: + image_token_array[-1] = parameters.image_end_token + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) + image_sizes[idx] + start_pos = im_position + + return sample @property def name(self) -> str: return self._name def __len__(self) -> int: - return self._num_documents + return self._header.num_documents - @property - def num_tokens(self) -> int: - return self._num_tokens - - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes + if parameters is not None and parameters.use_images: + # TODO: Optimize this. + return np.array([self.get_document_size(index, parameters) for index in range(self._header.num_documents)]) + return self._tokens.sizes + + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: + size = self._tokens.sizes[index].item() + if parameters is not None and parameters.use_images: + for _, size_ in self._images.get_positions_and_sizes(index, parameters): + size += size_ + return size + + def _shift_offset(self, offset, index: int, parameters: GPTSamplingParameters | None = None) -> int: + if parameters is not None and parameters.use_images: + offset += sum( + size for position, size in self._images.get_positions_and_sizes(index, parameters) if position < offset + ) + return offset - def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item() + def _unshift_offset(self, offset, index: int, parameters: GPTSamplingParameters | None = None) -> int: + unshifted_offset = offset + if parameters is not None and parameters.use_images: + for position, size in self._images.get_positions_and_sizes(index, parameters): + shifted_position = self._shift_offset(position, index, parameters) + if shifted_position < offset: + unshifted_offset -= size + return unshifted_offset @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): - # Initialize metadata - dtype = None + buffer_offsets = [] + index_data = {} num_documents = 0 - lengths = [] - pointers = [] - offset = 0 - # number of spans for each document - num_spans = [] - spans = [] - chosen_spans = [] - rejected_spans = [] + component_classes = (GPTTokensDatasetComponent, GPTSpansDatasetComponent, GPTImageDatasetComponent) prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) - # Write the binary data file (.bin) lazily - with prefix.with_suffix(".bin").open("wb") as bin_stream: + with prefix.with_suffix(".bin").open("wb") as binary_stream: + for document in documents: - # Infer dtype from the first document - if dtype is None: - dtype = document.token_ids.dtype - assert dtype is not None, "Document dtype could not be inferred from the data." - - # Ensure all documents have the same dtype - assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}." - - # Write document to binary file - bin_stream.write(document.token_ids.tobytes(order="C")) - - # Update metadata - doc_length = len(document.token_ids) - lengths.append(doc_length) - pointers.append(offset) - if document.loss_masking_spans is not None: - num_spans.append(len(document.loss_masking_spans)) - spans.append(document.loss_masking_spans) - if document.chosen_span is not None: - chosen_spans.append(document.chosen_span) - if document.rejected_span is not None: - rejected_spans.append(document.rejected_span) - offset += doc_length * np.dtype(dtype).itemsize - num_documents += 1 + buffer_offsets.append(binary_stream.tell()) + for component_class in component_classes: + component_class.write_document_and_gather_index(document, index_data, binary_stream) - # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) - pointers = np.array(pointers, dtype=np.int64) - num_spans = np.array(num_spans, dtype=np.int32) - if len(spans) > 0: - spans = np.vstack(spans, dtype=np.int32) - else: - spans = np.array(spans, dtype=np.int32) - chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) - rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) + # TODO: Address + assert document.chosen_span is None and document.rejected_span is None + + num_documents += 1 # Write the index file (.idx) - with prefix.with_suffix(".idx").open("wb") as idx_stream: - idx_stream.write(MEMMAP_INDEX_HEADER) - # Indicates the version - # Version 2 optionally adds loss-masking spans - # Version 3 optionally adds chosen/rejected spans - idx_stream.write(struct.pack(" 0 else 0)) - # Flag to indicate whether preference loss-masking spans are present - idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) - # Data type - idx_stream.write(struct.pack(" None: """ Create a `GPTSampledDataset` with the requested parameters. """ - # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) - documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + # Get the size each document, the main information needed for sampling. + # Note: "document" may refer to more than just text. + document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes(self._parameters)).to(self._device) + + documents_per_epoch, tokens_per_epoch, long_docs_filter = self._get_epoch_size(document_sizes) + num_epochs, shuffled_epochs = self._get_epoch_count(documents_per_epoch, tokens_per_epoch) + + shuffled_documents = documents_per_epoch * shuffled_epochs + unshuffled_epochs = num_epochs - shuffled_epochs + + yaml_data, cached = self._get_and_compare_yaml_data(documents_per_epoch, tokens_per_epoch, unshuffled_epochs) + if cached: + return + + if shuffled_documents > 1e8: + warnings.warn( + f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." + f" This may take a while and/or use an excessive amount of memory." + ) + elif documents_per_epoch > 1e8: + # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? + warnings.warn( + f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." + f" Sampling may take a while and/or use an excessive amount of memory." + ) - # Calculate basic stats. - if not self._truncate_documents: + document_shuffling = self._get_document_shuffling(documents_per_epoch, shuffled_documents, shuffled_epochs) + + # To get a sample on the fly we need to know where it begins, + # and this is a non-trivial information because the documents have variable length. + # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. + # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. + # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. + # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. + # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. + # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` + + # TODO: Allowing for max 100% extra tokens for padding, is that enough? + cumsum_dtype = get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs) + if unshuffled_epochs > 0: + token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum(document_sizes, 0, cumsum_dtype) + self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) + else: + unshuffled_tokens = 0 + + if shuffled_epochs > 0: + token_cumsum_shuffled, _ = self._get_token_cumsum( + document_sizes[ + # Torch indexing only works with int32 or int64 + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) + ], + self._unshuffled_tokens, + cumsum_dtype, + ) + self._token_cumsum_shuffled.save(token_cumsum_shuffled) + self._document_shuffling.save( + document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy(force=True) + ) + + yaml_data["unshuffled_tokens"] = unshuffled_tokens + self._load_yaml_data(yaml_data) + if self._yaml_path is not None: + self._yaml_path.parent.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(yaml_data, self._yaml_path.open("w")) + + def _get_epoch_size(self, document_sizes: torch.Tensor) -> tuple[int, int, torch.Tensor | None]: + documents_per_epoch = document_sizes.numel() + if self._truncate_documents: + tokens_per_epoch = document_sizes.sum().item() + long_docs_filter = None + else: assert _extension_available, ( "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._parameters.sequence_length + 1 - ignored_documents = long_docs_filter.sum().item() - if ignored_documents: + long_docs_filter = document_sizes <= self._parameters.sequence_length + 1 + documents_per_epoch_filtered = long_docs_filter.sum().item() + if ignored_documents := documents_per_epoch_filtered - documents_per_epoch: log_main_rank( - f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", + f" > {ignored_documents}/{documents_per_epoch} documents" + f" are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + tokens_per_epoch = document_sizes[long_docs_filter].sum().item() if tokens_per_epoch == 0: raise RuntimeError( - f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." + f" > No documents shorter than {self._parameters.sequence_length+1}" + f" tokens found in dataset {self._indexed_dataset.name}." ) + return documents_per_epoch, tokens_per_epoch, long_docs_filter + def _get_epoch_count(self, documents_per_epoch: int, tokens_per_epoch: int) -> tuple[int, int]: # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, # but in case of truncations we also include those last labels in the following sample, # so we need `sequence_length * num_samples + extra_tokens` tokens in total. - if self._parameters.use_preference_loss_spans: - documents_per_epoch = (~long_docs_filter).sum().item() - num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) - elif self._truncate_documents: + if self._truncate_documents: num_epochs = math.ceil( (self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens) / tokens_per_epoch @@ -174,32 +246,34 @@ def _sample(self) -> None: ) # Prepare for shuffling. - generator = torch.Generator(device=self._device) if self._config.shuffle == ShufflingType.skip_first_epoch: shuffled_epochs = num_epochs - 1 elif self._config.shuffle == ShufflingType.disabled: shuffled_epochs = 0 else: shuffled_epochs = num_epochs - shuffled_documents = documents_per_epoch * shuffled_epochs - unshuffled_epochs = num_epochs - shuffled_epochs + return num_epochs, shuffled_epochs + def _get_and_compare_yaml_data( + self, + documents_per_epoch: int, + tokens_per_epoch: int, + unshuffled_epochs: int, + ) -> tuple[dict[str, typing.Any], bool]: yaml_data = { "dataset": { "name": self._indexed_dataset.name, "documents_per_epoch": documents_per_epoch, "tokens_per_epoch": tokens_per_epoch, }, - "num_samples": self._parameters.num_samples, + "sampling": self._parameters.__dict__, "unshuffled_epochs": unshuffled_epochs, - "sequence_length": self._parameters.sequence_length, - "truncate_documents": self._truncate_documents, "config": self._config.to_dict(), } if self._truncate_documents: yaml_data["unshuffled_tokens"] = tokens_per_epoch * unshuffled_epochs - if self._yaml_path is not None and self._yaml_path.is_file(): + if cached := (self._yaml_path is not None and self._yaml_path.is_file()): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) # Hack to make sure unshuffled tokens are loaded if not self._truncate_documents: @@ -216,120 +290,8 @@ def _sample(self) -> None: ) # Dataset is already sampled, skip. logger.info(f"Using existing sampling for dataset {self.name}") - return - if shuffled_documents > 1e8: - warnings.warn( - f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." - f" This may take a while and/or use an excessive amount of memory." - ) - elif documents_per_epoch > 1e8: - # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? - warnings.warn( - f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." - f" Sampling may take a while and/or use an excessive amount of memory." - ) - - # Use the smallest possible data type to save memory and disk usage. - document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch - # Shuffle the dataset (documents) - # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial - # so we only evaluate and store the shuffled part `document_shuffling`. - if self._config.shuffle == ShufflingType.full: - generator.manual_seed(self._config.seed) - # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` - document_shuffling = ( - torch.randperm( - shuffled_documents, - generator=generator, - dtype=get_unsigned_integer_type(shuffled_documents).torch, - device=self._device, - ) - .remainder_(documents_per_epoch) - .to(dtype=document_shuffling_dtype) - ) - elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): - document_shuffling = torch.empty( - shuffled_documents, - dtype=document_shuffling_dtype, - device=self._device, - ) - for i in range(shuffled_epochs): - generator.manual_seed(self._config.seed + i * 571) - torch.randperm( - documents_per_epoch, - generator=generator, - out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], - ) - elif self._config.shuffle == ShufflingType.disabled: - document_shuffling = None - else: - raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") - - if self._parameters.use_preference_loss_spans: - yaml_data["unshuffled_tokens"] = 0 # not used, ignore - - # index of all documents less than seq length long - doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] - self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) - - # apply shuffling on doc_length_filtered_indicies - if shuffled_epochs > 0: - self._document_shuffling.save( - document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu) - ) - self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - return - - # To get a sample on the fly we need to know where it begins, - # and this is a non-trivial information because the documents have variable length. - # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. - # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. - # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. - # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. - # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. - # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` - if unshuffled_epochs > 0: - token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes, - offset=0, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) - else: - unshuffled_tokens = 0 - - if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens - self._load_yaml_data(yaml_data) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - - if shuffled_epochs > 0: - token_cumsum_shuffled, _ = self._get_token_cumsum( - document_sizes[ - # Torch indexing only works with int32 or int64 - document_shuffling.to( - dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 - ) - ], - offset=self._unshuffled_tokens, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_shuffled.save(token_cumsum_shuffled) - self._document_shuffling.save( - document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy( - force=self._config.gpu - ) - ) - # Free memory - del document_shuffling + return yaml_data, cached def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: if self._truncate_documents: @@ -372,6 +334,50 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - ] return out, num_tokens + def _get_document_shuffling( + self, + documents_per_epoch: int, + shuffled_documents: int, + shuffled_epochs: int, + ) -> torch.Tensor | None: + generator = torch.Generator(device=self._device) + # Use the smallest possible data type to save memory and disk usage. + document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch + # Shuffle the dataset (documents) + # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial + # so we only evaluate and store the shuffled part `document_shuffling`. + if self._config.shuffle == ShufflingType.full: + generator.manual_seed(self._config.seed) + # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` + document_shuffling = ( + torch.randperm( + shuffled_documents, + generator=generator, + dtype=get_unsigned_integer_type(shuffled_documents).torch, + device=self._device, + ) + .remainder_(documents_per_epoch) + .to(dtype=document_shuffling_dtype) + ) + elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): + document_shuffling = torch.empty( + shuffled_documents, + dtype=document_shuffling_dtype, + device=self._device, + ) + for i in range(shuffled_epochs): + generator.manual_seed(self._config.seed + i * 571) + torch.randperm( + documents_per_epoch, + generator=generator, + out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], + ) + elif self._config.shuffle == ShufflingType.disabled: + document_shuffling = None + else: + raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") + return document_shuffling + def __len__(self) -> int: return self._parameters.num_samples @@ -383,39 +389,6 @@ def __getitem__(self, index: int) -> typing.Any: """ self._lazy_load() - if self._parameters.use_preference_loss_spans: - if index < self._unshuffled_documents: - document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] - else: - document_index = self._doc_length_filtered_indicies[ - self._document_shuffling[index - self._unshuffled_documents].item() - ] - - sample = self._indexed_dataset.get( - document_index, - offset=0, - length=self._document_sizes[document_index], - use_loss_masking_spans=self._parameters.use_loss_masking_spans, - use_preference_loss_spans=self._parameters.use_preference_loss_spans, - ) - - chosen_span_end = sample.chosen_span[1] + 1 - sequence_lengths = [ - chosen_span_end, - len(sample.token_ids) - chosen_span_end, - ] - - # compute padding size - padding = np.full((self._parameters.sequence_length + 1,), 0) - padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) - sample.token_ids = padding - - if not self._parameters.cross_document_attention: - sample.sequence_lengths = np.array(sequence_lengths) - - return sample - # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample sample_length = ( @@ -441,7 +414,13 @@ def __getitem__(self, index: int) -> typing.Any: token_count = token_start_array[token_start_cumsum_index] token_ids = [] - loss_masking_spans = [] + if self._parameters.use_loss_masking_spans: + loss_masking_spans = [] + if self._parameters.use_images: + images = [] + image_positions = [] + image_tokens_added = 0 + text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -449,7 +428,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index) + document_size = self._indexed_dataset.get_document_size(document_index, self._parameters) if not self._truncate_documents: if document_size > self._parameters.sequence_length + 1: @@ -468,21 +447,79 @@ def __getitem__(self, index: int) -> typing.Any: else: # Move on to the next sample. token_count += padding_size + continue + elif document_size + tokens_in_sample == self._parameters.sequence_length + 1: + if token_count + document_size == token_start: + token_count += document_size + document_sampling_index += 1 + continue # Determine if the document belongs to the requested sample. if token_count + document_size > token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) - sample = self._indexed_dataset.get( + sample: GPTSample = self._indexed_dataset.get( document_index, offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) - token_ids.append(sample.token_ids) + if self._parameters.use_images: + start_pos = 0 + sample_token_ids = [] + for idx, im_position in enumerate(sample.image_positions): + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] + if self._parameters.image_break_token is not None: + height, width = resized_image_lengths[idx] + num_patches_h = div(height, self._parameters.patch_size) + num_patches_w = div(width, self._parameters.patch_size) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + # account for break tokens after each row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = self._parameters.image_break_token + # handle the last row separately + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if self._parameters.image_end_token is not None: + image_token_array[last_row_position] = self._parameters.image_end_token + else: + image_token_array[last_row_position] = self._parameters.image_break_token + else: + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if self._parameters.image_end_token is not None: + image_token_array[-1] = self._parameters.image_end_token + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) + image_tokens_added += image_sizes[idx] + start_pos = im_position + # Add the last text segment after the last image + sample_token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(sample_token_ids[-1]) + token_ids.append(np.concatenate(sample_token_ids)) + images.append(sample.images) + else: + token_ids.append(sample.token_ids) + text_tokens_added += len(token_ids[-1]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: + if self._parameters.use_images: + # Shift the spans to account for the images. + loss_masking_span[0] += sum( + image_size + for image_size, image_position in zip(image_sizes, sample.image_positions) + if image_position < loss_masking_span[0] + ) + loss_masking_span[1] += sum( + image_size + for image_size, image_position in zip(image_sizes, sample.image_positions) + if image_position < loss_masking_span[1] + ) + span = np.clip( loss_masking_span + token_count - token_start, 0, @@ -500,34 +537,59 @@ def __getitem__(self, index: int) -> typing.Any: if not self._parameters.cross_document_attention else None ) + token_ids = np.concatenate(token_ids, dtype=np.int64) loss_masking_spans = ( (np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) if self._parameters.use_loss_masking_spans else None ) + images = [im for img_list in images for im in img_list] if self._parameters.use_images else None + image_positions = np.array(image_positions) if self._parameters.use_images else None Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + return GPTSample( + token_ids=token_ids, + loss_masking_spans=loss_masking_spans, + sequence_lengths=sequence_lengths, + images=images, + image_positions=image_positions, + ) @property def name(self) -> str: return self._indexed_dataset.name + def _get_image_sizes(self, document_index: int): + # TODO: Duplicate of _get_document_sizes + image_lengths = self._indexed_dataset.get_image_size(document_index) + + resized_image_lengths = [ + get_resize_dims( + *image_length, + self._parameters.max_image_size, + self._parameters.max_image_size, + self._parameters.patch_size, + ) + for image_length in image_lengths + ] + image_sizes = [ + get_num_image_tokens( + *image_length, + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for image_length in resized_image_lengths + ] + image_tokens = sum(image_sizes) + return resized_image_lengths, image_sizes, image_tokens + def _lazy_load(self): if not hasattr(self, "_documents_per_epoch"): self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r"))) def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - - if self._parameters.use_preference_loss_spans: - data["unshuffled_tokens"] = 0 # not used, ignore - elif "unshuffled_tokens" not in data: - # Backward compatibility - # TODO v0.x: Remove - assert self._truncate_documents - data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] - self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index d2aaee5e..da353793 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -42,6 +42,18 @@ class TextColumnConfig(SourceSchemaConfig): ) +@config_class(dynamic_type={SourceSchemaConfig: "text_image_column"}) +class TextImageColumnConfig(TextColumnConfig): + images_column: str = Field( + default="images", + desc="Field containing images relevant to a document.", + ) + image_positions_column: None | str = Field( + default="image_positions", + desc="Field containing image positions within a document.", + ) + + @config_class() class GPTHuggingfaceDatasetConfig(Config): path: str = Field( @@ -175,6 +187,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) + image_patch_size: int = Field( + default=16, + desc="Patch size for images. This is used solely for computing the number of tokens in an image to get an even split.", + hint=FieldHint.optional, + ) splits: dict[str, float] | None = Field( default=None, desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 33c40bf8..94eede19 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,3 +1,5 @@ +import io +import itertools import json import logging import multiprocessing @@ -8,6 +10,7 @@ import datasets import huggingface_hub import numpy as np +import PIL.Image import requests import torch.distributed import tqdm @@ -24,7 +27,11 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + GPTMemmapDatasetPreparatorConfig, + TextColumnConfig, + TextImageColumnConfig, +) from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -39,36 +46,44 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _loss_masking_spans_column: str | None def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) for text in batch[self._text_column] - ] - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "num_tokens": num_tokens, - } - - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( + input_ids, token_spans, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2), + np.array(image_token_positions, dtype=np.int32), ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip(batch[self._text_column], batch[self._loss_masking_spans_column]) + for input_ids, token_spans, image_token_positions in [ + self._tokenizer.tokenize( + text, + loss_mask_spans, + im_char_positions, + ) + for text, loss_mask_spans, im_char_positions in zip( + batch[self._text_column], + batch.get(self._loss_masking_spans_column, itertools.repeat(None)), + batch.get(self._image_positions_column, itertools.repeat(None)), + ) ] ] ), ) num_tokens = [len(x) for x in input_ids] + num_pixels = [0] * len(input_ids) + for idx, images in enumerate(batch.get("images", [])): + for bytes_im in images: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_pixels[idx] += width * height * 3 + return { "input_ids": input_ids, + "image_positions": image_token_positions, "token_spans": token_spans, "num_tokens": num_tokens, + "num_pixels": num_pixels, } def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -141,27 +156,22 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - ) - elif ( - "chosen_token_spans" in shard_dataset.column_names - and "rejected_token_spans" in shard_dataset.column_names - and self._config.dataset.chosen_text is not None - and self._config.dataset.rejected_text is not None - ): - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), - chosen_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), - rejected_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + has_preference_spans = ( + self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None + ) + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + np.array(item["input_ids"], dtype=self._data_type.numpy), + item["images"] if self._images_column else None, + item["image_positions"] if self._image_positions_column else None, + ( + np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) + if self._loss_masking_spans_column + else None + ), + item["chosen_token_spans"] if has_preference_spans else None, + item["rejected_token_spans"] if has_preference_spans else None, + ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -171,6 +181,7 @@ def _document_generator(): "path": prefix, "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), + "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), } ) @@ -290,6 +301,11 @@ def run(self) -> None: if isinstance(self._config.dataset.source_schema, TextColumnConfig): self._text_column = self._config.dataset.source_schema.input_column self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column + if isinstance(self._config.dataset.source_schema, TextImageColumnConfig): + self._images_column = self._config.dataset.source_schema.images_column + self._image_positions_column = self._config.dataset.source_schema.image_positions_column + # decoding bytes to images is slow and should be done only when needed + dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) else: raise ValueError( f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'." @@ -298,18 +314,17 @@ def run(self) -> None: if self._text_column not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._text_column}'.") - if self._config.dataset.source_schema.loss_masking_spans_column is not None and ( + if self._loss_masking_spans_column is not None and ( self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None ): - raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: + raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if self._loss_masking_spans_column not in dataset.column_names: + raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") # route tokenize function - if self._loss_masking_spans_column is not None: - if self._loss_masking_spans_column not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") - tokenize_fn = self._tokenize_batch_with_spans elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: if self._config.dataset.chosen_text not in dataset.column_names: raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") @@ -329,6 +344,13 @@ def run(self) -> None: # Calculate total number of tokens total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + total_pixels = ( + sum(tqdm.tqdm(tokenized_dataset["num_pixels"], desc="Counting pixels", unit="pixels")) + if self._images_column + else 0 + ) + # Add the token-equivalent bytes of pixels to determine shard size + total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) @@ -357,7 +379,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa # Create the config file(s) on rank 0 if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, self._config.splits, self._config.output_path, self._config.image_patch_size ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" @@ -397,7 +419,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path + cls, + dataset_configs: list[GPTMemmapDatasetConfig], + splits: dict[str, int | float], + output_path: pathlib.Path, + image_patch_size: None | int = None, ) -> dict[str, GPTSampledDatasetConfig]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] @@ -427,10 +453,20 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + text_sizes, image_sizes = dataset.get_document_sizes() + tokens_cumsum = text_sizes.cumsum() + Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) + if image_sizes: + num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) + # We use the patch sizes only for the purposes of even splitting and blending weights. + # We can always use a different patch size for training without any significant impact + # Unless the patch size used at training time is significantly different from the one used here + image_tokens_cumsum = num_pixels_cumsum // (image_patch_size**2) + tokens_cumsum += image_tokens_cumsum + num_pixels_cumsum = num_pixels_cumsum * 3 + Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) + begin_index = _get_nearest_split(tokens_cumsum, split_begin_in_dataset * tokens_cumsum[-1]) + end_index = _get_nearest_split(tokens_cumsum, split_end_in_dataset * tokens_cumsum[-1]) if end_index > begin_index: datasets_in_split.append( GPTDatasetSliceConfig.from_dict( @@ -443,8 +479,8 @@ def _split_and_blend_dataset_configs( ) ) dataset_tokens_in_split.append( - sizes_cumsum[end_index - 1].item() - - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) + tokens_cumsum[end_index - 1].item() + - (tokens_cumsum[begin_index - 1].item() if begin_index > 0 else 0) ) # [else] None of the dataset belongs to the split. diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c7458620..23a839af 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -41,44 +41,77 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, begin=True, end=True) -> list[int]: + def _tokenize(self, text: str, begin=True, end=True) -> list[int]: return ( ([self.bod_id] if begin else []) + self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end else []) ) - def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] - ) -> tuple[list[int], list[tuple[int, int]]]: + def tokenize( + self, text: str, add_bos=True, add_eos=True, char_spans=None, image_positions=None + ) -> tuple[list[int], list[tuple[int, int]], list[int]]: """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. + Tokenize the input text and return the tokenized input_ids, token spans, and image token positions. + This version simplifies logic by merging all relevant positions, sorting, and tokenizing between them. """ - input_ids = [] + if not image_positions: + image_positions = [] + if not char_spans: + char_spans = [] + + # Collect all positions with their type + positions = [] + for pos in image_positions: + positions.append((pos, "image")) + + for start, end in char_spans: + positions.append((start, "span_start")) + positions.append((end + 1, "span_end")) + # Sort positions by character index. We assume that image and span positions are individually sorted and spans do not overlap + positions = sorted(positions, key=lambda x: x[0]) + + token_ids = [] token_spans = [] + image_token_positions = [] char_pos = 0 - beginning_of_text = True + current_span_start = None - for start, end in char_spans: - if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 + for position in positions: + # We only tokenize if there is at least one character, else we might potentially add begin/end multiple times + if char_pos < position[0]: + tokenized_text = self._tokenize( + text[char_pos : position[0]], + begin=add_bos and char_pos == 0, + end=add_eos and position[0] > len(text) - 1, + ) + token_ids.extend(tokenized_text) + char_pos = position[0] + # beginning_of_text = False + if position[1] == "image": + if position[0] == 0: + # image should be after the bos token + image_token_positions.append(1) + else: + image_token_positions.append(len(token_ids)) + elif position[1] == "span_start": + assert ( + current_span_start is None + ), "Starting a new span before current has ended, please check for overlapping spans" + current_span_start = len(token_ids) + elif position[1] == "span_end": + assert ( + current_span_start is not None + ), "Closing a span that has not started, please check for overlapping spans" + # spans are inclusive, so we take the index of the last token in the span + token_spans.append((current_span_start, len(token_ids) - 1)) + current_span_start = None + # Handle any remaining text after the last position and add EOS token if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans + tokenized_text = self._tokenize(text[char_pos:], begin=add_bos and char_pos == 0, end=add_eos) + token_ids.extend(tokenized_text) + + return token_ids, token_spans, image_token_positions def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index afe38129..96fb5332 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -150,7 +150,6 @@ def _load_weights( ].values() } elif (config.path / transformers.utils.WEIGHTS_NAME).is_file(): - # TODO: Prevent unsafe by default paths = {config.path / transformers.utils.WEIGHTS_NAME} elif (config.path / transformers.utils.WEIGHTS_INDEX_NAME).is_file(): logger.info(f"Loading index from {config.path / transformers.utils.WEIGHTS_INDEX_NAME}") @@ -170,7 +169,7 @@ def _load_weights( for key in f.keys(): yield key, "weights", f.get_slice(key) elif path.suffix == ".bin": - # TODO: Prevent unsafe by default - yield from torch.load(path) + # TODO: Confirm that loading works with `weights_only=True` + yield from torch.load(path, weights_only=True) else: raise NotImplementedError(f"Unknown file format for {path}") diff --git a/fast_llm/engine/config_utils/parameter.py b/fast_llm/engine/config_utils/parameter.py index 76416d36..c0910c09 100644 --- a/fast_llm/engine/config_utils/parameter.py +++ b/fast_llm/engine/config_utils/parameter.py @@ -1,7 +1,8 @@ import math import typing -from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.initialization import Initialization, InitializationConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.layers.common.peft.config import PeftConfig @@ -36,7 +37,7 @@ def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): @config_class() -class ParameterConfig(Config): +class ParameterConfig(ModuleConfig): initialization: InitializationConfig = Field( desc="If provided, override the default initialization method set by the parent layer.", hint=FieldHint.feature, diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 26877ee0..74b5cf21 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -135,3 +135,11 @@ def _get_configurable_class(self) -> "type[YarnRotary]": from fast_llm.layers.attention.rotary.rotary import YarnRotary return YarnRotary + + +@config_class(dynamic_type={RotaryConfig: "default_2d"}) +class Rotary2DConfig(DefaultRotaryConfig): + def _get_configurable_class(self) -> "type[Rotary2D]": + from fast_llm.layers.transformer.rotary.rotary import Rotary2D + + return Rotary2D diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index d57d7294..6250fd4a 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -12,6 +12,7 @@ DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, + Rotary2DConfig, RotaryConfig, YarnRotaryConfig, ) @@ -174,3 +175,49 @@ def _get_correction(self, beta: float, dim: int) -> float: * math.log(self._config.original_context_length / (beta * 2 * math.pi)) / (2 * math.log(self._config.theta)) ) + + +class Rotary2D[ConfigType: Rotary2DConfig](DefaultRotary[ConfigType]): + _rotary_embedding_frequencies: torch.Tensor + _tensor_cache_max_num_patches: int = -1 + _config: ConfigType + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors( + kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size], batch.device + ) + position_ids = kwargs[VisionTransformerKwargs.patch_position_ids] + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + + def forward( + self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] + ) -> tuple[torch.Tensor, torch.Tensor]: + rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings + query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) + return query, key + + def _get_frequencies(self, sequence_length: int, head_size: int, device: torch.device) -> torch.Tensor: + max_num_patches = sequence_length + # Calculate complex frequencies by using alternating channels for width and height + height_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + width_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + frequencies = self._config.theta ** -torch.arange(0, 1, 2 / head_size, device=device, dtype=torch.float64) + angles_h = torch.outer(height_positions, frequencies[::2]) + angles_w = torch.outer(width_positions, frequencies[1::2]) + angles = torch.cat( + [ + angles_h[:, None, :].repeat(1, max_num_patches, 1), + angles_w[None, :, :].repeat(max_num_patches, 1, 1), + ], + dim=-1, + ).reshape(-1, head_size // 2) + + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not self._config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), head_size, 3 + ).contiguous() + + return frequencies diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index ab6cb22b..67ce5eea 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -103,6 +103,7 @@ def __init__( config: ConfigType, distributed_config: DistributedConfig, *, + # TODO: Review. Use `input_dim(s)` and `output_dim(s)` instead? hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index e7c6d9e9..0dc11826 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -1,7 +1,12 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.engine.config_utils.initialization import Initialization, init_uniform_centered_, init_zeros_ +from fast_llm.engine.config_utils.initialization import ( + Initialization, + init_normal_, + init_uniform_centered_, + init_zeros_, +) from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.config import ActivationType @@ -9,7 +14,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.common.linear.convolution import CausalConv1d + from fast_llm.layers.common.linear.convolution import CausalConv1d, Convolution2D from fast_llm.layers.common.linear.linear import LinearBase @@ -217,3 +222,44 @@ def get_layer( return CausalConv1d( weight, bias, activation=default_activation if self.activation is None else self.activation ) + + +@config_class +class Convolution2DConfig(AffineLinearBaseConfig): + def get_layer( + self, + in_dim: TensorDim, + out_dim: TensorDim, + kernel_dim_1: TensorDim, + kernel_dim_2: TensorDim, + *, + stride: tuple[int, int], + default_weight_initialization: Initialization | None = None, + default_bias_initialization: Initialization | None = None, + default_add_bias: bool = True, + lr_scale: float | None, + peft: PeftConfig | None, + ) -> "Convolution2D": + from fast_llm.layers.common.linear.convolution import Convolution2D + + if default_weight_initialization is None: + default_weight_initialization = init_normal_() + if default_bias_initialization is None: + default_bias_initialization = init_normal_() + + lr_scale = (combine_lr_scales(lr_scale, self.lr_scale),) + weight = self.weight.get_parameter( + (out_dim, in_dim, kernel_dim_1, kernel_dim_2), + default_initialization=default_weight_initialization, + lr_scale=lr_scale, + peft=peft, + ) + bias = self.bias.get_parameter( + (out_dim,), + default_initialization=default_bias_initialization, + lr_scale=lr_scale, + default_enabled=default_add_bias, + peft=peft, + ) + + return Convolution2D(weight, bias, stride=stride) diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index b88b7b2e..6281348e 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -55,3 +55,27 @@ def _forward_causal_conv1d(self, input_: torch.Tensor) -> torch.Tensor: def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: raise NotImplementedError() + + +class Convolution2D(torch.nn.Module): + """ + TODO: Generalize to other convolutions? + """ + + def __init__( + self, + weight: ParameterMeta, + bias: ParameterMeta | None, + *, + stride: tuple[int, int], + ): + super().__init__() + self.weight = weight + self.bias = bias + self._stride = stride + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self._stride) + + def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: + raise NotImplementedError() diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index d2fbc490..25fa2d91 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -2,7 +2,6 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -16,6 +15,7 @@ if typing.TYPE_CHECKING: from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase + from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction @@ -41,12 +41,6 @@ class LanguageModelEmbeddingsConfig(BlockConfig): desc="Configuration for the word embedding (weight).", hint=FieldHint.architecture, ) - hidden_size: int = Field( - default=1024, - desc="Size of the model's main hidden dimension, e.g., for its input and output layers.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", @@ -295,24 +289,29 @@ def max_prediction_distance(self) -> int: @config_class() -class LanguageModelConfig(ModuleConfig): - # TODO: block +class LanguageModelConfig(BlockConfig): decoder: BlockSequenceConfig = Field( desc="Configuration for the language model decoder.", hint=FieldHint.architecture, ) - embeddings: LanguageModelEmbeddingsConfig = Field() - head: LanguageModelHeadBaseConfig = Field() - # TODO: Allow overriding in sub-models? - peft: PeftConfig = Field( - desc="Configuration for parameter-efficient fine tuning.", + embeddings: LanguageModelEmbeddingsConfig = Field( hint=FieldHint.architecture, + desc="Configuration for the language model embeddings.", + ) + head: LanguageModelHeadBaseConfig = Field( + hint=FieldHint.architecture, desc="Configuration for the language model head(s)." ) tied_embedding_weight: bool = Field( default=False, desc="Tie the output weights (logits) with the vocabulary embedding.", hint=FieldHint.architecture, ) + hidden_size: int = Field( + default=1024, + desc="Size of the model's main hidden dimension, e.g., for its input and output layers.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) sequence_first: bool | None = Field( default=None, desc="Override the default dimension ordering", @@ -321,3 +320,9 @@ class LanguageModelConfig(ModuleConfig): " Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.", hint=FieldHint.testing, ) + + @property + def layer_class(self) -> "type[LanguageModel]": + from fast_llm.layers.language_model.language_model import LanguageModel + + return LanguageModel diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 0ad3225c..17ee3261 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -3,7 +3,7 @@ import torch from fast_llm.core.distributed import set_generator -from fast_llm.core.ops import reduce_forward, split +from fast_llm.core.ops import gather, reduce_forward, split from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim @@ -14,6 +14,8 @@ from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert +WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" + class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[ConfigType]): """ @@ -26,7 +28,8 @@ class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[Co layer_count: float = 1000.0 _config: ConfigType - # Position embedding preprocessing + # Preprocessing + _rotary_embedding_frequencies: torch.Tensor _position_ids: torch.Tensor _tensor_cache_max_sequence_length: int = -1 @@ -75,34 +78,62 @@ def __init__( ) @torch.compile - def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: + def _forward( + self, + input_: torch.Tensor, + token_ids: torch.Tensor, + position_ids: torch.Tensor | None, + mask_inputs: bool, + # TODO: Flatten the batch and sequence in the map? + embedding_map: tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]] | None, + ) -> torch.Tensor: Assert.eq(position_ids is None, self.position_embeddings_weight is None) group = self._parallel_dim.group if self._vocab_parallel: - input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) - masked_input = (input_ - self._vocab_start_index) * input_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa + token_mask = (token_ids >= self._vocab_start_index) * (token_ids < self._vocab_end_index) + masked_input = (token_ids - self._vocab_start_index) * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * token_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) + # TODO: Input masking of position embeddings inconsistant with non-vocab-parallel if self.position_embeddings_weight is not None: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + + if embedding_map is not None: + # TODO: Accumulate redundant with masking? + input_index, embedding_index = embedding_map + if self._sequence_parallel: + input_ = gather(input_, group=group, dim=0) + embeddings = embeddings.index_put(embedding_index, input_[input_index], accumulate=True) + if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: - input_ = split(input_, group=group, dim=0) + token_ids = split(token_ids, group=group, dim=0) if self.position_embeddings_weight is not None: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: - input_mask = input_ >= 0 - masked_input = input_ * input_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_input) - else: - embeddings = torch.embedding(self.word_embeddings_weight, input_) + token_mask = token_ids >= 0 + token_ids = token_ids * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, token_ids) if self.position_embeddings_weight is not None: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: - embeddings = embeddings * input_mask.unsqueeze(2) + embeddings = embeddings * token_mask.unsqueeze(2) + + if embedding_map is not None: + # TODO: Accumulate redundant with masking? + input_index, embedding_index = embedding_map + if self._sequence_parallel: + # TODO:: Filter and shift embedding map instead? (needs cuda sync) + input_ = gather(input_, group=group, dim=0) + embeddings_ = embeddings.new_zeros(embeddings.shape[0] * group.size(), *embeddings.shape[1:]) + embeddings_.index_put(embedding_index, input_[input_index], accumulate=True) + embeddings = embeddings + split(embeddings_, group=group, dim=0) + else: + embeddings = embeddings.index_put(embedding_index, input_[input_index], accumulate=True) + with set_generator( self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): @@ -119,11 +150,17 @@ def forward( if isinstance(input_, TensorMeta): return TensorMeta.from_dims( kwargs[LanguageModelKwargs.hidden_dims], - tensor_name=f"{self.module_name} output", + tensor_name="Embedding output", dtype=self._residual_dtype, ) + return self._forward( - input_, kwargs.get(LanguageModelKwargs.position_ids), kwargs.get(LanguageModelKwargs.mask_inputs) + input_, + kwargs.get(LanguageModelKwargs.token_ids), + kwargs.get(LanguageModelKwargs.position_ids), + # TODO ====== Vision ====== Review input masking. + kwargs.get(LanguageModelKwargs.mask_inputs), + kwargs.get(LanguageModelKwargs.embedding_map), ) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 9a3bef19..2e46bb57 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -1,52 +1,64 @@ import logging import typing -from fast_llm.config import Configurable -from fast_llm.engine.base_model.base_model import Layer, LayerBase +import torch + +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding logger = logging.getLogger(__name__) -class LanguageModel[ConfigType: LanguageModelConfig](Configurable[ConfigType], LayerBase): +class LanguageModel[ConfigType: LanguageModelConfig](BlockBase[ConfigType]): _config: ConfigType def __init__( self, config: ConfigType, distributed_config: DistributedConfig, + *, + # TODO: Unused, but required by the `BlockBase` interface. + hidden_dim: TensorDim | None = None, + lr_scale: float | None, + peft: PeftConfig | None, ): - super().__init__(config, distributed_config) - - self._hidden_dim = TensorDim("hidden", config.embeddings.hidden_size) + super().__init__( + config, + distributed_config, + hidden_dim=TensorDim("hidden", config.hidden_size), + lr_scale=lr_scale, + peft=peft, + ) self.embeddings: LanguageModelEmbedding = self._config.embeddings.get_layer( distributed_config, hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, + lr_scale=self._lr_scale, + peft=self._peft, ) self.decoder = self._config.decoder.get_layer( distributed_config, self._hidden_dim, - lr_scale=None, - peft=self._config.peft, + lr_scale=self._lr_scale, + peft=self._peft, ) self.head = self._config.head.get_layer( distributed_config, self._config.embeddings, hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, + lr_scale=self._lr_scale, + peft=self._peft, ) - def get_layers(self) -> list["Layer"]: + def get_layers(self) -> list[Layer]: return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() - def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(batch, kwargs) self.decoder.preprocess(batch, kwargs) diff --git a/fast_llm/layers/vision/__init__.py b/fast_llm/layers/vision/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py new file mode 100644 index 00000000..1af986ee --- /dev/null +++ b/fast_llm/layers/vision/config.py @@ -0,0 +1,169 @@ +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig +from fast_llm.layers.common.linear.config import Convolution2DConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig +from fast_llm.layers.decoder.config import MLPBaseConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.layers.vision.vision_encoder import VisionEncoder + + +@config_class() +class ImageNormalizationConfig(Config): + mean_r: float = Field( + default=0.48145466, + desc="Mean value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_g: float = Field( + default=0.4578275, + desc="Mean value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_b: float = Field( + default=0.40821073, + desc="Mean value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_r: float = Field( + default=0.26862954, + desc="Standard deviation value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_g: float = Field( + default=0.26130258, + desc="Standard deviation value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_b: float = Field( + default=0.27577711, + desc="Standard deviation value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Rescale factor for the image normalization process.", + hint=FieldHint.optional, + ) + + +@config_class() +class PatchConvolutionConfig(BlockConfig): + _abstract = False + convolution: Convolution2DConfig = Field( + desc="Configuration for the 2d convolution.", + hint=FieldHint.architecture, + ) + normalization: NormalizationConfig = Field( + desc="Configuration for the normalization layer.", + hint=FieldHint.architecture, + ) + patch_size: int = Field( + default=16, + desc="Size of image patches, in pixels (width and height).", + hint=FieldHint.core, + ) + input_channels: int = Field( + default=3, + desc="Number of pixel channels (usually 3).", + hint=FieldHint.feature, + ) + + +@config_class(registry=True) +class VisionEncoderConfig(BlockConfig): + _abstract = False + patch_convolution: PatchConvolutionConfig = Field( + desc="Configuration for the patch convolution layer.", + hint=FieldHint.architecture, + ) + adapter: MLPBaseConfig = Field( + desc="Configuration for the adapter layer.", + hint=FieldHint.architecture, + ) + # TODO: ====== Appropriate name?? ====== + decoder: BlockSequenceConfig = Field( + desc="Configuration for the vision decoder.", + hint=FieldHint.architecture, + ) + hidden_size: int = Field( + default=1024, + desc="Size of the vision encoder main hidden dimension.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + + @property + def layer_class(self) -> "type[VisionEncoder]": + from fast_llm.layers.vision.vision_encoder import VisionEncoder + + return VisionEncoder + + # transformer: TransformerConfig = Field( + # desc="Configuration for the vision transformer architecture.", + # hint=FieldHint.core, + # ) + # patch_size: int = Field( + # default=16, + # desc="Patch size for the image encoder.", + # hint=FieldHint.core, + # ) + # conv_bias: bool = Field( + # default=False, + # desc="Whether to use bias in the convolutional layer.", + # hint=FieldHint.optional, + # ) + # patch_norm: NormalizationConfig = Field( + # desc="Configuration for the normalization layers applied to the image patches.", + # hint=FieldHint.optional, + # ) + # adapter_size: int = Field( + # default=5120, + # desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + # hint=FieldHint.core, + # ) + # adapter_activation_type: ActivationType = Field( + # default=ActivationType.gelu, + # desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + # hint=FieldHint.core, + # ) + # adapter_bias: bool = Field( + # default=True, + # desc="Whether to use bias in the adapter linear layer.", + # hint=FieldHint.optional, + # ) + # image_normalization: ImageNormalizationConfig = Field( + # desc="Configuration for the normalization layers applied to the image patches.", + # hint=FieldHint.optional, + # ) + # image_break_token: int | None = Field( + # default=None, + # desc="Token id to separate image rows. If None, no token id is applied.", + # hint=FieldHint.optional, + # ) + # image_end_token: int | None = Field( + # default=None, + # desc="Token id to indicate the end of an image. If None, no token id is applied.", + # hint=FieldHint.optional, + # ) + # adapter_lr_scale: float | None = Field( + # default=None, + # desc="Custom learning rate scale for the adapter weights.", + # hint=FieldHint.feature, + # valid=skip_valid_if_none(check_field(Assert.geq, 0)), + # ) + # conv_lr_scale: float | None = Field( + # default=None, + # desc="Custom learning rate scale for the convolutional layer weights.", + # hint=FieldHint.feature, + # valid=skip_valid_if_none(check_field(Assert.geq, 0)), + # ) + # adapter_init_method_std: float = Field( + # default=None, + # desc="Standard deviation for the normal initialization of the adapter weights. Default: adapter_size ** -0.5.", + # hint=FieldHint.optional, + # valid=check_field(Assert.geq, 0), + # ) diff --git a/fast_llm/layers/vision/patch_convolution.py b/fast_llm/layers/vision/patch_convolution.py new file mode 100644 index 00000000..46cf8670 --- /dev/null +++ b/fast_llm/layers/vision/patch_convolution.py @@ -0,0 +1,71 @@ +import typing + +import torch + +from fast_llm.core.ops import split +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.block.block import Block +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.vision.config import PatchConvolutionConfig +from fast_llm.tensor import TensorMeta + + +class PatchConvolution[ConfigType: PatchConvolutionConfig](Block[ConfigType]): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + # TODO: Input or output dim? + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + input_dim = TensorDim("input_channels", self._config.input_channels) + patch_dim = TensorDim("patch", self._config.patch_size) + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + self.convolution = self._config.convolution.get_layer( + self._hidden_dim, + input_dim, + patch_dim, + patch_dim, + stride=(self._config.patch_size, self._config.patch_size), + default_add_bias=False, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.normalization = self._config.normalization.get_layer(hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + input_.dims[:-1] + (self._hidden_dim,), tensor_name="patch conv output", dtype=input_.dtype + ) + # TODO: Avoid padding + input_ = self.convolution(input_) + patch_embeddings = self.normalization(input_.flatten(1)).view_as(input_) + + # TODO: Permute earlier? + if kwargs[AttentionKwargs.sequence_first]: + patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() + if self._sequence_parallel: + patch_embeddings = split(patch_embeddings, group=self._parallel_dim.group, dim=0) + return patch_embeddings diff --git a/fast_llm/layers/vision/preprocessing.py b/fast_llm/layers/vision/preprocessing.py new file mode 100644 index 00000000..83331c73 --- /dev/null +++ b/fast_llm/layers/vision/preprocessing.py @@ -0,0 +1,194 @@ +import math +import typing + +import torch +import torchvision.transforms.v2 as torchvision_transforms + +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.vision.config import ImageNormalizationConfig, VisionEncoderConfig +from fast_llm.utils import div + + +def get_num_patches(height: int, width: int, patch_size: int) -> int: + """ + Calculate the number of patches in height and width dimensions. + """ + return div(height, patch_size) * div(width, patch_size) + + +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) + + +def resize(image: torch.Tensor, target_height: int, target_width: int) -> torch.Tensor: + # cap the resizing to half of the current size as a workaround for large images + # See pytorch issue: https://github.com/pytorch/pytorch/issues/103589 + while max(image.size(1) / target_height, image.size(2) / target_width) > 2: + image = torchvision_transforms.functional.resize( + image, + size=(math.ceil(image.size(1) / 2), math.ceil(image.size(2) / 2)), + interpolation=torchvision_transforms.InterpolationMode.BICUBIC, + ) + + # TODO: options for interpolation mode? + return torchvision_transforms.functional.resize( + image, size=(target_height, target_width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC + ) + + +def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: + patch_height = height // patch_size + patch_width = width // patch_size + return torch.arange(patch_height).repeat_interleave(patch_width) * max_size + torch.arange(patch_width).repeat( + patch_height + ) + + +class VisionPreprocessor: + def __init__(self, config: VisionEncoderConfig, distributed: Distributed): + self._config = config + self._distributed = distributed + + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: + max_image_size = kwargs.get(VisionEncoderKwargs.max_image_size) + patch_size = self._config.patch_size + image_sizes = [] + + norm_config: ImageNormalizationConfig = kwargs["norm_config"] + + if LanguageModelKwargs.labels in kwargs: + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() + patches = [] + patch_position_ids = [] + sequence_lengths = [0] + max_sequence_length = -1 + + for sample_index, (sample_images_, positions) in enumerate( + zip(kwargs[VisionEncoderKwargs.images], kwargs.get(VisionEncoderKwargs.image_positions), strict=True) + ): + image_sizes.append(sample_image_sizes := []) + + sample_sequence_length = 0 + + for image, position in zip(sample_images_, positions, strict=True): + height, width = get_resize_dims( + image.size(1), image.size(2), max_image_size, max_image_size, patch_size=patch_size + ) + + sample_image_sizes.append((height, width)) + + image = resize(image, height, width) + + # TODO: Normalize with constant dtype instead? + image = image.to(dtype=self._distributed.config.training_dtype.torch) + + image = torchvision_transforms.functional.normalize( + image / norm_config.rescale_factor, + mean=[norm_config.mean_r, norm_config.mean_g, norm_config.mean_b], + std=[norm_config.std_r, norm_config.std_g, norm_config.std_b], + ) + patches.extend( + torch.nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).T.reshape( + -1, 3, patch_size, patch_size + ) + ) + + num_height_patches = div(height, patch_size) + num_width_patches = div(width, patch_size) + grid_height = torch.arange(num_height_patches).repeat_interleave(num_width_patches) + grid_width = torch.arange(num_width_patches).repeat(num_height_patches) + grid_height * div(max_image_size, patch_size) + grid_width + patch_position_ids.append(grid_height * div(max_image_size, patch_size) + grid_width) + + if LanguageModelKwargs.labels in kwargs: + num_tokens = get_num_image_tokens( + height, + width, + patch_size=patch_size, + image_break=self._config.image_break_token is not None, + image_end=self._config.image_end_token is not None, + ) + # set labels for image patches to -100 + labels[sample_index, max(position - 1, 0) : position + num_tokens - 1] = -100 + + sequence_lengths.append(sequence_length := num_height_patches * num_width_patches) + if sequence_length > max_sequence_length: + max_sequence_length = sequence_length + sample_sequence_length += sequence_length + + # TODO: No need for padding with varlen? + padding_size = kwargs[AttentionKwargs.sequence_length] - sample_sequence_length + if padding_size > max_sequence_length: + max_sequence_length = padding_size + sequence_lengths.append(padding_size) + + patches.append( + torch.zeros(padding_size, 3, patch_size, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ), + ) + patch_position_ids.append(torch.full((padding_size,), 0, dtype=torch.int64)) + + kwargs[VisionEncoderKwargs.image_sizes] = image_sizes + kwargs[VisionEncoderKwargs.image_patches] = torch.cat(patches).to(device=self._distributed.device) + kwargs[VisionTransformerKwargs.patch_position_ids] = torch.cat(patch_position_ids).to( + device=self._distributed.device + ) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size**2, patch_size**2) + # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k + kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( + cu_seqlens, device=self._distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.cu_seqlens_k] = torch.tensor( + cu_seqlens, device=self._distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.max_seqlen_q] = max_sequence_length + kwargs[VisionTransformerKwargs.max_seqlen_k] = max_sequence_length + if LanguageModelKwargs.labels in kwargs: + kwargs[LanguageModelKwargs.labels] = labels + + # TODO: add proper preprocessing for attention-mask when not using flash attention + # Following is just a dummy code to run the tests. + kwargs[self._config.transformer._transformer_kwargs.attention_mask] = torch.ones( + (1, 1, kwargs[AttentionKwargs.sequence_length], 1, kwargs[AttentionKwargs.sequence_length]), + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + kwargs[self._config.transformer._transformer_kwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._distributed.config.training_dtype.torch).min, + dtype=self._distributed.config.training_dtype.torch, + device=self._distributed.device, + ) diff --git a/fast_llm/layers/vision/vision_encoder.py b/fast_llm/layers/vision/vision_encoder.py new file mode 100644 index 00000000..b4fa189d --- /dev/null +++ b/fast_llm/layers/vision/vision_encoder.py @@ -0,0 +1,67 @@ +import logging +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.vision.config import VisionEncoderConfig + +logger = logging.getLogger(__name__) + + +class VisionEncoder[ConfigType: VisionEncoderConfig](BlockBase[VisionEncoderConfig]): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + vision_hidden_dim = TensorDim("hidden", self._config.hidden_size) + super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) + self.patch_convolution = self._config.patch_convolution.get_layer( + distributed_config, + vision_hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + # TODO: ====== Appropriate name?? ====== + self.decoder = self._config.decoder.get_layer( + distributed_config, + vision_hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + # TODO: ====== Hidden dim ====== + self.adapter = self._config.adapter.get_layer( + distributed_config, + vision_hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + def get_layers(self) -> list["Layer"]: + return self.patch_convolution.get_layers() + self.decoder.get_layers() + self.adapter.get_layers() + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + self.patch_convolution.preprocess(batch, kwargs) + self.decoder.preprocess(batch, kwargs) + self.adapter.preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + return ( + self.patch_convolution.get_loss_definitions(count) + + self.decoder.get_loss_definitions(count) + + self.adapter.get_loss_definitions(count) + ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 1e57f3b8..a901a046 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -10,6 +10,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig from fast_llm.models.gpt.conversion.config import ( AprielHybridSSMCheckpointFormat, @@ -84,6 +85,11 @@ def micro_batch_splits(self) -> int: class GPTBaseModelConfig(LanguageModelConfig, BaseModelConfig): _abstract = False + # TODO: Allow overriding in sub-models? + peft: PeftConfig = Field( + desc="Configuration for parameter-efficient fine tuning.", + hint=FieldHint.architecture, + ) # Debug, to get an exact match with megatron init. use_megatron_initialization: bool = Field( default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 786d923f..a9249226 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -449,19 +449,13 @@ def get_converters( class LlamaEmbeddingsConverter: @classmethod def import_config(cls, config: dict) -> dict: - return { - "vocab_size": config["vocab_size"], - "hidden_size": config["hidden_size"], - } + return {"vocab_size": config["vocab_size"]} @classmethod def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict: Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig) assert not config.position_embeddings.enabled - return { - "vocab_size": config.vocab_size, - "hidden_size": config.hidden_size, - } + return {"vocab_size": config.vocab_size} @classmethod def get_converters( @@ -516,6 +510,7 @@ def import_config(cls, config: dict) -> dict: "embeddings": cls.embeddings_converter_class.import_config(config), "decoder": cls.decoder_converter_class.import_config(config), "head": cls.head_converter_class.import_config(config), + "hidden_size": config["hidden_size"], "tied_embedding_weight": config["tie_word_embeddings"], } @@ -526,7 +521,10 @@ def export_config(cls, config: GPTBaseModelConfig) -> dict: cls.embeddings_converter_class.export_config(config.embeddings), cls.decoder_converter_class.export_config(config.decoder), cls.head_converter_class.export_config(config.head), - {"tie_word_embeddings": config.tied_embedding_weight}, + { + "tie_word_embeddings": config.tied_embedding_weight, + "hidden_size": config.hidden_size, + }, ) @classmethod diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2c1fb0e4..efa348ec 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -30,16 +30,14 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](LanguageModel[ConfigType], Ba def __init__( self, - config: GPTBaseModelConfig, + config: ConfigType, distributed_config: DistributedConfig, ): - super().__init__(config, distributed_config) + super().__init__(config, distributed_config, lr_scale=config.lr_scale, peft=config.peft) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) - param.init_parameter = get_init_megatron( - param, self._config.decoder.block, config.embeddings.hidden_size - ) # Noqa + param.init_parameter = get_init_megatron(param, self._config.decoder.block, config.hidden_size) # Noqa def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType diff --git a/fast_llm/models/multimodal/__init__.py b/fast_llm/models/multimodal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fast_llm/models/multimodal/config.py b/fast_llm/models/multimodal/config.py new file mode 100644 index 00000000..2415734e --- /dev/null +++ b/fast_llm/models/multimodal/config.py @@ -0,0 +1,89 @@ +import logging +import typing + +from fast_llm.config import Field, FieldHint, FieldUpdate, config_class +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.vision.config import VisionEncoderConfig +from fast_llm.models.gpt.config import ( + GPTBaseModelConfig, + GPTBatchConfig, + GPTModelConfig, + GPTTrainerConfig, + PretrainedGPTModelConfig, +) + +if typing.TYPE_CHECKING: + from fast_llm.models.multimodal.huggingface import HuggingfaceMultiModalModelForCausalLM + from fast_llm.models.multimodal.model import MultiModalBaseModel, MultiModalModel, MultiModalModelInferenceRunner + from fast_llm.models.multimodal.trainer import MultiModalTrainer + +logger = logging.getLogger(__name__) + + +@config_class() +class MultiModalBatchConfig(GPTBatchConfig): + pass + + +@config_class() +class MultiModalBaseModelConfig(GPTBaseModelConfig): + vision_encoder: VisionEncoderConfig = Field( + hint=FieldHint.architecture, + desc="Configuration for the vision encoder.", + ) + + @property + def base_model_class(self) -> type["MultiModalBaseModel"]: + from fast_llm.models.multimodal.model import MultiModalBaseModel + + return MultiModalBaseModel + + +@config_class(dynamic_type={FastLLMModelConfig: "gpt"}) +class MultiModalModelConfig(GPTModelConfig): + _abstract = False + model_name: typing.ClassVar[str] = "gpt" + base_model: GPTBaseModelConfig = FieldUpdate() + # TODO: ====== Conversion ====== + checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + + @classmethod + def get_model_class(cls) -> type["MultiModalModel"]: + from fast_llm.models.multimodal.model import MultiModalModel + + return MultiModalModel + + @classmethod + def get_inference_runner_class(cls) -> type["MultiModalModelInferenceRunner"]: + from fast_llm.models.multimodal.model import MultiModalModelInferenceRunner + + return MultiModalModelInferenceRunner + + @classmethod + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceMultiModalModelForCausalLM"]: + from fast_llm.models.multimodal.huggingface import HuggingfaceMultiModalModelForCausalLM + + return HuggingfaceMultiModalModelForCausalLM + + +@config_class() +class PretrainedMultiModalModelConfig(PretrainedGPTModelConfig): + _abstract = False + model: MultiModalModelConfig = FieldUpdate() + + +@config_class(dynamic_type={RunnableConfig: "train_gpt", TrainerConfig: "gpt"}) +class MultiModalTrainerConfig(PretrainedMultiModalModelConfig, GPTTrainerConfig): + data: MultiModalDataConfig = FieldUpdate() + batch: MultiModalBatchConfig = FieldUpdate() + # TODO: Use dynamic model type? + reference_models: dict[str, PretrainedMultiModalModelConfig] = FieldUpdate() + + @classmethod + def get_trainer_class(cls) -> type["MultiModalTrainer"]: + from fast_llm.models.multimodal.trainer import MultiModalTrainer + + return MultiModalTrainer diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py new file mode 100644 index 00000000..7426191f --- /dev/null +++ b/fast_llm/models/multimodal/model.py @@ -0,0 +1,133 @@ +import logging +import typing + +import torch + +from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.inference.runner import InferenceRunner +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.models.gpt.model import GPTBaseModel, GPTModel +from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalBatchConfig, MultiModalModelConfig +from fast_llm.tensor import TensorMeta + +logger = logging.getLogger(__name__) + + +class MultiModalBaseModel[ConfigType: MultiModalBaseModelConfig](GPTBaseModel[ConfigType]): + """ + A transformer-based language model generalizing the GPT model architecture. + """ + + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + ): + super().__init__(config, distributed_config) + self.vision_encoder = self._config.vision_encoder.get_layer( + distributed_config, + self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + + def preprocess_meta( + self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType + ) -> list[tuple[TensorMeta, dict]]: + # TODO Remove (Move batch splitting elsewhere) + # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # try: + # max_image_size = batch_meta.max_image_size + # except AttributeError: + # max_image_size = 256 + # logger.warning("Inference mode: max_image_size not provided, defaulting to 256") + # vision_kwargs = { + # VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, + # VisionEncoderKwargs.max_image_size: max_image_size, + # VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, + # VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size, + # VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size, + # } + # vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden] + # vision_hidden_dims = ( + # (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + # if sequence_first + # else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + # ) + # vision_kwargs.update( + # { + # VisionTransformerKwargs.hidden_dims: vision_hidden_dims, + # } + # ) + # common_kwargs.update(vision_kwargs) + + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + # preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + # else: + # preprocessed_meta.append((tokens, kwargs)) + pass + + def preprocess_batch( + self, + batch: GPTBatch, + preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, + *, + phase: PhaseType, + iteration: int, + metrics: dict | None = None, + ) -> list[tuple[torch.Tensor, dict]]: + # TODO Move batch splitting elsewhere, align interface with LayerBase + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # if self._config.vision_encoder.image_break_token is not None: + # if not labels_cloned: + # labels = labels.clone() + # labels_cloned = True + # labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + # if self._config.vision_encoder.image_end_token is not None: + # if not labels_cloned: + # labels = labels.clone() + # labels_cloned = True + # labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) + # Loss-masking for distillation losses + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # batch_images = ( + # batch.images if batch.images is not None else [[]] * kwargs[AttentionKwargs.micro_batch_size] + # ) + # kwargs[VisionEncoderKwargs.images] = [ + # [ + # img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + # for img in images + # ] + # for images in batch_images + # ] + # kwargs[VisionEncoderKwargs.image_positions] = ( + # batch.image_positions + # if batch.image_positions is not None + # else [[]] * kwargs[AttentionKwargs.micro_batch_size] + # ) + # kwargs[LanguageModelKwargs.tokens] = tokens + # image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + # if image_patches is not None: + # preprocessed.append((image_patches, kwargs)) + # else: + # preprocessed.append((tokens, kwargs)) + pass + + +class MultiModalModel[ConfigType: MultiModalModelConfig](GPTModel[ConfigType]): + # TODO: Can we drop class? + pass + + +class MultiModalInferenceRunner(InferenceRunner): + model_class: typing.ClassVar[type[MultiModalModel]] = MultiModalModel + batch_config_class: typing.ClassVar[type[MultiModalBatchConfig]] = MultiModalBatchConfig diff --git a/fast_llm/models/multimodal/trainer.py b/fast_llm/models/multimodal/trainer.py new file mode 100644 index 00000000..c4071aaf --- /dev/null +++ b/fast_llm/models/multimodal/trainer.py @@ -0,0 +1,14 @@ +import logging + +from fast_llm.models.gpt.trainer import GPTTrainer +from fast_llm.models.multimodal.config import MultiModalTrainerConfig + +logger = logging.getLogger(__name__) + + +class MultiModalTrainer[ConfigType: MultiModalTrainerConfig](GPTTrainer[ConfigType]): + def _get_data(self) -> MultiModalData: + return MultiModalData( + config=self._config.data, + distributed_config=self._config.model.distributed, + ) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index bbd69ae8..1f9feceb 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -316,9 +316,7 @@ def new_decorator(*args, **kwargs): return new_decorator -def compare_nested( - config_a, config_b, errors: list | None = None, prefix: tuple = (), ignore_missing: tuple[str, ...] = () -): +def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple = ()): if errors is None: errors = [] # Check for equality of both values and types. diff --git a/setup.cfg b/setup.cfg index 77073ab5..2a161455 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ OPTIONAL = # Huggingface tools HUGGINGFACE = - transformers>=4.52.4 + transformers==4.53.2 hf-transfer>=0.1.9 datasets>=3.6.0 huggingface-hub>=0.32.6 @@ -59,6 +59,13 @@ GENERATION = lm_eval>=0.4.9 +# Required for supporting vision inputs +VISION = + # Vision Tools + webp>=0.4.0 + pillow-simd>=9.5.0 + torchvision>=0.20.0 + DEV = # Pre-commit git hook pre-commit>=4.2.0 diff --git a/tests/data/common.py b/tests/data/common.py index d8cc6fff..afee8694 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -122,10 +122,10 @@ def compare_indexed_dataset( loss_masking_spans: dict[int, list[int]] | None = None, ) -> None: Assert.eq(len(dataset), length) - sizes = dataset.get_document_sizes() + text_sizes, image_sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] + [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], text_sizes[: min(len(dataset), 100)] ) for i, expected_sample in expected_samples.items(): Assert.all_equal(dataset.get(i).token_ids, np.array(expected_sample, dtype=np.uint16)) @@ -218,10 +218,10 @@ def name(self) -> str: def __len__(self) -> int: return self._config.num_documents - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) - def get_document_size(self, index: int) -> int: + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: return self._config.num_tokens_per_document def get(self, index: int, *args, **kwargs) -> typing.Any: diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 0de823e2..0dc2421a 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -171,13 +171,8 @@ def test_lm_head( } config = GPTBaseModelConfig.from_dict( { - "decoder": { - "num_blocks": 0, - }, - "embeddings": { - "vocab_size": VOCAB_SIZE, - "hidden_size": HIDDEN_SIZE, - }, + "decoder": {"num_blocks": 0}, + "embeddings": {"vocab_size": VOCAB_SIZE}, "head": ( head_config if prediction_heads == 1 @@ -187,6 +182,7 @@ def test_lm_head( "prediction_heads": prediction_heads, } ), + "hidden_size": HIDDEN_SIZE, }, config_dict, update_type=UpdateType.update, @@ -255,7 +251,7 @@ def test_lm_head( logit_weight = torch.nn.Parameter( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device - ).normal_(config.embeddings.hidden_size**-0.5) + ).normal_(config.hidden_size**-0.5) ) else: logit_weight = None diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 3c3bfb83..5db18d7f 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -343,12 +343,15 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) ) errors = [] - auto_model = ( - transformers.AutoModel - if model_testing_config.name in ("diffusion_llama", "dream") - else transformers.AutoModelForCausalLM - ) - model_as_hf = auto_model.from_pretrained(hf_path, trust_remote_code=True).cuda() + if model_testing_config.name in ("diffusion_llama", "dream"): + auto_model = transformers.AutoModel + elif model_testing_config.name in ("llava", "vision_hybrid_mamba2"): + auto_model = transformers.AutoModelForVision2Seq + else: + auto_model = transformers.AutoModelForCausalLM + model_as_hf = auto_model.from_pretrained( + hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code + ).cuda() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), (model_from_fast_llm, model_from_hf, model_as_hf), diff --git a/tests/test_config.py b/tests/test_config.py index 32620053..63f2606f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -74,9 +74,6 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config = GPTModelConfig.from_dict( { "base_model": { - "embeddings": { - "hidden_size": 1024, # Default - }, "decoder": { "block": { "mixer": { @@ -92,6 +89,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, # Default }, + "hidden_size": 1024, # Default "tied_embedding_weight": False, }, "multi_stage": {"zero_stage": 3}, @@ -105,7 +103,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config.save_metadata(save_config) base_model_update = { - "embeddings": {"hidden_size": 512, "vocab_size": 1000}, + "embeddings": {"vocab_size": 1000}, "decoder": { "block": { "mixer": { @@ -115,6 +113,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "normalization": {"implementation": "triton"}, # Update non-default nested }, }, + "hidden_size": 512, "peft": {"type": "lora", "freeze_others": False}, # Update default nested, change type } pretrained_config = PretrainedGPTModelConfig.from_dict( @@ -134,10 +133,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): expected_config["distributed"].update({"seed": 1234, "compute_dtype": "float16"}) if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { - "embeddings": { - "hidden_size": 512, - "vocab_size": 1000, - }, + "embeddings": {"vocab_size": 1000}, "decoder": { "block": { "mixer": { @@ -152,6 +148,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, + "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, } diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6b313aa8..c02521d7 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -192,7 +192,6 @@ def _update_and_add_testing_config( "embeddings": { "word_embeddings": init_1, "position_embeddings": {"enabled": True, **init_1}, - "hidden_size": 256, "num_position_embeddings": 512, "vocab_size": MODEL_TEST_VOCAB_SIZE, }, @@ -216,6 +215,7 @@ def _update_and_add_testing_config( "num_blocks": 2, }, "head": {"output_weight": init_1}, + "hidden_size": 256, "tied_embedding_weight": True, }, "multi_stage": { @@ -344,6 +344,8 @@ def _update_and_add_testing_config( ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, ) +del MODEL_CONFIGS["starcoder_2"].config_dict["model"]["base_model"]["embeddings"]["num_position_embeddings"] + _update_and_add_testing_config( # Main tested model. diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py index bbfa4b21..926bcc34 100644 --- a/tools/concatenate_dataset.py +++ b/tools/concatenate_dataset.py @@ -35,6 +35,7 @@ def run(self): dataset_dict = { "prefix": str(prefix.relative_to(self.directory)), "num_documents": len(dataset), + # Todo: fix "num_tokens": dataset.num_tokens, } if self.min_tokens is not None and dataset_dict["num_tokens"] < self.min_tokens: