Skip to content

Commit

Permalink
Added nnlm.dataset.both function
Browse files Browse the repository at this point in the history
  • Loading branch information
pakozm committed Jul 31, 2015
1 parent 2ae0080 commit a16e75d
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 26 deletions.
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ Execution
$ . configure.sh
$ make && sudo make install
$ lua test/test.lua
<s> <s> => the
<s> the => u.s.
the u.s. => centers
u.s. centers => for
centers for => disease
for disease => control
disease control => and
control and => prevention
and prevention => in
prevention in => atlanta
<s> <s> => the
<s> the => u.s.
the u.s. => centers
u.s. centers => for
centers for => disease
for disease => control
disease control => and
control and => prevention
and prevention => in
prevention in => atlanta
...
```
4 changes: 2 additions & 2 deletions packages/nnlm.dataset/c_src/nnlm_dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ namespace LanguageModels {
if (length == 1) return tk; // FIXME: refactor this code
(*pat)[i] = tk;
}
return pat.release();
return pat.weakRelease();
}

Token *NNLMDataSetToken::getPatternBunch(const int *indexes,
Expand Down Expand Up @@ -111,7 +111,7 @@ namespace LanguageModels {
if (length == 1) return tk; // FIXME: refactor this code
(*pat)[i] = tk;
}
return pat.release();
return pat.weakRelease();
}

void NNLMDataSetToken::putPattern(int index, Token *pat) {
Expand Down
25 changes: 25 additions & 0 deletions packages/nnlm.dataset/lua_src/nnlm_dataset.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
nnlm = nnlm or {}
nnlm.dataset = nnlm.dataset or {}
nnlm.dataset.both = function(t)
local params = get_table_fields({
corpora = { isa_match=nnlm.corpora, mandatory=true },
order = { type_match="number", mandatory=true },
initial_word = { type_match="string", mandatory=true, default="<s>" },
final_word = { type_match="string", mandatory=true, default="</s>" },
}, t)
local in_ds = nnlm.dataset{
offset = -params.order + 1,
length = params.order - 1,
corpora = params.corpora,
initial_word = params.initial_word,
final_word = params.final_word,
}
local out_ds = nnlm.dataset{
offset = 0,
length = 1,
corpora = params.corpora,
initial_word = params.initial_word,
final_word = params.final_word,
}
return in_ds,out_ds
end
37 changes: 23 additions & 14 deletions test/test.lua
Original file line number Diff line number Diff line change
@@ -1,28 +1,37 @@
require "aprilann" -- it is not necessary but its clear
require "nnlmutils"

local dir = arg[0]:dirname()
local lex = lexClass.load(io.open(dir.."/voc.398"))
local corpora = nnlm.corpora(dir.."/sample.txt",lex.cobj,"<unk>")
local in_ds = nnlm.dataset{
local in_ds,out_ds = nnlm.dataset.both{
corpora = corpora,
length = 2, -- bigrams
offset = -2, -- input requires negative offset
initial_word = "<s>",
final_word = "</s>",
}
local out_ds = nnlm.dataset{
corpora = corpora,
length = 1, -- unigram
offset = 0, -- output shouldn't requires offset
initial_word = "<s>",
final_word = "</s>",
order = 3, -- trigrams
}

local function words_of(tk)
return lex:getWordFromWordId(select(2,tk:to_dense():max()))
local idx = select(2,tk:to_dense():max(2)):squeeze()
local t = {}
for i=1,#idx do
t[i] = lex:getWordFromWordId(idx[i])
end
return t
end

for i,in_pat in in_ds:patterns() do
local out_pat = out_ds:getPattern(i)
print(words_of(in_pat[1]), words_of(in_pat[2]), "=>", words_of(out_pat))
print(iterator.zip(iterator(words_of(in_pat[1])),
iterator(words_of(in_pat[2])),
iterator.duplicate("=>"),
iterator(words_of(out_pat))):concat(" ", "\n"))
end

print("*******************************************************************")

local in_pat = in_ds:getPatternBunch(iterator.range(in_ds:numPatterns()):table())
local out_pat = out_ds:getPatternBunch(iterator.range(in_ds:numPatterns()):table())

print(iterator.zip(iterator(words_of(in_pat[1])),
iterator(words_of(in_pat[2])),
iterator.duplicate("=>"),
iterator(words_of(out_pat))):concat(" ", "\n"))

0 comments on commit a16e75d

Please sign in to comment.