From 1cf6baeeda5ada47da5f4e45503820dba3e64cc5 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 4 Nov 2024 22:38:31 -0800 Subject: [PATCH] Fix dataset download location (#1639) --- llmfoundry/data/finetuning/tasks.py | 6 ++++-- tests/data/test_dataloader.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index b83aee3aa6..dcc1c5491a 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -34,7 +34,6 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import importlib import logging import os -import tempfile import warnings from collections.abc import Mapping from functools import partial @@ -93,6 +92,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: UnknownExampleTypeError, ) # yapf: enable +from llmfoundry.utils.file_utils import dist_mkdtemp from llmfoundry.utils.logging_utils import SpecificWarningFilter log = logging.getLogger(__name__) @@ -888,6 +888,8 @@ def build_from_hf( signal_file_path = dist.get_node_signal_file_name() + download_folder = dist_mkdtemp() + # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. # Once local rank 0 is done, the datasets are all cached on disk, and all other ranks # can just read them. @@ -913,7 +915,7 @@ def build_from_hf( if not os.path.isdir(dataset_name): # dataset_name is not a local dir path, download if needed. local_dataset_dir = os.path.join( - tempfile.mkdtemp(), + download_folder, dataset_name, ) diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 682d25f7f1..3acffa1f5d 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -455,7 +455,7 @@ def test_finetuning_dataloader_safe_load( tokenizer = build_tokenizer('gpt2', {}) - with patch('llmfoundry.data.finetuning.tasks.tempfile.mkdtemp', return_value=str(tmp_path)): + with patch('llmfoundry.utils.file_utils.tempfile.mkdtemp', return_value=str(tmp_path)), patch('os.cpu_count', return_value=1): with expectation: _ = build_finetuning_dataloader( tokenizer=tokenizer, @@ -1516,7 +1516,8 @@ def test_ft_dataloader_with_extra_keys(): device_batch_size=device_batch_size, ).dataloader -@pytest.mark.xfail +# TODO: Change this back to xfail after figuring out why it caused CI to hang +@pytest.mark.skip def test_text_dataloader_with_extra_keys(): max_seq_len = 1024 cfg = {