Skip to content

Commit

Permalink
Fix dataset download location (#1639)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored and b-chu committed Nov 5, 2024
1 parent b320210 commit 1cf6bae
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
6 changes: 4 additions & 2 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import importlib
import logging
import os
import tempfile
import warnings
from collections.abc import Mapping
from functools import partial
Expand Down Expand Up @@ -93,6 +92,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
UnknownExampleTypeError,
)
# yapf: enable
from llmfoundry.utils.file_utils import dist_mkdtemp
from llmfoundry.utils.logging_utils import SpecificWarningFilter

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -888,6 +888,8 @@ def build_from_hf(

signal_file_path = dist.get_node_signal_file_name()

download_folder = dist_mkdtemp()

# Non local rank 0 ranks will wait here for local rank 0 to finish the data processing.
# Once local rank 0 is done, the datasets are all cached on disk, and all other ranks
# can just read them.
Expand All @@ -913,7 +915,7 @@ def build_from_hf(
if not os.path.isdir(dataset_name):
# dataset_name is not a local dir path, download if needed.
local_dataset_dir = os.path.join(
tempfile.mkdtemp(),
download_folder,
dataset_name,
)

Expand Down
5 changes: 3 additions & 2 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def test_finetuning_dataloader_safe_load(

tokenizer = build_tokenizer('gpt2', {})

with patch('llmfoundry.data.finetuning.tasks.tempfile.mkdtemp', return_value=str(tmp_path)):
with patch('llmfoundry.utils.file_utils.tempfile.mkdtemp', return_value=str(tmp_path)), patch('os.cpu_count', return_value=1):
with expectation:
_ = build_finetuning_dataloader(
tokenizer=tokenizer,
Expand Down Expand Up @@ -1516,7 +1516,8 @@ def test_ft_dataloader_with_extra_keys():
device_batch_size=device_batch_size,
).dataloader

@pytest.mark.xfail
# TODO: Change this back to xfail after figuring out why it caused CI to hang
@pytest.mark.skip
def test_text_dataloader_with_extra_keys():
max_seq_len = 1024
cfg = {
Expand Down

0 comments on commit 1cf6bae

Please sign in to comment.