-
Notifications
You must be signed in to change notification settings - Fork 0
/
gensim_utils.py
78 lines (64 loc) · 2.92 KB
/
gensim_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
pyLDAvis Gensim
===============
Helper functions to visualize LDA models trained by Gensim
"""
from __future__ import absolute_import
import funcy as fp
import numpy as np
from scipy.sparse import issparse
#from . import prepare as vis_prepare
import prepare_utils as vis_prepare
def extract_data(topic_model, corpus, dictionary, doc_topic_dists=None):
import gensim
if not gensim.matutils.ismatrix(corpus):
corpus_csc = gensim.matutils.corpus2csc(corpus, num_terms=len(dictionary))
else:
corpus_csc = corpus
# Need corpus to be a streaming gensim list corpus for len and inference functions below:
corpus = gensim.matutils.Sparse2Corpus(corpus_csc)
vocab = list(dictionary.token2id.keys())
# TODO: add the hyperparam to smooth it out? no beta in online LDA impl.. hmm..
# for now, I'll just make sure we don't ever get zeros...
beta = 0.01
fnames_argsort = np.asarray(list(dictionary.token2id.values()), dtype=np.int_)
term_freqs = corpus_csc.sum(axis=1).A.ravel()[fnames_argsort]
term_freqs[term_freqs == 0] = beta
doc_lengths = corpus_csc.sum(axis=0).A.ravel()
assert term_freqs.shape[0] == len(dictionary),\
'Term frequencies and dictionary have different shape {} != {}'.format(
term_freqs.shape[0], len(dictionary))
assert doc_lengths.shape[0] == len(corpus),\
'Document lengths and corpus have different sizes {} != {}'.format(
doc_lengths.shape[0], len(corpus))
if hasattr(topic_model, 'lda_alpha'):
num_topics = len(topic_model.lda_alpha)
else:
num_topics = topic_model.num_topics
if doc_topic_dists is None:
# If its an HDP model.
if hasattr(topic_model, 'lda_beta'):
gamma = topic_model.inference(corpus)
else:
gamma, _ = topic_model.inference(corpus)
doc_topic_dists = gamma / gamma.sum(axis=1)[:, None]
else:
if isinstance(doc_topic_dists, list):
doc_topic_dists = gensim.matutils.corpus2dense(doc_topic_dists, num_topics).T
elif issparse(doc_topic_dists):
doc_topic_dists = doc_topic_dists.T.todense()
doc_topic_dists = doc_topic_dists / doc_topic_dists.sum(axis=1)
assert doc_topic_dists.shape[1] == num_topics,\
'Document topics and number of topics do not match {} != {}'.format(
doc_topic_dists.shape[1], num_topics)
# get the topic-term distribution straight from gensim without
# iterating over tuples
if hasattr(topic_model, 'lda_beta'):
topic = topic_model.lda_beta
else:
topic = topic_model.state.get_lambda()
topic = topic / topic.sum(axis=1)[:, None]
topic_term_dists = topic[:, fnames_argsort]
assert topic_term_dists.shape[0] == doc_topic_dists.shape[1]
return {'topic_term_dists': topic_term_dists, 'doc_topic_dists': doc_topic_dists,
'doc_lengths': doc_lengths, 'vocab': vocab, 'term_frequency': term_freqs}