From f4e244ede47b95c5dd880fa49adc5b9d736ee9d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Vial?= Date: Mon, 13 Jan 2020 15:06:37 +0100 Subject: [PATCH] Allow more language models as input (Cleaning - Java part) --- java/src/main/java/NeuralWSDDecode.java | 18 +- java/src/main/java/NeuralWSDDecodeUFSAC.java | 40 +++-- java/src/main/java/NeuralWSDPrepare.java | 3 + .../method/neural/NeuralDataPreparator.java | 2 + .../java/getalp/wsd/utils/ArgumentParser.java | 164 ------------------ 5 files changed, 40 insertions(+), 187 deletions(-) delete mode 100644 java/src/main/java/getalp/wsd/utils/ArgumentParser.java diff --git a/java/src/main/java/NeuralWSDDecode.java b/java/src/main/java/NeuralWSDDecode.java index 682ea91..0ca7b0c 100644 --- a/java/src/main/java/NeuralWSDDecode.java +++ b/java/src/main/java/NeuralWSDDecode.java @@ -21,6 +21,8 @@ public static void main(String[] args) throws Exception new NeuralWSDDecode().decode(args); } + private boolean filterLemma; + private boolean mfsBackoff; private Disambiguator firstSenseDisambiguator; @@ -42,9 +44,10 @@ private void decode(String[] args) throws Exception parser.addArgument("sense_compression_instance_hypernyms", "false"); parser.addArgument("sense_compression_antonyms", "false"); parser.addArgument("sense_compression_file", ""); - parser.addArgument("clear_text", "false"); + parser.addArgument("clear_text", "true"); parser.addArgument("batch_size", "1"); parser.addArgument("truncate_max_length", "150"); + parser.addArgument("filter_lemma", "true"); parser.addArgument("mfs_backoff", "true"); if (!parser.parse(args)) return; @@ -59,8 +62,9 @@ private void decode(String[] args) throws Exception boolean clearText = parser.getArgValueBoolean("clear_text"); int batchSize = parser.getArgValueInteger("batch_size"); int truncateMaxLength = parser.getArgValueInteger("truncate_max_length"); + filterLemma = parser.getArgValueBoolean("filter_lemma"); mfsBackoff = parser.getArgValueBoolean("mfs_backoff"); - + Map senseCompressionClusters = null; if (senseCompressionHypernyms || senseCompressionAntonyms) { @@ -75,6 +79,7 @@ private void decode(String[] args) throws Exception firstSenseDisambiguator = new FirstSenseDisambiguator(WordnetHelper.wn30()); neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights, clearText, batchSize); neuralDisambiguator.lowercaseWords = lowercase; + neuralDisambiguator.filterLemma = filterLemma; neuralDisambiguator.reducedOutputVocabulary = senseCompressionClusters; reader = new BufferedReader(new InputStreamReader(System.in)); @@ -85,9 +90,12 @@ private void decode(String[] args) throws Exception Sentence sentence = new Sentence(line); if (sentence.getWords().size() > truncateMaxLength) { - sentence.getWords().stream().skip(truncateMaxLength).collect(Collectors.toList()).forEach(sentence::removeWord); + sentence.getWords().stream().skip(truncateMaxLength).collect(Collectors.toList()).forEach(sentence::removeWord); + } + if (filterLemma) + { + tagger.tag(sentence.getWords()); } - tagger.tag(sentence.getWords()); sentences.add(sentence); if (sentences.size() >= batchSize) { @@ -113,7 +121,7 @@ private void decodeSentenceBatch(List sentences) throws IOException for (Word word : sentence.getWords()) { writer.write(word.getValue().replace("|", "/")); - if (word.hasAnnotation("lemma") && word.hasAnnotation("pos") && word.hasAnnotation("wsd")) + if (/*word.hasAnnotation("lemma") && word.hasAnnotation("pos") && */ word.hasAnnotation("wsd")) { writer.write("|" + word.getAnnotationValue("wsd")); } diff --git a/java/src/main/java/NeuralWSDDecodeUFSAC.java b/java/src/main/java/NeuralWSDDecodeUFSAC.java index 5775279..67bb41b 100644 --- a/java/src/main/java/NeuralWSDDecodeUFSAC.java +++ b/java/src/main/java/NeuralWSDDecodeUFSAC.java @@ -1,11 +1,13 @@ import getalp.wsd.common.wordnet.WordnetHelper; +import getalp.wsd.method.Disambiguator; +import getalp.wsd.method.FirstSenseDisambiguator; import getalp.wsd.method.neural.NeuralDisambiguator; import getalp.wsd.ufsac.core.Sentence; import getalp.wsd.ufsac.streaming.modifier.StreamingCorpusModifierSentence; import getalp.wsd.ufsac.utils.CorpusPOSTaggerAndLemmatizer; -import getalp.wsd.utils.ArgumentParser; +import getalp.wsd.common.utils.ArgumentParser; import getalp.wsd.utils.WordnetUtils; -import getalp.wsd.common.utils.Wrapper; + import java.util.List; public class NeuralWSDDecodeUFSAC @@ -18,9 +20,11 @@ public static void main(String[] args) throws Exception parser.addArgumentList("weights"); parser.addArgument("input"); parser.addArgument("output"); - parser.addArgument("lowercase", "true"); + parser.addArgument("lowercase", "false"); parser.addArgument("sense_reduction", "true"); - parser.addArgument("lemma_pos_tagged", "false"); + parser.addArgument("clear_text", "true"); + parser.addArgument("batch_size", "1"); + parser.addArgument("mfs_backoff", "true"); if (!parser.parse(args)) return; String pythonPath = parser.getArgValue("python_path"); @@ -30,32 +34,32 @@ public static void main(String[] args) throws Exception String outputPath = parser.getArgValue("output"); boolean lowercase = parser.getArgValueBoolean("lowercase"); boolean senseReduction = parser.getArgValueBoolean("sense_reduction"); - boolean lemmaPOSTagged = parser.getArgValueBoolean("lemma_pos_tagged"); + boolean clearText = parser.getArgValueBoolean("clear_text"); + int batchSize = parser.getArgValueInteger("batch_size"); + boolean mfsBackoff = parser.getArgValueBoolean("mfs_backoff"); - Wrapper lemmaPOSTagger = new Wrapper<>(null); - if (!lemmaPOSTagged) - { - lemmaPOSTagger.obj = new CorpusPOSTaggerAndLemmatizer(); - } - NeuralDisambiguator disambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights); - disambiguator.lowercaseWords = lowercase; - if (senseReduction) disambiguator.reducedOutputVocabulary = WordnetUtils.getReducedSynsetKeysWithHypernyms3(WordnetHelper.wn30()); - else disambiguator.reducedOutputVocabulary = null; + CorpusPOSTaggerAndLemmatizer tagger = new CorpusPOSTaggerAndLemmatizer(); + Disambiguator firstSenseDisambiguator = new FirstSenseDisambiguator(WordnetHelper.wn30()); + NeuralDisambiguator neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights, clearText, batchSize); + neuralDisambiguator.lowercaseWords = lowercase; + if (senseReduction) neuralDisambiguator.reducedOutputVocabulary = WordnetUtils.getReducedSynsetKeysWithHypernyms3(WordnetHelper.wn30()); + else neuralDisambiguator.reducedOutputVocabulary = null; StreamingCorpusModifierSentence modifier = new StreamingCorpusModifierSentence() { public void modifySentence(Sentence sentence) { - if (lemmaPOSTagger.obj != null) + tagger.tag(sentence.getWords()); + neuralDisambiguator.disambiguate(sentence, "wsd"); + if (mfsBackoff) { - lemmaPOSTagger.obj.tag(sentence.getWords()); + firstSenseDisambiguator.disambiguate(sentence, "wsd"); } - disambiguator.disambiguate(sentence, "wsd"); } }; modifier.load(inputPath, outputPath); - disambiguator.close(); + neuralDisambiguator.close(); } } diff --git a/java/src/main/java/NeuralWSDPrepare.java b/java/src/main/java/NeuralWSDPrepare.java index 7f6b6b7..0ce5e6b 100644 --- a/java/src/main/java/NeuralWSDPrepare.java +++ b/java/src/main/java/NeuralWSDPrepare.java @@ -27,6 +27,7 @@ public static void main(String[] args) throws Exception parser.addArgument("exclude_line_length", "150"); parser.addArgument("line_length_tokenizer", "null"); parser.addArgument("lowercase", "false"); + parser.addArgument("filter_lemma", "true"); parser.addArgument("uniform_dash", "false"); parser.addArgument("sense_compression_hypernyms", "true"); parser.addArgument("sense_compression_instance_hypernyms", "false"); @@ -53,6 +54,7 @@ public static void main(String[] args) throws Exception int outputFeatureVocabularyLimit = parser.getArgValueInteger("output_feature_vocabulary_limit"); int maxLineLength = parser.getArgValueInteger("truncate_line_length"); boolean lowercase = parser.getArgValueBoolean("lowercase"); + boolean filterLemma = parser.getArgValueBoolean("filter_lemma"); boolean uniformDash = parser.getArgValueBoolean("uniform_dash"); boolean senseCompressionHypernyms = parser.getArgValueBoolean("sense_compression_hypernyms"); boolean senseCompressionInstanceHypernyms = parser.getArgValueBoolean("sense_compression_instance_hypernyms"); @@ -124,6 +126,7 @@ public static void main(String[] args) throws Exception preparator.maxLineLength = maxLineLength; preparator.lowercaseWords = lowercase; + preparator.filterLemma = filterLemma; preparator.uniformDash = uniformDash; preparator.multisenses = false; preparator.removeAllCoarseGrained = true; diff --git a/java/src/main/java/getalp/wsd/method/neural/NeuralDataPreparator.java b/java/src/main/java/getalp/wsd/method/neural/NeuralDataPreparator.java index 8896a39..2b588c0 100644 --- a/java/src/main/java/getalp/wsd/method/neural/NeuralDataPreparator.java +++ b/java/src/main/java/getalp/wsd/method/neural/NeuralDataPreparator.java @@ -116,6 +116,8 @@ public class NeuralDataPreparator public boolean lowercaseWords = true; + public boolean filterLemma = true; + public boolean addWordKeyFromSenseKey = false; public boolean uniformDash = false; diff --git a/java/src/main/java/getalp/wsd/utils/ArgumentParser.java b/java/src/main/java/getalp/wsd/utils/ArgumentParser.java deleted file mode 100644 index f0828ec..0000000 --- a/java/src/main/java/getalp/wsd/utils/ArgumentParser.java +++ /dev/null @@ -1,164 +0,0 @@ -package getalp.wsd.utils; - -import java.util.HashMap; -import java.util.Map; -import java.util.List; - -import org.apache.commons.cli.*; - - -public class ArgumentParser -{ - private CommandLineParser parser; - - private Options options; - - private Map defaultValues; - - private Map parsedArgs; - - public ArgumentParser() - { - parser = new DefaultParser(); - options = new Options(); - defaultValues = new HashMap<>(); - parsedArgs = new HashMap<>(); - addOptionalArgument("help"); - } - - public void addOptionalArgument(String name) - { - options.addOption(Option.builder().longOpt(name).build()); - defaultValues.put(name, false); - } - - public void addArgument(String name) - { - options.addOption(Option.builder().longOpt(name).hasArg().required().build()); - defaultValues.put(name, null); - } - - public void addArgument(String name, String defaultValue) - { - options.addOption(Option.builder().longOpt(name).hasArg().build()); - defaultValues.put(name, defaultValue); - } - - public void addArgumentList(String name) - { - options.addOption(Option.builder().longOpt(name).hasArgs().required().build()); - defaultValues.put(name, null); - } - - public void addArgumentList(String name, List defaultValue) - { - options.addOption(Option.builder().longOpt(name).hasArgs().build()); - defaultValues.put(name, defaultValue); - } - - public boolean parse(String[] args) - { - return parse(args, false); - } - - public boolean parse(String[] args, boolean printArgs) - { - try - { - CommandLine cd = parser.parse(options, args); - for (Option opt : cd.getOptions()) - { - if (opt.hasArgs()) - { - parsedArgs.put(opt.getLongOpt(), opt.getValuesList()); - } - else if (opt.hasArg()) - { - parsedArgs.put(opt.getLongOpt(), opt.getValue()); - } - else - { - parsedArgs.put(opt.getLongOpt(), true); - } - } - if (hasArg("help")) - { - printArgs(); - return false; - } - else - { - if (printArgs) - { - printArgs(); - } - return true; - } - } - catch (ParseException e) - { - printArgs(); - System.err.println("Error : " + e.getMessage()); - //new HelpFormatter().printHelp("", options); - return false; - } - } - - public boolean hasArg(String name) - { - return getArgValueGeneric(name); - } - - public String getArgValue(String name) - { - return getArgValueGeneric(name); - } - - public boolean getArgValueBoolean(String name) - { - return getArgValue(name).equals("true"); - } - - public int getArgValueInteger(String name) - { - return Integer.valueOf(getArgValue(name)); - } - - public List getArgValueList(String name) - { - return getArgValueGeneric(name); - } - - @SuppressWarnings("unchecked") - private T getArgValueGeneric(String name) - { - if (parsedArgs.containsKey(name)) - { - return (T) parsedArgs.get(name); - } - else - { - return (T) defaultValues.get(name); - } - } - - public void printArgs() - { - System.out.println("Arguments:"); - for (Option opt : options.getOptions()) - { - if (parsedArgs.containsKey(opt.getLongOpt())) - { - System.out.println(" --" + opt.getLongOpt() + " = " + parsedArgs.get(opt.getLongOpt())); - } - else if (defaultValues.get(opt.getLongOpt()) != null) - { - System.out.println(" --" + opt.getLongOpt() + " (default value) = " + defaultValues.get(opt.getLongOpt())); - } - else - { - System.out.println(" --" + opt.getLongOpt() + " (missing)"); - } - } - } -}