Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add near-duplicate signal #451

Merged
merged 4 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,6 @@ follow_imports = skip
ignore_missing_imports = True
follow_imports = skip

[mypy-scipy.integrate.*]
ignore_missing_imports = True
follow_imports = skip
2 changes: 1 addition & 1 deletion src/data/dataset_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def select_groups(

leaf_is_float = is_float(leaf.dtype)
leaf_is_integer = is_integer(leaf.dtype)
if leaf_is_float or leaf_is_integer:
if not leaf.categorical and (leaf_is_float or leaf_is_integer):
if named_bins is None:
# Auto-bin.
named_bins = _auto_bins(stats, NUM_AUTO_BINS)
Expand Down
11 changes: 11 additions & 0 deletions src/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class Field(BaseModel):
signal: Optional[dict[str, Any]]
# Maps a named bin to a tuple of (start, end) values.
bins: Optional[list[Bin]]
categorical: Optional[bool]

@validator('fields')
def either_fields_or_repeated_field_is_defined(
Expand Down Expand Up @@ -163,6 +164,13 @@ def validate_bins(cls, bins: list[Bin]) -> list[Bin]:
f'Bin {i} start ({start}) should be equal to the previous bin end {prev_end}.')
return bins

@validator('categorical')
def validate_categorical(cls, categorical: bool, values: dict[str, Any]) -> bool:
"""Validate the categorical field."""
if categorical and is_float(values['dtype']):
raise ValueError('Categorical fields cannot be float dtypes.')
return categorical

def __str__(self) -> str:
return _str_field(self, indent=0)

Expand Down Expand Up @@ -256,6 +264,7 @@ def field(
signal: Optional[dict] = None,
fields: Optional[object] = None,
bins: Optional[list[Bin]] = None,
categorical: Optional[bool] = None,
) -> Field:
"""Parse a field-like object to a Field object."""
field = _parse_field_like(fields or {}, dtype)
Expand All @@ -267,6 +276,8 @@ def field(
field.dtype = dtype
if bins:
field.bins = bins
if categorical is not None:
field.categorical = categorical
return field


Expand Down
2 changes: 2 additions & 0 deletions src/signals/default_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..embeddings.sbert import SBERT
from .concept_labels import ConceptLabelsSignal
from .concept_scorer import ConceptScoreSignal
from .near_dup import NearDuplicateSignal
from .ner import SpacyNER
from .pii import PIISignal
from .signal import register_signal
Expand All @@ -21,6 +22,7 @@ def register_default_signals() -> None:
register_signal(PIISignal)
register_signal(TextStatisticsSignal)
register_signal(SpacyNER)
register_signal(NearDuplicateSignal)

# Embeddings.
register_signal(Cohere)
Expand Down
212 changes: 212 additions & 0 deletions src/signals/minhash_dup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""Find near-duplicates using minhash.

# Code forked from
# https://github.com/bigcode-project/bigcode-dataset/blob/main/near_deduplication/minhash_deduplication.py
# under the Apache 2.0 License.
"""
import gc
import hashlib
import re
import struct
from collections import defaultdict
from itertools import tee
from typing import Iterable, List

import numpy as np
from scipy.integrate import quad as integrate
from tqdm import tqdm

SEED = 42
NON_ALPHA = re.compile('[^A-Za-z_0-9]')
RNG = np.random.RandomState(SEED)
MAX_HASH = np.uint64((1 << 32) - 1)
MERSENNE_PRIME = np.uint64((1 << 61) - 1)


def _ngrams(sequence: List[str], n: int, min_ngram_size: int) -> Iterable:
"""Directly taken from nltk package to avoid dependency.

Args:
sequence The sequence of items to be n-grammed.
n The order of the n-grams to be extracted.
min_ngram_size The minimum size of n-grams.

Returns
The n-grams generated from the sequence.
"""
if len(sequence) < min_ngram_size:
return []
ngram_size = min(n, len(sequence))
iterables = tee(sequence, ngram_size)
for i, sub_iterable in enumerate(iterables):
for _ in range(i):
next(sub_iterable, None)
return zip(*iterables)


def _sha1_hash32(data: bytes) -> int:
"""Directly taken from datasketch package to avoid dependency."""
return struct.unpack('<I', hashlib.sha1(data).digest()[:4])[0]


def _embed_func(
content: str,
num_perm: int,
ngram_size: int,
hashranges: list[tuple[int, int]],
permutations: np.ndarray,
min_ngram_size: int,
) -> list[bytes]:
"""Combined with some datasketch code to better parallelize computation.

Args:
content The content to be embedded.
idx The index of the content.
num_perm The number of permutations.
ngram_size The size of n-grams.
hashranges The ranges of hash values.
permutations The permutations for the minhash.
min_ngram_size The minimum size of n-grams.

Returns
The hash values in each range and the index.
"""
hashvalues = np.ones(num_perm, dtype=np.uint64) * MAX_HASH
tokens = {' '.join(t) for t in _ngrams(NON_ALPHA.split(content), ngram_size, min_ngram_size)}
hv = np.array([_sha1_hash32(token.encode('utf-8')) for token in tokens],
dtype=np.uint64) # noqa: E501
a, b = permutations
phv = np.bitwise_and(((hv * np.tile(a, (len(hv), 1)).T).T + b) % MERSENNE_PRIME,
MAX_HASH) # noqa: E501
hashvalues = np.vstack([phv, hashvalues]).min(axis=0)
Hs: list[bytes] = [bytes(hashvalues[start:end].byteswap().data) for start, end in hashranges]
return Hs


def _optimal_param(threshold: float,
num_perm: int,
false_positive_weight: float = 0.5,
false_negative_weight: float = 0.5) -> tuple[int, int]:
"""Find optimal `MinHashLSH` parameter that minimizes the weighted sum of false pos and false neg.

Taken from datasketch.

Args
threshold The threshold for similarity.
num_perm The number of permutations.
false_positive_weight The weight of false positive.
false_negative_weight The weight of false negative.

Returns
The optimal `b` and `r` parameters.
The number of bands, and the number of rows per band respectively.
"""

def false_positive_probability(threshold: float, b: int, r: int) -> float:
"""Source: `datasketch.lsh`."""

def proba(s: float) -> float:
return 1 - (1 - s**float(r))**float(b)

a, _ = integrate(proba, 0.0, threshold)
return a

def false_negative_probability(threshold: float, b: int, r: int) -> float:
"""Source: `datasketch.lsh`."""

def proba(s: float) -> float:
return 1 - (1 - (1 - s**float(r))**float(b))

a, _ = integrate(proba, threshold, 1.0)
return a

min_error = float('inf')
opt = (0, 0)
for b in range(1, num_perm + 1):
max_r = int(num_perm / b)
for r in range(1, max_r + 1):
fp = false_positive_probability(threshold, b, r)
fn = false_negative_probability(threshold, b, r)
error = fp * false_positive_weight + fn * false_negative_weight
if error < min_error:
min_error = error
opt = (b, r)
return opt


class UnionFind:
"""Union find data structure."""

def __init__(self) -> None:
self.parent: dict[int, int] = {}

def find(self, x: int) -> int:
"""Find the parent of the node."""
if x not in self.parent:
self.parent[x] = x
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]

def union(self, x: int, y: int) -> None:
"""Union two nodes."""
px = self.find(x)
py = self.find(y)
self.parent[px] = self.parent[py] = min(px, py)


def find_clusters(data: Iterable[str],
ngram_size: int = 5,
num_perm: int = 256,
threshold: float = 0.7,
min_ngram_size: int = 1) -> Iterable[int]:
"""Deduplicates documents and returns cluster ids."""
uf = UnionFind()
B, R = _optimal_param(threshold, num_perm)
HASH_RANGES: list[tuple[int, int]] = [(i * R, (i + 1) * R) for i in range(B)]
HASH_TABLES: list[dict[bytes, set[int]]] = [defaultdict(set) for _ in range(B)]

# Consume the data.
PERMUTATIONS = np.array(
[(
RNG.randint(1, MERSENNE_PRIME, dtype=np.uint64),
RNG.randint(0, MERSENNE_PRIME, dtype=np.uint64),
) for _ in range(num_perm)],
dtype=np.uint64,
).T

# Fingerprinting.
embedded: list[tuple[int, list[bytes]]] = []
for key, content in tqdm(enumerate(data), dynamic_ncols=True, desc='Fingerprinting...'):
hashes = _embed_func(
content,
num_perm=num_perm,
hashranges=HASH_RANGES,
ngram_size=ngram_size,
permutations=PERMUTATIONS,
min_ngram_size=min_ngram_size)
embedded.append((key, hashes))

batch_size: int = 10000
for i in tqdm(
range(0, len(embedded), batch_size), dynamic_ncols=True, desc='Computing hash collisions...'):
batch = embedded[i:i + batch_size]
for (key, Hs) in batch:
for H, hashtable in zip(Hs, HASH_TABLES):
hashtable[H].add(key)

for table in tqdm(HASH_TABLES, dynamic_ncols=True, desc='Clustering...'):
for cluster in table.values():
if len(cluster) <= 1:
continue
idx = min(cluster)
for x in cluster:
uf.union(x, idx)

gc.freeze()
gc.disable()
cluster_ids = [uf.find(i) for i in range(len(embedded))]
gc.enable()
gc.collect()

return cluster_ids
42 changes: 42 additions & 0 deletions src/signals/near_dup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Compute near duplicates for a dataset."""
from typing import Iterable, Optional, cast

from pydantic import Field as PydanticField
from typing_extensions import override

from ..schema import Field, Item, RichData, SignalInputType, field
from .minhash_dup import find_clusters
from .signal import TextSignal

CLUSTER_KEY = 'cluster_id'


class NearDuplicateSignal(TextSignal):
"""Find near duplicate documents in a dataset using n-grams.

<br/>

Documents are fingerprinted using n-grams with
[minhash LSH](https://en.wikipedia.org/wiki/MinHash). Documents are assigned the same cluster id
if their Jaccard similarity is above the provided threshold.
"""
name = 'near_dup'
display_name = 'Near duplicate documents'

input_type = SignalInputType.TEXT
compute_type = SignalInputType.TEXT

threshold: float = PydanticField(
default=0.75,
description='The similarity threshold for detecting a near duplicate.',
)

@override
def fields(self) -> Field:
return field(fields={CLUSTER_KEY: field('uint32', categorical=True)})

@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
cluster_ids = find_clusters(cast(Iterable[str], data), threshold=self.threshold)
for cluster_id in cluster_ids:
yield {CLUSTER_KEY: cluster_id}
18 changes: 18 additions & 0 deletions src/signals/near_dup_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Test the Near duplicate signal."""

from .near_dup import CLUSTER_KEY, NearDuplicateSignal


def test_exact_duplicates() -> None:
signal = NearDuplicateSignal()
docs = ['Hello', 'Everyone', 'Hello', 'Hi']
assert list(signal.compute(docs)) == [{CLUSTER_KEY: x} for x in [0, 1, 0, 3]]


def test_near_dups() -> None:
signal = NearDuplicateSignal()
docs = [
'Hello everyone. This is a test for near duplication with almost the same content',
'Hello everyone. This is a test for near duplication with almost the same content [time]',
]
assert list(signal.compute(docs)) == [{CLUSTER_KEY: x} for x in [0, 0]]
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
<div class="flex w-full flex-col gap-y-6 rounded border border-gray-300 bg-white p-4">
{#if signalInfo}
{#key signalInfo}
<div class="whitespace-pre-wrap">
<div>
<SvelteMarkdown source={signalInfo.json_schema.description} />
</div>

Expand Down
1 change: 1 addition & 0 deletions web/lib/fastapi_client/models/Field.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ export type Field = {
dtype?: DataType;
signal?: Record<string, any>;
bins?: Array<Array<any>>;
categorical?: boolean;
};