Skip to content

Commit

Permalink
Merge pull request #73 from larsgrobe/master
Browse files Browse the repository at this point in the history
Support gensim4 LdaModel
  • Loading branch information
stijnh authored Feb 5, 2024
2 parents 27919ae + 48f42ea commit ca7fc5d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
1 change: 1 addition & 0 deletions litstudy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
build_corpus,
train_nmf_model,
train_lda_model,
train_elda_model,
compute_word_distribution,
calculate_embedding,
) # noqa: F401
Expand Down
61 changes: 57 additions & 4 deletions litstudy/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,16 +311,69 @@ def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel:
:param num_topics: The number of topics to train.
:param seed: The seed used for random number generation.
:param kwargs: Arguments passed to `gensim.models.lda.LdaModel`.
:param kwargs: Arguments passed to `gensim.models.lda.LdaModel` (gensim3)
or `gensim.models.ldamodel.LdaModel` (gensim4).
"""
from gensim.models.lda import LdaModel

dic = corpus.dictionary
freqs = corpus.frequencies

model = LdaModel(list(corpus), **kwargs)
from importlib.metadata import version

doc2topic = corpus2dense(model[freqs], num_topics)
gensim_mayor = int(version("gensim").split(".")[0])

if gensim_mayor == 3:
from gensim.models.lda import LdaModel

model = LdaModel(list(corpus), **kwargs)
elif gensim_mayor == 4:
from gensim.models.ldamodel import LdaModel

model = LdaModel(freqs, id2word=dic, num_topics=num_topics, **kwargs)
else:
from sys import exit

exit("LdaModel could not be imported from gensim 3 or 4.")

doc2topic = corpus2dense(model[freqs], num_topics).T
topic2token = model.get_topics()

return TopicModel(dic, doc2topic, topic2token)


def train_elda_model(corpus: Corpus, num_topics, num_models=4, seed=0, **kwargs) -> TopicModel:
"""Train a topic model using ensemble LDA.
:param num_topics: The number of topics to train.
:param num_models: The number of models to train.
:param seed: The seed used for random number generation.
:param kwargs: Arguments passed to `gensim.models.ensemblelda.EnsembleLda` (gensim4).
"""

from importlib.metadata import version

gensim_mayor = int(version("gensim").split(".")[0])

if gensim_mayor <= 3:
from sys import exit

exit("EnsembleLda requires at least gensim 4.")

dic = corpus.dictionary
freqs = corpus.frequencies

from gensim.models.ensemblelda import EnsembleLda

model = EnsembleLda(
topic_model_class="ldamulticore",
corpus=freqs,
id2word=dic,
num_topics=num_topics,
num_models=num_models,
**kwargs
)

doc2topic = corpus2dense(model[freqs], num_topics).T
topic2token = model.get_topics()

return TopicModel(dic, doc2topic, topic2token)
Expand Down
1 change: 1 addition & 0 deletions litstudy/sources/scopus_csv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
support loading Scopus CSV export.
"""

from typing import List, Optional
from ..types import Document, Author, DocumentSet, DocumentIdentifier, Affiliation
from ..common import robust_open
Expand Down

0 comments on commit ca7fc5d

Please sign in to comment.