-
Notifications
You must be signed in to change notification settings - Fork 11
Add local pHash deduplication utility #462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,372 @@ | ||
| """Local pHash-based deduplication utilities.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import itertools | ||
| import re | ||
| from collections.abc import Callable, Iterable, Mapping | ||
| from dataclasses import dataclass | ||
| from typing import Generic, List, Optional, Sequence, TypeVar, Union | ||
|
|
||
| from .constants import ITEM_KEY | ||
| from .dataset_item import DatasetItem | ||
| from .deduplication import DeduplicationStats | ||
|
|
||
| InputT = TypeVar("InputT") | ||
| _TieBreakKey = tuple[int, Union[int, str]] | ||
|
|
||
| PHASH_REGEX = re.compile(r"^[01]{64}$") | ||
| DEDUP_THRESHOLD_MIN = 0 | ||
| DEDUP_THRESHOLD_MAX = 64 | ||
| LOCAL_INDEX_MAX_THRESHOLD = 10 | ||
| PARTITION_COUNT = 2 | ||
| INDEX_CHUNK_BITS = (11, 11, 11, 11, 10, 10) | ||
| INDEX_CHUNK_COUNT = len(INDEX_CHUNK_BITS) | ||
| PHASH_VALUE_MASK = (1 << 64) - 1 | ||
| ROTATED_PARTITION_BITS = 8 | ||
|
|
||
|
|
||
| @dataclass | ||
| class LocalDeduplicationResult(Generic[InputT]): | ||
| """Output of a local pHash deduplication run. | ||
|
|
||
| Attributes: | ||
| unique: Input objects that survived deduplication. If you passed rows | ||
| from ``items_and_annotation_generator()``, this contains those same | ||
| row dictionaries. If you passed ``DatasetItem`` objects, it contains | ||
| ``DatasetItem`` objects. | ||
| unique_dataset_items: The DatasetItem for each object in ``unique``. | ||
| unique_reference_ids: Reference IDs for the DatasetItems in ``unique``. | ||
| Entries can be ``None`` if a DatasetItem has no reference ID. | ||
| stats: Summary statistics for the run. | ||
| """ | ||
|
|
||
| unique: List[InputT] | ||
| unique_dataset_items: List[DatasetItem] | ||
| unique_reference_ids: List[Optional[str]] | ||
| stats: DeduplicationStats | ||
|
|
||
|
|
||
| @dataclass(slots=True) | ||
| class _DeduplicationRecord(Generic[InputT]): | ||
| stable_id: _TieBreakKey | ||
| phash_value: int | ||
| obj: InputT | ||
| dataset_item: DatasetItem | ||
|
|
||
|
|
||
| class _HammingIndex: | ||
| """Exact Hamming near-neighbor index for 64-bit hashes. | ||
|
|
||
| The first partition splits the pHash into six contiguous chunks. | ||
| The second partition rotates the pHash by eight bits, then applies the same | ||
| chunking. If two hashes are within threshold ``t``, then each partition has | ||
| at least one chunk within ``floor(t / 6)``. Querying both partitions and | ||
| intersecting the candidates keeps the result exact while avoiding a full | ||
| scan for the intended ``t <= 10`` case. | ||
| """ | ||
|
|
||
| def __init__(self, threshold: int): | ||
| self._threshold = threshold | ||
| self._chunk_radius = threshold // INDEX_CHUNK_COUNT | ||
| self._variant_masks_by_chunk = [ | ||
| _variant_masks(self._chunk_radius, chunk_bits) | ||
| for chunk_bits in INDEX_CHUNK_BITS | ||
| ] | ||
| self._hashes: List[int] = [] | ||
| self._candidate_marks = bytearray() | ||
| self._indexes: List[List[dict[int, List[int]]]] = [ | ||
| [{} for _ in range(INDEX_CHUNK_COUNT)] | ||
| for _ in range(PARTITION_COUNT) | ||
| ] | ||
|
|
||
| def _add(self, phash_value: int) -> None: | ||
| kept_index = len(self._hashes) | ||
| self._hashes.append(phash_value) | ||
| self._candidate_marks.append(0) | ||
| for partition_index in range(PARTITION_COUNT): | ||
| chunks = _partition_chunks(phash_value, partition_index) | ||
| for chunk_index, chunk_value in enumerate(chunks): | ||
| index = self._indexes[partition_index][chunk_index] | ||
| index.setdefault(chunk_value, []).append(kept_index) | ||
|
|
||
| def add_if_unique(self, phash_value: int) -> bool: | ||
| if self._find_duplicate_index(phash_value) is not None: | ||
| return False | ||
| self._add(phash_value) | ||
| return True | ||
|
|
||
| def _find_duplicate_index(self, phash_value: int) -> Optional[int]: | ||
| if not self._hashes: | ||
| return None | ||
|
|
||
| touched_indexes = self._mark_partition_zero_candidates(phash_value) | ||
| if not touched_indexes: | ||
| return None | ||
|
|
||
| duplicate_index = self._find_marked_partition_one_duplicate( | ||
| phash_value | ||
| ) | ||
| _clear_candidate_marks(self._candidate_marks, touched_indexes) | ||
| return duplicate_index | ||
|
|
||
| def _mark_partition_zero_candidates(self, phash_value: int) -> List[int]: | ||
| touched_indexes: List[int] = [] | ||
| marks = self._candidate_marks | ||
| partition_zero_chunks = _partition_chunks(phash_value, 0) | ||
| for chunk_index, chunk_value in enumerate(partition_zero_chunks): | ||
| index = self._indexes[0][chunk_index] | ||
| for mask in self._variant_masks_by_chunk[chunk_index]: | ||
| bucket = index.get(chunk_value ^ mask) | ||
| if bucket is None: | ||
| continue | ||
| for kept_index in bucket: | ||
| if marks[kept_index] == 0: | ||
| marks[kept_index] = 1 | ||
| touched_indexes.append(kept_index) | ||
| return touched_indexes | ||
|
|
||
| def _find_marked_partition_one_duplicate( | ||
| self, phash_value: int | ||
| ) -> Optional[int]: | ||
| marks = self._candidate_marks | ||
| partition_one_chunks = _partition_chunks(phash_value, 1) | ||
| for chunk_index, chunk_value in enumerate(partition_one_chunks): | ||
| index = self._indexes[1][chunk_index] | ||
| for mask in self._variant_masks_by_chunk[chunk_index]: | ||
| bucket = index.get(chunk_value ^ mask) | ||
| if bucket is None: | ||
| continue | ||
| for kept_index in bucket: | ||
| if marks[kept_index] != 1: | ||
| continue | ||
| if ( | ||
| phash_value ^ self._hashes[kept_index] | ||
| ).bit_count() <= self._threshold: | ||
| return kept_index | ||
| marks[kept_index] = 2 | ||
| return None | ||
|
|
||
|
|
||
| def deduplicate_by_phash( | ||
| objects: Iterable[InputT], | ||
| threshold: int, | ||
| *, | ||
| sort_key: Optional[Callable[[InputT], Union[str, int]]] = None, | ||
| ) -> LocalDeduplicationResult[InputT]: | ||
| """Deduplicate local DatasetItems or item/annotation rows by pHash. | ||
|
|
||
| This utility does not call the Nucleus API and does not require an API key. | ||
| It accepts either ``DatasetItem`` objects or rows returned from | ||
| ``items_and_annotation_generator()`` / ``items_and_annotations()`` where | ||
| the row has an ``"item"`` key containing a ``DatasetItem``. The returned | ||
| ``unique`` list contains the same kind of objects that were passed in. | ||
|
|
||
| Args: | ||
| objects: DatasetItems or item/annotation rows to deduplicate. | ||
| threshold: Hamming distance threshold (0-64). Lower = stricter. | ||
| ``0`` deduplicates exact pHash matches only. | ||
| sort_key: Optional stable tie-break key used after pHash sorting. When | ||
| omitted, DatasetItem.reference_id is used. | ||
|
|
||
| Returns: | ||
| LocalDeduplicationResult containing the surviving input objects and | ||
| summary statistics. | ||
|
|
||
| Raises: | ||
| TypeError: If an input object is neither a DatasetItem nor a mapping | ||
| with an ``"item"`` DatasetItem. | ||
| ValueError: If threshold is invalid, or any selected DatasetItem is | ||
| missing a 64-character binary pHash. | ||
| """ | ||
|
|
||
| _validate_threshold(threshold) | ||
| records = _normalize_records(objects, sort_key=sort_key) | ||
| if not records: | ||
| return LocalDeduplicationResult( | ||
| unique=[], | ||
| unique_dataset_items=[], | ||
| unique_reference_ids=[], | ||
| stats=DeduplicationStats( | ||
| threshold=threshold, | ||
| original_count=0, | ||
| deduplicated_count=0, | ||
| ), | ||
| ) | ||
|
|
||
| records.sort(key=lambda record: (record.phash_value, record.stable_id)) | ||
|
|
||
| if threshold == 0: | ||
| unique_records = _deduplicate_exact(records) | ||
| elif threshold == DEDUP_THRESHOLD_MAX: | ||
| unique_records = [records[0]] | ||
| elif threshold <= LOCAL_INDEX_MAX_THRESHOLD: | ||
| unique_records = _deduplicate_with_index(records, threshold) | ||
| else: | ||
| unique_records = _deduplicate_with_linear_scan(records, threshold) | ||
|
|
||
| return LocalDeduplicationResult( | ||
| unique=[record.obj for record in unique_records], | ||
| unique_dataset_items=[ | ||
| record.dataset_item for record in unique_records | ||
| ], | ||
| unique_reference_ids=[ | ||
| record.dataset_item.reference_id for record in unique_records | ||
| ], | ||
|
greptile-apps[bot] marked this conversation as resolved.
|
||
| stats=DeduplicationStats( | ||
| threshold=threshold, | ||
| original_count=len(records), | ||
| deduplicated_count=len(unique_records), | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def _validate_threshold(threshold: int) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this validation not also happening in the original dedup sdk path? either way, can it be shared between both?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I left this local-only for now. The existing async |
||
| if ( | ||
| not isinstance(threshold, int) | ||
| or isinstance(threshold, bool) | ||
| or threshold < DEDUP_THRESHOLD_MIN | ||
| or threshold > DEDUP_THRESHOLD_MAX | ||
| ): | ||
| raise ValueError( | ||
| "threshold must be an integer between 0 and 64 (inclusive)" | ||
| ) | ||
|
|
||
|
|
||
| def _normalize_records( | ||
| objects: Iterable[InputT], | ||
| *, | ||
| sort_key: Optional[Callable[[InputT], Union[str, int]]], | ||
| ) -> List[_DeduplicationRecord[InputT]]: | ||
| records = [] | ||
| for ordinal, obj in enumerate(objects): | ||
| dataset_item = _extract_dataset_item(obj) | ||
| phash = dataset_item.phash | ||
| if phash is None or not PHASH_REGEX.fullmatch(phash): | ||
| ref_id = dataset_item.reference_id | ||
| ref_msg = f" with reference_id {ref_id!r}" if ref_id else "" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: would it be useful to include the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did not add this because the supported local input shape only exposes a |
||
| raise ValueError( | ||
| "All DatasetItems must have a 64-character binary pHash. " | ||
| f"Found missing or malformed pHash on item{ref_msg}." | ||
| ) | ||
|
|
||
| tie_breaker = ( | ||
| sort_key(obj) | ||
| if sort_key is not None | ||
| else dataset_item.reference_id | ||
| ) | ||
| stable_id = _tie_break_key(tie_breaker, ordinal) | ||
| records.append( | ||
| _DeduplicationRecord( | ||
| stable_id=stable_id, | ||
| phash_value=int(phash, 2), | ||
| obj=obj, | ||
| dataset_item=dataset_item, | ||
| ) | ||
| ) | ||
| return records | ||
|
|
||
|
|
||
| def _tie_break_key( | ||
| tie_breaker: Optional[Union[str, int]], ordinal: int | ||
| ) -> _TieBreakKey: | ||
| if tie_breaker is None: | ||
| return (2, ordinal) | ||
| if isinstance(tie_breaker, int): | ||
| return (0, tie_breaker) | ||
| return (1, tie_breaker) | ||
|
|
||
|
|
||
| def _extract_dataset_item(obj: InputT) -> DatasetItem: | ||
| if isinstance(obj, DatasetItem): | ||
| return obj | ||
|
|
||
| if isinstance(obj, Mapping): | ||
| item = obj.get(ITEM_KEY) | ||
| if isinstance(item, DatasetItem): | ||
| return item | ||
|
|
||
| raise TypeError( | ||
| "Expected each object to be a DatasetItem or a mapping with an " | ||
| f"{ITEM_KEY!r} DatasetItem." | ||
| ) | ||
|
|
||
|
|
||
| def _deduplicate_exact( | ||
| records: Sequence[_DeduplicationRecord[InputT]], | ||
| ) -> List[_DeduplicationRecord[InputT]]: | ||
| seen_phashes = set() | ||
| unique_records = [] | ||
| for record in records: | ||
| if record.phash_value in seen_phashes: | ||
| continue | ||
| seen_phashes.add(record.phash_value) | ||
| unique_records.append(record) | ||
| return unique_records | ||
|
|
||
|
|
||
| def _deduplicate_with_index( | ||
| records: Sequence[_DeduplicationRecord[InputT]], threshold: int | ||
| ) -> List[_DeduplicationRecord[InputT]]: | ||
| index = _HammingIndex(threshold) | ||
| unique_records = [] | ||
| for record in records: | ||
| if index.add_if_unique(record.phash_value): | ||
| unique_records.append(record) | ||
| return unique_records | ||
|
|
||
|
|
||
| def _deduplicate_with_linear_scan( | ||
| records: Sequence[_DeduplicationRecord[InputT]], threshold: int | ||
| ) -> List[_DeduplicationRecord[InputT]]: | ||
| kept_hashes: List[int] = [] | ||
| unique_records = [] | ||
| for record in records: | ||
| is_duplicate = any( | ||
| (record.phash_value ^ kept_hash).bit_count() <= threshold | ||
| for kept_hash in kept_hashes | ||
| ) | ||
| if is_duplicate: | ||
| continue | ||
| kept_hashes.append(record.phash_value) | ||
| unique_records.append(record) | ||
| return unique_records | ||
|
|
||
|
|
||
| def _partition_chunks( | ||
| phash_value: int, partition_index: int | ||
| ) -> tuple[int, ...]: | ||
| if partition_index == 1: | ||
| phash_value = _rotate_right_64(phash_value, ROTATED_PARTITION_BITS) | ||
|
|
||
| chunks = [] | ||
| shift = 64 | ||
| for chunk_bits in INDEX_CHUNK_BITS: | ||
| shift -= chunk_bits | ||
| chunks.append((phash_value >> shift) & ((1 << chunk_bits) - 1)) | ||
| return tuple(chunks) | ||
|
|
||
|
|
||
| def _rotate_right_64(phash_value: int, bits: int) -> int: | ||
| return ( | ||
| (phash_value >> bits) | (phash_value << (64 - bits)) | ||
| ) & PHASH_VALUE_MASK | ||
|
|
||
|
|
||
| def _clear_candidate_marks( | ||
| marks: bytearray, touched_indexes: List[int] | ||
| ) -> None: | ||
| for kept_index in touched_indexes: | ||
| marks[kept_index] = 0 | ||
|
|
||
|
|
||
| def _variant_masks(radius: int, chunk_bits: int) -> List[int]: | ||
| masks = [0] | ||
| bit_indexes = range(chunk_bits) | ||
| for distance in range(1, radius + 1): | ||
| for positions in itertools.combinations(bit_indexes, distance): | ||
| mask = 0 | ||
| for position in positions: | ||
| mask |= 1 << position | ||
| masks.append(mask) | ||
| return masks | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nbd but might be slightly more efficient to place this above the sort on 179? if i understand correctly the sort won't drop any records
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, the empty-input return now happens immediately after normalization and before the sort.