From 60bf9e514ded5afd4b68ff88ec28809773f653ac Mon Sep 17 00:00:00 2001 From: Maarten Grootendorst Date: Wed, 2 Dec 2020 11:05:15 +0100 Subject: [PATCH] Add custom countvectorizer (#14) --- keybert/__init__.py | 1 + keybert/model.py | 50 ++++++++++++++++++++++++++++----------------- setup.py | 2 +- tests/test_model.py | 33 +++++++++++++++++++----------- 4 files changed, 54 insertions(+), 32 deletions(-) diff --git a/keybert/__init__.py b/keybert/__init__.py index 1d41a2ea..b6e17af9 100644 --- a/keybert/__init__.py +++ b/keybert/__init__.py @@ -1 +1,2 @@ from keybert.model import KeyBERT +__version__ = "0.1.3" diff --git a/keybert/model.py b/keybert/model.py index e890de3e..5a161894 100644 --- a/keybert/model.py +++ b/keybert/model.py @@ -3,7 +3,7 @@ from sklearn.metrics.pairwise import cosine_similarity from sklearn.feature_extraction.text import CountVectorizer from tqdm import tqdm -from typing import List, Union +from typing import List, Union, Tuple import warnings from .mmr import mmr from .maxsum import max_sum_similarity @@ -35,14 +35,15 @@ def __init__(self, model: str = 'distilbert-base-nli-mean-tokens'): def extract_keywords(self, docs: Union[str, List[str]], - keyphrase_length: int = 1, + keyphrase_ngram_range: Tuple[int, int] = (1, 1), stop_words: Union[str, List[str]] = 'english', top_n: int = 5, min_df: int = 1, use_maxsum: bool = False, use_mmr: bool = False, diversity: float = 0.5, - nr_candidates: int = 20) -> Union[List[str], List[List[str]]]: + nr_candidates: int = 20, + vectorizer: CountVectorizer = None) -> Union[List[str], List[List[str]]]: """ Extract keywords/keyphrases NOTE: @@ -62,7 +63,7 @@ def extract_keywords(self, Arguments: docs: The document(s) for which to extract keywords/keyphrases - keyphrase_length: Length, in words, of the extracted keywords/keyphrases + keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases stop_words: Stopwords to remove from the document top_n: Return the top n keywords/keyphrases min_df: Minimum document frequency of a word across all documents @@ -75,6 +76,7 @@ def extract_keywords(self, is set to True nr_candidates: The number of candidates to consider if use_maxsum is set to True + vectorizer: Pass in your own CountVectorizer from scikit-learn Returns: keywords: the top n keywords for a document @@ -83,43 +85,47 @@ def extract_keywords(self, if isinstance(docs, str): return self._extract_keywords_single_doc(docs, - keyphrase_length, + keyphrase_ngram_range, stop_words, top_n, use_maxsum, use_mmr, diversity, - nr_candidates) + nr_candidates, + vectorizer) elif isinstance(docs, list): warnings.warn("Although extracting keywords for multiple documents is faster " - "than iterating over single documents, it requires significant memory " + "than iterating over single documents, it requires significantly more memory " "to hold all word embeddings. Use this at your own discretion!") return self._extract_keywords_multiple_docs(docs, - keyphrase_length, + keyphrase_ngram_range, stop_words, top_n, - min_df=min_df) + min_df, + vectorizer) def _extract_keywords_single_doc(self, doc: str, - keyphrase_length: int = 1, + keyphrase_ngram_range: Tuple[int, int] = (1, 1), stop_words: Union[str, List[str]] = 'english', top_n: int = 5, use_maxsum: bool = False, use_mmr: bool = False, diversity: float = 0.5, - nr_candidates: int = 20) -> List[str]: + nr_candidates: int = 20, + vectorizer: CountVectorizer = None) -> List[str]: """ Extract keywords/keyphrases for a single document Arguments: doc: The document for which to extract keywords/keyphrases - keyphrase_length: Length, in words, of the extracted keywords/keyphrases + keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases stop_words: Stopwords to remove from the document top_n: Return the top n keywords/keyphrases use_mmr: Whether to use Max Sum Similarity use_mmr: Whether to use MMR diversity: The diversity of results between 0 and 1 if use_mmr is True nr_candidates: The number of candidates to consider if use_maxsum is set to True + vectorizer: Pass in your own CountVectorizer from scikit-learn Returns: keywords: The top n keywords for a document @@ -127,8 +133,10 @@ def _extract_keywords_single_doc(self, """ try: # Extract Words - n_gram_range = (keyphrase_length, keyphrase_length) - count = CountVectorizer(ngram_range=n_gram_range, stop_words=stop_words).fit([doc]) + if vectorizer: + count = vectorizer.fit([doc]) + else: + count = CountVectorizer(ngram_range=keyphrase_ngram_range, stop_words=stop_words).fit([doc]) words = count.get_feature_names() # Extract Embeddings @@ -150,28 +158,32 @@ def _extract_keywords_single_doc(self, def _extract_keywords_multiple_docs(self, docs: List[str], - keyphrase_length: int = 1, + keyphrase_ngram_range: Tuple[int, int] = (1, 1), stop_words: str = 'english', top_n: int = 5, - min_df: int = 1): + min_df: int = 1, + vectorizer: CountVectorizer = None): """ Extract keywords/keyphrases for a multiple documents This currently does not use MMR as Arguments: docs: The document for which to extract keywords/keyphrases - keyphrase_length: Length, in words, of the extracted keywords/keyphrases + keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases stop_words: Stopwords to remove from the document top_n: Return the top n keywords/keyphrases min_df: The minimum frequency of words + vectorizer: Pass in your own CountVectorizer from scikit-learn Returns: keywords: The top n keywords for a document """ # Extract words - n_gram_range = (keyphrase_length, keyphrase_length) - count = CountVectorizer(ngram_range=n_gram_range, stop_words=stop_words, min_df=min_df).fit(docs) + if vectorizer: + count = vectorizer.fit(docs) + else: + count = CountVectorizer(ngram_range=keyphrase_ngram_range, stop_words=stop_words, min_df=min_df).fit(docs) words = count.get_feature_names() df = count.transform(docs) diff --git a/setup.py b/setup.py index 060ba0c3..7a4b5ff8 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ setuptools.setup( name="keybert", packages=["keybert"], - version="0.1.2", + version="0.1.3", author="Maarten Grootendorst", author_email="maartengrootendorst@gmail.com", description="KeyBERT performs keyword extraction with state-of-the-art transformer models.", diff --git a/tests/test_model.py b/tests/test_model.py index 2590b692..c03fd483 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,47 +1,56 @@ import pytest from .utils import get_test_data +from sklearn.feature_extraction.text import CountVectorizer doc_one, doc_two = get_test_data() -@pytest.mark.parametrize("keyphrase_length", [i+1 for i in range(5)]) -def test_single_doc(keyphrase_length, base_keybert): +@pytest.mark.parametrize("keyphrase_length", [(1, i+1) for i in range(5)]) +@pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")]) +def test_single_doc(keyphrase_length, vectorizer, base_keybert): """ Test whether the keywords are correctly extracted """ top_n = 5 - keywords = base_keybert.extract_keywords(doc_one, keyphrase_length=keyphrase_length, min_df=1, top_n=top_n) + + keywords = base_keybert.extract_keywords(doc_one, + keyphrase_ngram_range=keyphrase_length, + min_df=1, + top_n=top_n, + vectorizer=vectorizer) assert isinstance(keywords, list) assert isinstance(keywords[0], str) assert len(keywords) == top_n for keyword in keywords: - assert len(keyword.split(" ")) == keyphrase_length + assert len(keyword.split(" ")) <= keyphrase_length[1] -@pytest.mark.parametrize("keyphrase_length, mmr, maxsum", [(i+1, truth, not truth) +@pytest.mark.parametrize("keyphrase_length, mmr, maxsum", [((1, i+1), truth, not truth) for i in range(4) for truth in [True, False]]) -def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, base_keybert): +@pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")]) +def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, vectorizer, base_keybert): """ Test extraction of protected single document method """ top_n = 5 keywords = base_keybert._extract_keywords_single_doc(doc_one, top_n=top_n, - keyphrase_length=keyphrase_length, + keyphrase_ngram_range=keyphrase_length, use_mmr=mmr, use_maxsum=maxsum, - diversity=0.5) + diversity=0.5, + vectorizer=vectorizer) assert isinstance(keywords, list) assert isinstance(keywords[0], str) assert len(keywords) == top_n for keyword in keywords: - assert len(keyword.split(" ")) == keyphrase_length + assert len(keyword.split(" ")) <= keyphrase_length[1] -@pytest.mark.parametrize("keyphrase_length", [i+1 for i in range(5)]) +@pytest.mark.parametrize("keyphrase_length", [(1, i+1) for i in range(5)]) def test_extract_keywords_multiple_docs(keyphrase_length, base_keybert): """ Test extractino of protected multiple document method""" top_n = 5 keywords_list = base_keybert._extract_keywords_multiple_docs([doc_one, doc_two], top_n=top_n, - keyphrase_length=keyphrase_length) + keyphrase_ngram_range=keyphrase_length) assert isinstance(keywords_list, list) assert isinstance(keywords_list[0], list) assert len(keywords_list) == 2 @@ -50,7 +59,7 @@ def test_extract_keywords_multiple_docs(keyphrase_length, base_keybert): assert len(keywords) == top_n for keyword in keywords: - assert len(keyword.split(" ")) == keyphrase_length + assert len(keyword.split(" ")) <= keyphrase_length[1] def test_error(base_keybert):