Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions examples/semdedup_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,31 @@ def main(args):
if semdedup_config.num_files > 0:
input_files = input_files[: semdedup_config.num_files]
logger.info(f"Processing {len(input_files)} files")

ddf = read_data(
input_files=input_files,
file_type=args.input_file_type,
add_filename=False,
backend="cudf",
)
dataset = DocumentDataset(ddf)
semdup = SemDedup(semdedup_config, logger=logger)
dedup_ids = semdup(dataset)
print(dedup_ids.df.head())
logger.info(f"Time taken: {time.time() - st}")
semdup = SemDedup(
semdedup_config,
# Decides whether output of the module is a deduplicated dataset or the IDs of the duplicates
perform_removal=False,
logger=logger,
)
# When perform_removal=False, it will only call .identify_duplicates() and return the list of duplicate IDs.
# When perform_removal=True, then exact_dup outputs the dataset with the duplicates removed.
# It will behave by calling .identify_duplicates() and .remove() in sequence.
duplicates = semdup(dataset)
print(duplicates.df.head())
logger.info(f"Time taken to identify duplicates: {time.time() - st}")

result = semdup.remove(dataset, duplicates)
print(result.df.head())
logger.info(f"Time taken to remove duplicates: {time.time() - st}")

client.cancel(client.futures, force=True)
client.close()

Expand Down
82 changes: 81 additions & 1 deletion nemo_curator/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import warnings
from abc import ABC, abstractmethod
from typing import Literal, Optional
from typing import Literal, Optional, Union

import dask.dataframe as dd

Expand Down Expand Up @@ -82,3 +84,81 @@ def __call__(self, dataset: DocumentDataset):
self._validate_correct_backend(dataset.df)

return self.call(dataset)


class BaseDeduplicationModule(BaseModule):
"""
Base class for all NeMo Curator deduplication modules.
"""

def __init__(
self,
id_field: str,
text_field: str,
perform_removal: bool = False,
logger: Union[logging.LoggerAdapter, str] = "./",
profile_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
input_backend: Literal["pandas", "cudf", "any"] = "any",
**kwargs,
):
super().__init__(input_backend=input_backend, **kwargs)
self.id_field = id_field
self.text_field = text_field
self.perform_removal = perform_removal
self.logger = logger
self.profile_dir = profile_dir
self.cache_dir = cache_dir

if self.perform_removal and cache_dir is None:
warnings.warn("cache_dir is recommended to remove duplicates.")

if cache_dir is None and profile_dir is not None:
warnings.warn(
"cache_dir for intermediate outputs is required to generate profiles"
)

if not self.perform_removal:
warnings.warn(
"In future NeMo Curator releases, the default value for perform_removal will be True."
)

@abstractmethod
def identify_duplicates(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Identifies duplicates in a dataset

Args:
dataset (DocumentDataset): The dataset to identify duplicates in
"""
raise NotImplementedError(
"identify_duplicates method must be implemented by subclasses"
)

@abstractmethod
def remove(
self, dataset: DocumentDataset, duplicates_to_remove: DocumentDataset
) -> DocumentDataset:
"""
Removes duplicates from a dataset

Args:
dataset (DocumentDataset): The dataset to remove duplicates from
"""
raise NotImplementedError("remove method must be implemented by subclasses")

def call(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Execute the deduplication process.

Args:
dataset (DocumentDataset): Input dataset for deduplication.
Returns:
DocumentDataset: Deduplicated dataset if perform_removal is False, otherwise the dataset with duplicates removed.
"""
duplicates = self.identify_duplicates(dataset)

if self.perform_removal:
return self.remove(dataset, duplicates)

return duplicates
46 changes: 13 additions & 33 deletions nemo_curator/modules/exact_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@
from nemo_curator._compat import DASK_P2P_ERROR
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.base import BaseModule
from nemo_curator.modules.base import BaseDeduplicationModule
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
from nemo_curator.utils.duplicates_removal import remove_duplicates
from nemo_curator.utils.gpu_utils import is_cudf_type


class ExactDuplicates(BaseModule):
class ExactDuplicates(BaseDeduplicationModule):
"""Find exact duplicates in a document corpus"""

SUPPORTED_HASHES = {"md5"}
Expand All @@ -61,33 +61,21 @@ def __init__(
cache_dir: str, Default None
If specified, will compute & write duplicate id's to cache directory.
"""
super().__init__(input_backend="any")
super().__init__(
id_field=id_field,
text_field=text_field,
input_backend="any",
logger=logger,
perform_removal=perform_removal,
profile_dir=profile_dir,
cache_dir=cache_dir,
)

if hash_method not in self.SUPPORTED_HASHES:
raise ValueError(
f"{hash_method} not in supported hash_methods. Choose a hash_method from {self.SUPPORTED_HASHES}"
)
msg = f"{hash_method} not in supported hash_methods. Choose a hash_method from {self.SUPPORTED_HASHES}"
raise ValueError(msg)

self.hash_method = hash_method
self.id_field = id_field
self.text_field = text_field
self.perform_removal = perform_removal

if not self.perform_removal:
warnings.warn(
"In future NeMo Curator releases, the default value for perform_removal will be True."
)

if self.perform_removal and cache_dir is None:
warnings.warn("cache_dir is recommended to remove duplicates.")

if cache_dir is None and profile_dir is not None:
warnings.warn(
"cache_dir for intermediate outputs is required to generate profiles"
)

self.cache_dir = cache_dir
self.profile_dir = profile_dir

if isinstance(logger, str):
self._logger = create_logger(
Expand Down Expand Up @@ -231,11 +219,3 @@ def remove(
group_field="_hashes",
)
return DocumentDataset(result)

def call(self, dataset: DocumentDataset) -> DocumentDataset:
duplicates = self.identify_duplicates(dataset)

if self.perform_removal:
return self.remove(dataset, duplicates)

return duplicates
30 changes: 15 additions & 15 deletions nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.base import BaseModule
from nemo_curator.modules.base import BaseDeduplicationModule
from nemo_curator.modules.config import FuzzyDuplicatesConfig
from nemo_curator.modules.fuzzy_dedup._mapbuckets import _MapBuckets
from nemo_curator.modules.fuzzy_dedup._shuffle import _Shuffle
Expand All @@ -35,25 +35,36 @@
from nemo_curator.utils.duplicates_removal import remove_duplicates


class FuzzyDuplicates(BaseModule):
class FuzzyDuplicates(BaseDeduplicationModule):
def __init__(
self,
config: FuzzyDuplicatesConfig,
logger: Union[logging.LoggerAdapter, str] = "./",
perform_removal: bool = False,
):
"""
Parameters
----------
config: FuzzyDuplicatesConfig,
Config options for finding FuzzyDuplicates
logger: Existing logger to log to, or a path to a log directory.

perform_removal: Whether to remove duplicates from the dataset.
Default is False.
Returns
-------
DocumentDataset containing IDs of all documents and the corresponding duplicate group
they belong to. Documents in the same group are near duplicates.
"""
super().__init__(input_backend="cudf")
super().__init__(
id_field=config.id_field,
text_field=config.text_field,
input_backend="cudf",
logger=logger,
perform_removal=perform_removal,
profile_dir=config.profile_dir,
cache_dir=config.cache_dir,
)

if isinstance(logger, str):
self._logger = create_logger(
rank=0,
Expand All @@ -64,7 +75,6 @@ def __init__(
self._logger = logger

self.config = config

self.minhash = MinHash(
seed=self.config.seed,
num_hashes=self.config.num_hashes,
Expand Down Expand Up @@ -282,13 +292,3 @@ def remove(
group_field="group",
)
return DocumentDataset(result)

def call(
self, dataset: DocumentDataset, perform_removal: bool = False
) -> DocumentDataset:
duplicates = self.identify_duplicates(dataset)

if perform_removal:
return self.remove(dataset, duplicates)

return duplicates
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,11 @@ def compute_semantic_match_dfs(self) -> None:

def extract_dedup_data(self, eps_to_extract: float) -> DocumentDataset:
"""
Extract deduplicated data based on epsilon value.

Extract similar records that are within epsilon threshold. These records can be removed from the dataset.
Args:
eps_to_extract (float): Epsilon threshold for extracting deduplicated data.

Returns:
DocumentDataset: Dataset containing deduplicated documents.
DocumentDataset: Dataset containing list of ids that are can be removed.
"""
if not self.computed_semantic_match_dfs:
raise ValueError(
Expand Down
Loading
Loading