diff --git a/README.md b/README.md index 5ca734b..c3705b7 100644 --- a/README.md +++ b/README.md @@ -10,15 +10,15 @@ Execution $ . configure.sh $ make && sudo make install $ lua test/test.lua - => the - the => u.s. -the u.s. => centers -u.s. centers => for -centers for => disease -for disease => control -disease control => and -control and => prevention -and prevention => in -prevention in => atlanta + => the + the => u.s. +the u.s. => centers +u.s. centers => for +centers for => disease +for disease => control +disease control => and +control and => prevention +and prevention => in +prevention in => atlanta ... ``` diff --git a/packages/nnlm.dataset/c_src/nnlm_dataset.cc b/packages/nnlm.dataset/c_src/nnlm_dataset.cc index 3a3d261..ad2fcb1 100644 --- a/packages/nnlm.dataset/c_src/nnlm_dataset.cc +++ b/packages/nnlm.dataset/c_src/nnlm_dataset.cc @@ -79,7 +79,7 @@ namespace LanguageModels { if (length == 1) return tk; // FIXME: refactor this code (*pat)[i] = tk; } - return pat.release(); + return pat.weakRelease(); } Token *NNLMDataSetToken::getPatternBunch(const int *indexes, @@ -111,7 +111,7 @@ namespace LanguageModels { if (length == 1) return tk; // FIXME: refactor this code (*pat)[i] = tk; } - return pat.release(); + return pat.weakRelease(); } void NNLMDataSetToken::putPattern(int index, Token *pat) { diff --git a/packages/nnlm.dataset/lua_src/nnlm_dataset.lua b/packages/nnlm.dataset/lua_src/nnlm_dataset.lua new file mode 100644 index 0000000..4eb6e03 --- /dev/null +++ b/packages/nnlm.dataset/lua_src/nnlm_dataset.lua @@ -0,0 +1,25 @@ +nnlm = nnlm or {} +nnlm.dataset = nnlm.dataset or {} +nnlm.dataset.both = function(t) + local params = get_table_fields({ + corpora = { isa_match=nnlm.corpora, mandatory=true }, + order = { type_match="number", mandatory=true }, + initial_word = { type_match="string", mandatory=true, default="" }, + final_word = { type_match="string", mandatory=true, default="" }, + }, t) + local in_ds = nnlm.dataset{ + offset = -params.order + 1, + length = params.order - 1, + corpora = params.corpora, + initial_word = params.initial_word, + final_word = params.final_word, + } + local out_ds = nnlm.dataset{ + offset = 0, + length = 1, + corpora = params.corpora, + initial_word = params.initial_word, + final_word = params.final_word, + } + return in_ds,out_ds +end diff --git a/test/test.lua b/test/test.lua index f37f2a5..e122897 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1,28 +1,37 @@ +require "aprilann" -- it is not necessary but its clear require "nnlmutils" local dir = arg[0]:dirname() local lex = lexClass.load(io.open(dir.."/voc.398")) local corpora = nnlm.corpora(dir.."/sample.txt",lex.cobj,"") -local in_ds = nnlm.dataset{ +local in_ds,out_ds = nnlm.dataset.both{ corpora = corpora, - length = 2, -- bigrams - offset = -2, -- input requires negative offset - initial_word = "", - final_word = "", -} -local out_ds = nnlm.dataset{ - corpora = corpora, - length = 1, -- unigram - offset = 0, -- output shouldn't requires offset - initial_word = "", - final_word = "", + order = 3, -- trigrams } local function words_of(tk) - return lex:getWordFromWordId(select(2,tk:to_dense():max())) + local idx = select(2,tk:to_dense():max(2)):squeeze() + local t = {} + for i=1,#idx do + t[i] = lex:getWordFromWordId(idx[i]) + end + return t end for i,in_pat in in_ds:patterns() do local out_pat = out_ds:getPattern(i) - print(words_of(in_pat[1]), words_of(in_pat[2]), "=>", words_of(out_pat)) + print(iterator.zip(iterator(words_of(in_pat[1])), + iterator(words_of(in_pat[2])), + iterator.duplicate("=>"), + iterator(words_of(out_pat))):concat(" ", "\n")) end + +print("*******************************************************************") + +local in_pat = in_ds:getPatternBunch(iterator.range(in_ds:numPatterns()):table()) +local out_pat = out_ds:getPatternBunch(iterator.range(in_ds:numPatterns()):table()) + +print(iterator.zip(iterator(words_of(in_pat[1])), + iterator(words_of(in_pat[2])), + iterator.duplicate("=>"), + iterator(words_of(out_pat))):concat(" ", "\n"))