Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

CU-8698f8fgc: Fix negative sampling including indices for words without a vector #524

Merged
merged 4 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions medcat/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,28 @@ def make_unigram_table(self, table_size: int = -1) -> None:
"the creation of a massive array. So therefore, there "
"is no need to pass the `table_size` parameter anymore.")
freqs = []
for word in self.vec_index2word.values():
# index list maps the slot in which a word index
# sits in vec_index2word to the actual index for said word
# e.g:
# if we have words indexed 0, 1, and 2
# but only 0, and 2 have corresponding vectors
# then only 0 and 2 will occur in vec_index2word
# and while 0 will be in the 0th position (as expected)
# in the final probability list, 2 will be in 1st position
# so we need to mark that conversion down
index_list = []
for word_index, word in self.vec_index2word.items():
freqs.append(self[word])
index_list.append(word_index)

# Power and normalize frequencies
freqs = np.array(freqs) ** (3/4)
freqs /= freqs.sum()

# Calculate cumulative probabilities
self.cum_probs = np.cumsum(freqs)
# the mapping from vector index order to word indices
self._index_list = index_list

def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) -> List[int]:
"""Get N negative samples.
Expand All @@ -216,8 +229,11 @@ def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) -
if len(self.cum_probs) == 0:
self.make_unigram_table()
random_vals = np.random.rand(n)
# NOTE: there's a change in numpy
inds = cast(List[int], np.searchsorted(self.cum_probs, random_vals).tolist())
# NOTE: These indices are in terms of the cum_probs array
# which only has word data for words with vectors.
vec_slots = cast(List[int], np.searchsorted(self.cum_probs, random_vals).tolist())
# so we need to translate these back to word indices
inds = list(map(self._index_list.__getitem__, vec_slots))

if ignore_punct_and_num:
# Do not return anything that does not have letters in it
Expand Down
48 changes: 39 additions & 9 deletions tests/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,37 +43,67 @@ class VocabUnigramTableTests(unittest.TestCase):
"..", "examples", "vocab_data.txt")
UNIGRAM_TABLE_SIZE = 10_000
# found that this seed had the closest frequency at the sample size we're at
RANDOM_SEED = 4976
RANDOM_SEED = 32
NUM_SAMPLES = 20 # NOTE: 3, 9, 18, and 27 at a time are regular due to context vector sizes
NUM_TIMES = 200
# based on the counts on vocab_data.txt and the one set in setUpClass
EXPECTED_FREQUENCIES = [0.62218692, 0.32422858, 0.0535845]
# based on the counts on vocab_data.txt and the ones set in setUpClass
# plus the power of 3/4
EXPECTED_FREQUENCIES = {
0: 0.61078822, 1: 0.3182886,
2: 0.05260281,
# NOTE: no 3 since that's got no vectors
4: 0.01832037}
TOLERANCE = 0.001

@classmethod
def setUpClass(cls):
cls.vocab = Vocab()
cls.vocab.add_words(cls.EXAMPLE_DATA_PATH)
cls.vocab.add_word("test", cnt=1310, vec=[1.42, 1.44, 1.55])
cls.vocab.add_word("vectorless", cnt=1234, vec=None)
cls.vocab.add_word("withvector", cnt=321, vec=[1.3, 1.2, 0.8])
cls.vocab.make_unigram_table(table_size=cls.UNIGRAM_TABLE_SIZE)

def setUp(self):
np.random.seed(self.RANDOM_SEED)

@classmethod
def _get_freqs(cls) -> list[float]:
def _get_freqs(cls) -> dict[int, float]:
c = Counter()
for _ in range(cls.NUM_TIMES):
got = cls.vocab.get_negative_samples(cls.NUM_SAMPLES)
c += Counter(got)
total = sum(c[i] for i in c)
got_freqs = [c[i]/total for i in range(len(cls.EXPECTED_FREQUENCIES))]
total = sum(c.values())
got_freqs = {index: val/total for index, val in c.items()}
return got_freqs

def assert_accurate_enough(self, got_freqs: list[float]):
@classmethod
def _get_abs_max_diff(cls, dict1: dict[int, float],
dict2: dict[int, float]):
assert dict1.keys() == dict2.keys()
vals1, vals2 = [], []
for index in dict1:
vals1.append(dict1[index])
vals2.append(dict2[index])
return np.max(np.abs(np.array(vals1) - np.array(vals2)))

def assert_accurate_enough(self, got_freqs: dict[int, float]):
self.assertEqual(got_freqs.keys(), self.EXPECTED_FREQUENCIES.keys())
self.assertTrue(
np.max(np.abs(np.array(got_freqs) - self.EXPECTED_FREQUENCIES)) < self.TOLERANCE
)
self._get_abs_max_diff(self.EXPECTED_FREQUENCIES, got_freqs) < self.TOLERANCE)

def test_does_not_include_vectorless_indices(self, num_samples: int = 100):
inds = self.vocab.get_negative_samples(num_samples)
for index in inds:
with self.subTest(f"Index: {index}"):
# in the right list
self.assertIn(index, self.vocab.vec_index2word)
word = self.vocab.vec_index2word[index]
info = self.vocab.vocab[word]
# the info has vector
self.assertIn("vec", info)
# the vector is an array or a list
self.assertIsInstance(self.vocab.vec(word), (np.ndarray, list),)

def test_negative_sampling(self):
got_freqs = self._get_freqs()
Expand Down