|
16 | 16 | from itertools import combinations
|
17 | 17 | from typing import Iterable
|
18 | 18 |
|
| 19 | +import dask.dataframe as dd |
19 | 20 | import numpy as np
|
20 | 21 | import pytest
|
21 | 22 | import yaml
|
| 23 | +from dask import config |
22 | 24 | from dask.dataframe.utils import assert_eq
|
23 | 25 | from distributed import Client
|
24 | 26 |
|
25 | 27 | from nemo_curator import LSH, FuzzyDuplicates, FuzzyDuplicatesConfig, MinHash
|
26 | 28 | from nemo_curator.datasets import DocumentDataset
|
| 29 | +from nemo_curator.utils.fuzzy_dedup_utils.merge_utils import extract_partitioning_index |
27 | 30 | from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from
|
28 | 31 |
|
29 | 32 | cudf = gpu_only_import("cudf")
|
@@ -367,3 +370,74 @@ def test_from_yaml(self, tmpdir):
|
367 | 370 | config = FuzzyDuplicatesConfig.from_yaml(tmpdir / "config.yaml")
|
368 | 371 | for param in yaml_params:
|
369 | 372 | assert getattr(config, param) == yaml_params[param]
|
| 373 | + |
| 374 | + |
| 375 | +@pytest.mark.parametrize( |
| 376 | + "backend", |
| 377 | + [ |
| 378 | + "pandas", |
| 379 | + pytest.param( |
| 380 | + "cudf", |
| 381 | + marks=pytest.mark.gpu, |
| 382 | + ), |
| 383 | + ], |
| 384 | +) |
| 385 | +def test_extract_partitioning_index(backend): |
| 386 | + |
| 387 | + def add_partition_info(df, partition_info=None): |
| 388 | + if partition_info is None: |
| 389 | + df["file_id"] = -1 |
| 390 | + else: |
| 391 | + df["file_id"] = partition_info["number"] |
| 392 | + return df |
| 393 | + |
| 394 | + with config.set({"dataframe.backend": backend}): |
| 395 | + |
| 396 | + # Create a random `unshuffled` DataFrame with a |
| 397 | + # "part_id" column to be used as the shuffle index |
| 398 | + npartitions_left = 7 |
| 399 | + unshuffled = dd.from_dict( |
| 400 | + {"part_id": np.random.randint(25, size=1000, dtype="int32")}, |
| 401 | + npartitions=npartitions_left, |
| 402 | + ) |
| 403 | + |
| 404 | + # Create a `bk_mapping` DataFrame that defines |
| 405 | + # the "correct" mapping beween "part_id" and |
| 406 | + # the destination partition ("file_id") |
| 407 | + npartitions_right = 5 |
| 408 | + bk_mapping = ( |
| 409 | + dd.from_dict( |
| 410 | + {"part_id": np.arange(25, dtype="int32")}, |
| 411 | + npartitions=npartitions_right, |
| 412 | + ) |
| 413 | + .shuffle("part_id") |
| 414 | + .map_partitions(add_partition_info) |
| 415 | + .compute() |
| 416 | + ) |
| 417 | + |
| 418 | + # Use `extract_partitioning_index` to calculate |
| 419 | + # the partitioning index and assign it as a new |
| 420 | + # "_partitions" column |
| 421 | + result, _ = extract_partitioning_index( |
| 422 | + unshuffled, |
| 423 | + "part_id", |
| 424 | + bk_mapping, |
| 425 | + npartitions_right, |
| 426 | + npartitions_right, |
| 427 | + ) |
| 428 | + |
| 429 | + # Rename the "_partitions" column, shuffle by "part_id", |
| 430 | + # and then assign a "file_id" column to reflect the final |
| 431 | + # partition of each row |
| 432 | + check = ( |
| 433 | + result.rename(columns={"_partitions": "expected_file_id"}) |
| 434 | + .shuffle( |
| 435 | + "part_id", |
| 436 | + npartitions=npartitions_right, |
| 437 | + ) |
| 438 | + .map_partitions(add_partition_info) |
| 439 | + .compute() |
| 440 | + ) |
| 441 | + |
| 442 | + # Check that the real and expected partitions match |
| 443 | + assert (check["file_id"] == check["expected_file_id"]).all() |
0 commit comments