diff --git a/medcat/vocab.py b/medcat/vocab.py index b23b24190..04528fd40 100644 --- a/medcat/vocab.py +++ b/medcat/vocab.py @@ -190,8 +190,19 @@ def make_unigram_table(self, table_size: int = -1) -> None: "the creation of a massive array. So therefore, there " "is no need to pass the `table_size` parameter anymore.") freqs = [] - for word in self.vec_index2word.values(): + # index list maps the slot in which a word index + # sits in vec_index2word to the actual index for said word + # e.g: + # if we have words indexed 0, 1, and 2 + # but only 0, and 2 have corresponding vectors + # then only 0 and 2 will occur in vec_index2word + # and while 0 will be in the 0th position (as expected) + # in the final probability list, 2 will be in 1st position + # so we need to mark that conversion down + index_list = [] + for word_index, word in self.vec_index2word.items(): freqs.append(self[word]) + index_list.append(word_index) # Power and normalize frequencies freqs = np.array(freqs) ** (3/4) @@ -199,6 +210,8 @@ def make_unigram_table(self, table_size: int = -1) -> None: # Calculate cumulative probabilities self.cum_probs = np.cumsum(freqs) + # the mapping from vector index order to word indices + self._index_list = index_list def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) -> List[int]: """Get N negative samples. @@ -216,8 +229,11 @@ def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) - if len(self.cum_probs) == 0: self.make_unigram_table() random_vals = np.random.rand(n) - # NOTE: there's a change in numpy - inds = cast(List[int], np.searchsorted(self.cum_probs, random_vals).tolist()) + # NOTE: These indices are in terms of the cum_probs array + # which only has word data for words with vectors. + vec_slots = cast(List[int], np.searchsorted(self.cum_probs, random_vals).tolist()) + # so we need to translate these back to word indices + inds = list(map(self._index_list.__getitem__, vec_slots)) if ignore_punct_and_num: # Do not return anything that does not have letters in it diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 5e4f8e25e..dd1203ad8 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -43,11 +43,16 @@ class VocabUnigramTableTests(unittest.TestCase): "..", "examples", "vocab_data.txt") UNIGRAM_TABLE_SIZE = 10_000 # found that this seed had the closest frequency at the sample size we're at - RANDOM_SEED = 4976 + RANDOM_SEED = 32 NUM_SAMPLES = 20 # NOTE: 3, 9, 18, and 27 at a time are regular due to context vector sizes NUM_TIMES = 200 - # based on the counts on vocab_data.txt and the one set in setUpClass - EXPECTED_FREQUENCIES = [0.62218692, 0.32422858, 0.0535845] + # based on the counts on vocab_data.txt and the ones set in setUpClass + # plus the power of 3/4 + EXPECTED_FREQUENCIES = { + 0: 0.61078822, 1: 0.3182886, + 2: 0.05260281, + # NOTE: no 3 since that's got no vectors + 4: 0.01832037} TOLERANCE = 0.001 @classmethod @@ -55,25 +60,50 @@ def setUpClass(cls): cls.vocab = Vocab() cls.vocab.add_words(cls.EXAMPLE_DATA_PATH) cls.vocab.add_word("test", cnt=1310, vec=[1.42, 1.44, 1.55]) + cls.vocab.add_word("vectorless", cnt=1234, vec=None) + cls.vocab.add_word("withvector", cnt=321, vec=[1.3, 1.2, 0.8]) cls.vocab.make_unigram_table(table_size=cls.UNIGRAM_TABLE_SIZE) def setUp(self): np.random.seed(self.RANDOM_SEED) @classmethod - def _get_freqs(cls) -> list[float]: + def _get_freqs(cls) -> dict[int, float]: c = Counter() for _ in range(cls.NUM_TIMES): got = cls.vocab.get_negative_samples(cls.NUM_SAMPLES) c += Counter(got) - total = sum(c[i] for i in c) - got_freqs = [c[i]/total for i in range(len(cls.EXPECTED_FREQUENCIES))] + total = sum(c.values()) + got_freqs = {index: val/total for index, val in c.items()} return got_freqs - def assert_accurate_enough(self, got_freqs: list[float]): + @classmethod + def _get_abs_max_diff(cls, dict1: dict[int, float], + dict2: dict[int, float]): + assert dict1.keys() == dict2.keys() + vals1, vals2 = [], [] + for index in dict1: + vals1.append(dict1[index]) + vals2.append(dict2[index]) + return np.max(np.abs(np.array(vals1) - np.array(vals2))) + + def assert_accurate_enough(self, got_freqs: dict[int, float]): + self.assertEqual(got_freqs.keys(), self.EXPECTED_FREQUENCIES.keys()) self.assertTrue( - np.max(np.abs(np.array(got_freqs) - self.EXPECTED_FREQUENCIES)) < self.TOLERANCE - ) + self._get_abs_max_diff(self.EXPECTED_FREQUENCIES, got_freqs) < self.TOLERANCE) + + def test_does_not_include_vectorless_indices(self, num_samples: int = 100): + inds = self.vocab.get_negative_samples(num_samples) + for index in inds: + with self.subTest(f"Index: {index}"): + # in the right list + self.assertIn(index, self.vocab.vec_index2word) + word = self.vocab.vec_index2word[index] + info = self.vocab.vocab[word] + # the info has vector + self.assertIn("vec", info) + # the vector is an array or a list + self.assertIsInstance(self.vocab.vec(word), (np.ndarray, list),) def test_negative_sampling(self): got_freqs = self._get_freqs()