Skip to content

Commit

Permalink
Merge pull request #264 from Ankush-Chander/update
Browse files Browse the repository at this point in the history
make lemma_graph undirected
  • Loading branch information
ceteri authored Feb 21, 2024
2 parents 73ee61b + 3f61b59 commit d61cd31
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
12 changes: 7 additions & 5 deletions pytextrank/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def __init__ (
# effectively, performs the same work as the `reset()` method;
# called explicitly here for the sake of type annotations
self.elapsed_time: float = 0.0
self.lemma_graph: nx.DiGraph = nx.DiGraph()
self.lemma_graph: nx.Graph = nx.Graph()
self.phrases: typing.List[Phrase] = []
self.ranks: typing.Dict[Lemma, float] = {}
self.seen_lemma: typing.Dict[Lemma, typing.Set[int]] = OrderedDict()
Expand All @@ -323,7 +323,7 @@ def reset (
removing any pre-existing state.
"""
self.elapsed_time = 0.0
self.lemma_graph = nx.DiGraph()
self.lemma_graph = nx.Graph()
self.phrases = []
self.ranks = {}
self.seen_lemma = OrderedDict()
Expand Down Expand Up @@ -400,15 +400,15 @@ def get_personalization ( # pylint: disable=R0201

def _construct_graph (
self
) -> nx.DiGraph:
) -> nx.Graph:
"""
Construct the
[*lemma graph*](https://derwen.ai/docs/ptr/glossary/#lemma-graph).
returns:
a directed graph representing the lemma graph
"""
g = nx.DiGraph()
g = nx.Graph()

# add nodes made of Lemma(lemma, pos)
g.add_nodes_from(self.node_list)
Expand Down Expand Up @@ -571,6 +571,8 @@ def _calc_discounted_normalised_rank (
returns:
normalized rank metric
"""
if len(span) < 1 :
return 0.0
non_lemma = len([tok for tok in span if tok.pos_ not in self.pos_kept])
non_lemma_discount = len(span) / (len(span) + (2.0 * non_lemma) + 1.0)

Expand Down Expand Up @@ -877,7 +879,7 @@ def write_dot (
path:
path for the output file; defaults to `"graph.dot"`
"""
dot = graphviz.Digraph()
dot = graphviz.Graph()

for lemma in self.lemma_graph.nodes():
rank = self.ranks[lemma]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ def test_stop_words ():
for phrase in doc._.phrases[:5]
]

assert "words" in phrases
assert "sentences" in phrases

# add `"word": ["NOUN"]` to the *stop words*, to remove instances
# of `"word"` or `"words"` then see how the ranked phrases differ?

nlp2 = spacy.load("en_core_web_sm")
nlp2.add_pipe("textrank", config={ "stopwords": { "word": ["NOUN"] } })
nlp2.add_pipe("textrank", config={ "stopwords": { "sentence": ["NOUN"] } })

with open("dat/gen.txt", "r") as f:
doc = nlp2(f.read())
Expand Down

0 comments on commit d61cd31

Please sign in to comment.