diff --git a/README.md b/README.md index 158e5b7..f08d9fb 100644 --- a/README.md +++ b/README.md @@ -1,80 +1,116 @@ # disambiguate: Neural Word Sense Disambiguation Toolkit -This repository contains a set of easy-to-use tools for training, evaluating and using neural WSD models. -This is the implementation used in the article [Improving the Coverage and the Generalization Ability of Neural Word Sense Disambiguation through Hypernymy and Hyponymy Relationships](https://arxiv.org/abs/1811.00960), written by Loïc Vial, Benjamin Lecouteux and Didier Schwab. + +This repository contains a set of easy-to-use tools for training, evaluating and using neural WSD models. + +This is the implementation used in the article [Sense Vocabulary Compression through the Semantic Knowledge of WordNet for Neural Word Sense Disambiguation](https://arxiv.org/abs/1905.05677), written by Loïc Vial, Benjamin Lecouteux and Didier Schwab. ## Dependencies - Python (version 3.6 or higher) - - Java (version 8 or higher) - - Maven - -- PyTorch (version 0.4.0 or higher) - +- PyTorch (version 1.0.0 or higher) - +- (optional, for using ELMo) AllenNLP - +- (optional, for using BERT) huggingface's pytorch-pretrained-BERT - - UFSAC - To install **Python**, **Java** and **Maven**, you can use the package manager of your distribution (apt-get, pacman...). -To install **PyTorch**, please follow [this page](https://pytorch.org/get-started). +To install **PyTorch**, please follow the instructions on [this page](https://pytorch.org/get-started). + +To install **AllenNLP** (necessary if using ELMo), please follow the instructions on [this page](https://allennlp.org/tutorials). + +To install **huggingface's pytorch-pretrained-BERT** (necessary if using BERT), please follow the instructions on [this page](https://github.com/huggingface/pytorch-pretrained-BERT). To install **UFSAC**, simply: + - download the content of the [UFSAC repository](https://github.com/getalp/UFSAC) -- go into the `java` folder +- go into the `java` folder - run `mvn install` ## Compilation -Once the dependencies are installed, please run `./java/compile.sh` to compile the Java code. +Once the dependencies are installed, please run `./java/compile.sh` to compile the Java code. -## Use pre-trained models +## Using pre-trained models -At the moment we are only providing one of our best model trained on the SemCor and the WordNet Gloss Tagged, with the vocabulary reduction applied, as described in [our article](https://arxiv.org/abs/1811.00960). +We are currently providing one of our best model trained on the SemCor and the WordNet Gloss Tagged, using BERT embeddings, with the vocabulary compression through the hypernymy/hyponymy relationships applied, as described in [our article](https://arxiv.org/abs/1905.05677). -Here is the link to the data: +Here is the link to the data: Once the data are downloaded and extracted, you can use the following commands (replace `$DATADIR` with the path of the appropriate folder): -- `./decode.sh --data_path $DATADIR --weights $DATADIR/model_weights_wsd` + +### Disambiguating raw text + +- `./decode.sh --data_path $DATADIR --weights $DATADIR/model_weights_wsd0` This script allows to disambiguate raw text from the standard input to the standard output -- `./evaluate.sh --data_path $DATADIR --weights $DATADIR/model_weights_wsd --corpus [UFSAC corpus]...` +### Evaluating a model - This script evaluates a WSD model by computing its coverage, precision, recall and F1 scores on sense annotated corpora in the UFSAC format, with and without first sense backoff. +- `./evaluate.sh --data_path $DATADIR --weights $DATADIR/model_weights_wsd0 --corpus [UFSAC corpus]...` + + This script evaluates a WSD model by computing its coverage, precision, recall and F1 scores on sense annotated corpora in the UFSAC format, with and without first sense backoff. Description of the arguments: -- `--data_path [DIR]` is the path to the directory containing the files needed for describing the model architecture (files `config.json`, `input_vocabularyX` and `output_vocabularyX`) + +- `--data_path [DIR]` is the path to the directory containing the files needed for describing the model architecture (files `config.json`, `input_vocabularyX` and `output_vocabularyX`) - `--weights [FILE]...` is a list of model weights: if multiple weights are given, an ensemble of these weights is used in `decode.sh`, and both the evaluation of the ensemble of weights and the evaluation of each individual weight is performed in `evaluate.sh` - `--corpus [FILE]...` (`evaluate.sh` only) is the list of UFSAC corpora used for evaluating the WSD model -Optional arguments: -- `--lowercase [true|false]` (default `true`) if you want to enable/disable lowercasing of input -- `--sense_reduction [true|false]` (default `true`) if you want to enable/disable the sense vocabulary reduction method. +Optional arguments: + +- `--lowercase [true|false]` (default `false`) if you want to enable/disable lowercasing of input +- `--batch_size [n]` (default `1`) is the batch size. +- `--sense_compression_hypernyms [true|false]` (default `true`) if you want to enable/disable the sense vocabulary compression through the hypernym/hyponym relationships. +- `--sense_compression_file [FILE]` if you want to use another sense vocabulary compression mapping. -UFSAC corpora are available in the [UFSAC repository](https://github.com/getalp/UFSAC). If you want to reproduce our results, please download UFSAC 2.1 and you will find the SemCor (file `semcor.xml`, the WordNet Gloss Tagged (file `wngt.xml`) and all the SemEval/SensEval evaluation corpora that we used. +UFSAC corpora are available in the [UFSAC repository](https://github.com/getalp/UFSAC). If you want to reproduce our results, please download UFSAC 2.1 and you will find the SemCor (file `semcor.xml`, the WordNet Gloss Tagged (file `wngt.xml`) and all the SemEval/SensEval evaluation corpora that we used (files raganato_*.xml). -## Train a WSD model +## Training new WSD models + +### Preparing data + +Call the `./prepare_data.sh` script with the following main arguments: -To train a model, first call the `./prepare_data.sh` script with the following arguments: - `--data_path [DIR]` is the path to the directory that will contain the description of the model (files `config.json`, `input_vocabularyX` and `output_vocabularyX`) and the processed training data (files `train` and `dev`) - `--train [FILE]...` is the list of corpora in UFSAC format used for the training set - `--dev [FILE]...` (optional) is the list of corpora in UFSAC format used for the development set - `--dev_from_train [N]` (default `0`) randomly extracts `N` sentences from the training corpus and use it as development corpus - `--input_features [FEATURE]...` (default `surface_form`) is the list of input features used, as UFSAC attributes. Possible values are, but not limited to, `surface_form`, `lemma`, `pos`, `wn30_key`... - `--input_embeddings [FILE]...` (default `null`) is the list of pre-trained embeddings to use for each input feature. Must be the same number of arguments as `input_features`, use special value `null` if you want to train embeddings as part of the model +- `--input_clear_text [true|false]...` (default `false`) is a list of true/false values (one value for each input feature) indicating if the feature must be used as clear text (e.g. with ELMo/BERT) or as integer values (with classic embeddings). Must be the same number of arguments as `input_features` - `--output_features [FEATURE]...` (default `wn30_key`) is the list of output features to predict by the model, as UFSAC attributes. Possible values are the same as input features - `--lowercase [true|false]` (default `true`) if you want to enable/disable lowercasing of input -- `--sense_reduction [true|false]` (default `true`) if you want to enable/disable the sense vocabulary reduction method. -- `--add_monosemics [true|false]` (default `false`) if you want to consider all monosemic words annotated with their unique sense tag (even if they are not initially annotated) +- `--sense_compression_hypernyms [true|false]` (default `true`) if you want to enable/disable the sense vocabulary compression through the hypernym/hyponym relationships. +- `--sense_compression_file [FILE]` if you want to use another sense vocabulary compression mapping. +- `--add_monosemics [true|false]` (default `false`) if you want to consider all monosemic words annotated with their unique sense tag (even if they are not initially annotated) - `--remove_monosemics [true|false]` (default `false`) if you want to remove the tag of all monosemic words - `--remove_duplicates [true|false]` (default `true`) if you want to remove duplicate sentences from the training set (output features are merged) -Once the data prepared, tweak the generated `config.json` file to your needs (LSTM layers, embeddings size, dropout rate...) +### Training a model (or an ensemble of models) + +Call the `./train.sh` script with the following main arguments: -Finally, use the `./train.sh` script with the following arguments: - `--data_path [DIR]` is the path to the directory generated by `prepare_data.sh` (must contains the files describing the model and the processed training data) - `--model_path [DIR]` is the path where the trained model weights and the training info will be saved - `--batch_size [N]` (default `100`) is the batch size - `--ensemble_count [N]` (default `8`) is the number of different model to train - `--epoch_count [N]` (default `100`) is the number of epoch -- `--eval_frequency [N]` (default `4000`) is the number of batch to process before evaluating the model on the development set. The count resets every epoch, and an eveluation is also performed at the end of every epoch +- `--eval_frequency [N]` (default `4000`) is the number of batch to process before evaluating the model on the development set. The count resets every epoch, and an eveluation is also performed at the end of every epoch - `--update_frequency [N]` (default `1`) is the number of batch to accumulate before backpropagating (if you want to accumulate the gradient of several batches) - `--lr [N]` (default `0.0001`) is the initial learning rate of the optimizer (Adam) +- `--input_embeddings_size [N]` (default `300`) is the size of input embeddings (if not using pre-trained embeddings, BERT nor ELMo) +- `--input_elmo_model [MODEL]` is the name of the ELMo model to use (one of `small`, `medium` or `original`), it will be downloaded automatically. +- `--input_bert_model [MODEL]` is the name of the BERT model to use (of the form `bert-{base,large}-(multilingual-(un)cased`), it will be downloaded automatically. +- `--encoder_type [ENCODER]` (default `lstm`) is one of `lstm` or `transformer`. +- `--encoder_lstm_hidden_size [N]` (default `1000`) +- `--encoder_lstm_layers [N]` (default `1`) +- `--encoder_lstm_dropout [N]` (default `0.5`) +- `--encoder_transformer_hidden_size [N]` (default `512`) +- `--encoder_transformer_layers [N]` (default `6`) +- `--encoder_transformer_heads [N]` (default `8`) +- `--encoder_transformer_positional_encoding [true|false]` (default `true`) +- `--encoder_transformer_dropout [N]` (default `0.1`) - `--reset [true|false]` (default `false`) if you do not want to resume a previous training. Be careful as it will effectively resets the training state and the model weights saved in the `--model_path` diff --git a/java/pom.xml b/java/pom.xml index ebcf9e2..59ada34 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -44,11 +44,6 @@ commons-lang3 3.5 - - commons-cli - commons-cli - 1.4 - com.panayotis.javaplot javaplot diff --git a/java/src/main/java/NeuralWSDDecode.java b/java/src/main/java/NeuralWSDDecode.java index 9de7f9d..682ea91 100644 --- a/java/src/main/java/NeuralWSDDecode.java +++ b/java/src/main/java/NeuralWSDDecode.java @@ -1,50 +1,118 @@ import getalp.wsd.common.wordnet.WordnetHelper; +import getalp.wsd.method.Disambiguator; +import getalp.wsd.method.FirstSenseDisambiguator; import getalp.wsd.method.neural.NeuralDisambiguator; import getalp.wsd.ufsac.core.Sentence; import getalp.wsd.ufsac.core.Word; import getalp.wsd.ufsac.utils.CorpusPOSTaggerAndLemmatizer; -import getalp.wsd.utils.ArgumentParser; +import getalp.wsd.common.utils.ArgumentParser; import getalp.wsd.utils.WordnetUtils; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.InputStreamReader; -import java.io.OutputStreamWriter; + +import java.io.*; +import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; public class NeuralWSDDecode { public static void main(String[] args) throws Exception + { + new NeuralWSDDecode().decode(args); + } + + private boolean mfsBackoff; + + private Disambiguator firstSenseDisambiguator; + + private NeuralDisambiguator neuralDisambiguator; + + private BufferedWriter writer; + + private BufferedReader reader; + + private void decode(String[] args) throws Exception { ArgumentParser parser = new ArgumentParser(); parser.addArgument("python_path"); parser.addArgument("data_path"); parser.addArgumentList("weights"); - parser.addArgument("lowercase", "true"); - parser.addArgument("sense_reduction", "true"); + parser.addArgument("lowercase", "false"); + parser.addArgument("sense_compression_hypernyms", "true"); + parser.addArgument("sense_compression_instance_hypernyms", "false"); + parser.addArgument("sense_compression_antonyms", "false"); + parser.addArgument("sense_compression_file", ""); + parser.addArgument("clear_text", "false"); + parser.addArgument("batch_size", "1"); + parser.addArgument("truncate_max_length", "150"); + parser.addArgument("mfs_backoff", "true"); if (!parser.parse(args)) return; String pythonPath = parser.getArgValue("python_path"); String dataPath = parser.getArgValue("data_path"); List weights = parser.getArgValueList("weights"); boolean lowercase = parser.getArgValueBoolean("lowercase"); - boolean senseReduction = parser.getArgValueBoolean("sense_reduction"); + boolean senseCompressionHypernyms = parser.getArgValueBoolean("sense_compression_hypernyms"); + boolean senseCompressionInstanceHypernyms = parser.getArgValueBoolean("sense_compression_instance_hypernyms"); + boolean senseCompressionAntonyms = parser.getArgValueBoolean("sense_compression_antonyms"); + String senseCompressionFile = parser.getArgValue("sense_compression_file"); + boolean clearText = parser.getArgValueBoolean("clear_text"); + int batchSize = parser.getArgValueInteger("batch_size"); + int truncateMaxLength = parser.getArgValueInteger("truncate_max_length"); + mfsBackoff = parser.getArgValueBoolean("mfs_backoff"); + + Map senseCompressionClusters = null; + if (senseCompressionHypernyms || senseCompressionAntonyms) + { + senseCompressionClusters = WordnetUtils.getSenseCompressionClusters(WordnetHelper.wn30(), senseCompressionHypernyms, senseCompressionInstanceHypernyms, senseCompressionAntonyms); + } + if (!senseCompressionFile.isEmpty()) + { + senseCompressionClusters = WordnetUtils.getSenseCompressionClustersFromFile(senseCompressionFile); + } CorpusPOSTaggerAndLemmatizer tagger = new CorpusPOSTaggerAndLemmatizer(); - NeuralDisambiguator disambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights); - disambiguator.lowercaseWords = lowercase; - if (senseReduction) disambiguator.reducedOutputVocabulary = WordnetUtils.getReducedSynsetKeysWithHypernyms3(WordnetHelper.wn30()); - else disambiguator.reducedOutputVocabulary = null; - - BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); - BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(System.out)); - for (String line = reader.readLine() ; line != null ; line = reader.readLine()) + firstSenseDisambiguator = new FirstSenseDisambiguator(WordnetHelper.wn30()); + neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights, clearText, batchSize); + neuralDisambiguator.lowercaseWords = lowercase; + neuralDisambiguator.reducedOutputVocabulary = senseCompressionClusters; + + reader = new BufferedReader(new InputStreamReader(System.in)); + writer = new BufferedWriter(new OutputStreamWriter(System.out)); + List sentences = new ArrayList<>(); + for (String line = reader.readLine(); line != null ; line = reader.readLine()) { Sentence sentence = new Sentence(line); + if (sentence.getWords().size() > truncateMaxLength) + { + sentence.getWords().stream().skip(truncateMaxLength).collect(Collectors.toList()).forEach(sentence::removeWord); + } tagger.tag(sentence.getWords()); - disambiguator.disambiguate(sentence, "wsd"); + sentences.add(sentence); + if (sentences.size() >= batchSize) + { + decodeSentenceBatch(sentences); + sentences.clear(); + } + } + decodeSentenceBatch(sentences); + writer.close(); + reader.close(); + neuralDisambiguator.close(); + } + + private void decodeSentenceBatch(List sentences) throws IOException + { + neuralDisambiguator.disambiguateDynamicSentenceBatch(sentences, "wsd", ""); + for (Sentence sentence : sentences) + { + if (mfsBackoff) + { + firstSenseDisambiguator.disambiguate(sentence, "wsd"); + } for (Word word : sentence.getWords()) { - writer.write(word.getValue().replace("|", "")); + writer.write(word.getValue().replace("|", "/")); if (word.hasAnnotation("lemma") && word.hasAnnotation("pos") && word.hasAnnotation("wsd")) { writer.write("|" + word.getAnnotationValue("wsd")); @@ -52,11 +120,8 @@ public static void main(String[] args) throws Exception writer.write(" "); } writer.newLine(); - writer.flush(); } - writer.close(); - reader.close(); - disambiguator.close(); + writer.flush(); } } diff --git a/java/src/main/java/NeuralWSDPrepare.java b/java/src/main/java/NeuralWSDPrepare.java index 4b8c30d..7f6b6b7 100644 --- a/java/src/main/java/NeuralWSDPrepare.java +++ b/java/src/main/java/NeuralWSDPrepare.java @@ -1,9 +1,8 @@ -import java.util.Collections; -import java.util.List; +import java.util.*; import getalp.wsd.common.wordnet.WordnetHelper; import getalp.wsd.method.neural.NeuralDataPreparator; -import getalp.wsd.utils.ArgumentParser; +import getalp.wsd.common.utils.ArgumentParser; import getalp.wsd.utils.WordnetUtils; public class NeuralWSDPrepare @@ -14,13 +13,26 @@ public static void main(String[] args) throws Exception parser.addArgument("data_path"); parser.addArgumentList("train"); parser.addArgumentList("dev", Collections.emptyList()); + parser.addArgument("dev_from_train", "0"); + parser.addArgument("corpus_format", "xml"); + parser.addArgumentList("txt_corpus_features", Collections.singletonList("null")); parser.addArgumentList("input_features", Collections.singletonList("surface_form")); parser.addArgumentList("input_embeddings", Collections.singletonList("null")); + parser.addArgumentList("input_vocabulary", Collections.singletonList("null")); + parser.addArgument("input_vocabulary_limit", "-1"); + parser.addArgumentList("input_clear_text", Collections.singletonList("false")); parser.addArgumentList("output_features", Collections.singletonList("wn30_key")); - parser.addArgument("lowercase", "true"); + parser.addArgument("output_feature_vocabulary_limit", "-1"); + parser.addArgument("truncate_line_length", "80"); + parser.addArgument("exclude_line_length", "150"); + parser.addArgument("line_length_tokenizer", "null"); + parser.addArgument("lowercase", "false"); parser.addArgument("uniform_dash", "false"); - parser.addArgument("dev_from_train", "0"); - parser.addArgument("sense_reduction", "true"); + parser.addArgument("sense_compression_hypernyms", "true"); + parser.addArgument("sense_compression_instance_hypernyms", "false"); + parser.addArgument("sense_compression_antonyms", "false"); + parser.addArgument("sense_compression_file", ""); + parser.addArgument("add_wordkey_from_sensekey", "false"); parser.addArgument("add_monosemics", "false"); parser.addArgument("remove_monosemics", "false"); parser.addArgument("remove_duplicates", "true"); @@ -29,19 +41,52 @@ public static void main(String[] args) throws Exception String dataPath = parser.getArgValue("data_path"); List trainingCorpusPaths = parser.getArgValueList("train"); List devCorpusPaths = parser.getArgValueList("dev"); + int devFromTrain = parser.getArgValueInteger("dev_from_train"); + String corpusFormat = parser.getArgValue("corpus_format"); + List txtCorpusFeatures = parser.getArgValueList("txt_corpus_features"); List inputFeatures = parser.getArgValueList("input_features"); List inputEmbeddings = parser.getArgValueList("input_embeddings"); + List inputVocabulary = parser.getArgValueList("input_vocabulary"); + int inputVocabularyLimit = parser.getArgValueInteger("input_vocabulary_limit"); + List inputClearText = parser.getArgValueBooleanList("input_clear_text"); List outputFeatures = parser.getArgValueList("output_features"); + int outputFeatureVocabularyLimit = parser.getArgValueInteger("output_feature_vocabulary_limit"); + int maxLineLength = parser.getArgValueInteger("truncate_line_length"); boolean lowercase = parser.getArgValueBoolean("lowercase"); boolean uniformDash = parser.getArgValueBoolean("uniform_dash"); - int devFromTrain = parser.getArgValueInteger("dev_from_train"); - boolean senseReduction = parser.getArgValueBoolean("sense_reduction"); + boolean senseCompressionHypernyms = parser.getArgValueBoolean("sense_compression_hypernyms"); + boolean senseCompressionInstanceHypernyms = parser.getArgValueBoolean("sense_compression_instance_hypernyms"); + boolean senseCompressionAntonyms = parser.getArgValueBoolean("sense_compression_antonyms"); + String senseCompressionFile = parser.getArgValue("sense_compression_file"); + boolean addWordKeyFromSenseKey = parser.getArgValueBoolean("add_wordkey_from_sensekey"); boolean addMonosemics = parser.getArgValueBoolean("add_monosemics"); boolean removeMonosemics = parser.getArgValueBoolean("remove_monosemics"); boolean removeDuplicateSentences = parser.getArgValueBoolean("remove_duplicates"); + Map senseCompressionClusters = null; + if (senseCompressionHypernyms || senseCompressionAntonyms) + { + senseCompressionClusters = WordnetUtils.getSenseCompressionClusters(WordnetHelper.wn30(), senseCompressionHypernyms, senseCompressionInstanceHypernyms, senseCompressionAntonyms); + } + if (!senseCompressionFile.isEmpty()) + { + senseCompressionClusters = WordnetUtils.getSenseCompressionClustersFromFile(senseCompressionFile); + } + + inputEmbeddings = padList(inputEmbeddings, inputFeatures.size(), "null"); + inputVocabulary = padList(inputVocabulary, inputFeatures.size(), "null"); + inputClearText = padList(inputClearText, inputFeatures.size(), false); + NeuralDataPreparator preparator = new NeuralDataPreparator(); + preparator.addWordKeyFromSenseKey = addWordKeyFromSenseKey; + + if (txtCorpusFeatures.size() == 1 && txtCorpusFeatures.get(0).equals("null")) + { + txtCorpusFeatures = Collections.emptyList(); + } + preparator.txtCorpusFeatures = txtCorpusFeatures; + preparator.setOutputDirectoryPath(dataPath); for (String corpusPath : trainingCorpusPaths) @@ -54,56 +99,51 @@ public static void main(String[] args) throws Exception preparator.addDevelopmentCorpus(corpusPath); } - assert(inputFeatures.size() == inputEmbeddings.size()); - - for (int i = 0 ; i < inputFeatures.size() ; i++) + for (int i = 0; i < inputFeatures.size(); i++) { String inputFeatureAnnotationName = inputFeatures.get(i); - String inputFeatureEmbeddings = inputEmbeddings.get(i); - if (inputFeatureEmbeddings.equals("null")) - { - preparator.addInputFeature(inputFeatureAnnotationName); - } - else - { - preparator.addInputFeature(inputFeatureAnnotationName, inputFeatureEmbeddings); - } + String inputFeatureEmbeddings = inputEmbeddings.get(i).equals("null") ? null : inputEmbeddings.get(i); + String inputFeatureVocabulary = inputVocabulary.get(i).equals("null") ? null : inputVocabulary.get(i); + preparator.addInputFeature(inputFeatureAnnotationName, inputFeatureEmbeddings, inputFeatureVocabulary); + } + + if (outputFeatures.size() == 1 && outputFeatures.get(0).equals("null")) + { + outputFeatures.clear(); } - for (int i = 0 ; i < outputFeatures.size() ; i++) + for (int i = 0; i < outputFeatures.size(); i++) { - preparator.addOutputFeature(outputFeatures.get(i)); - /* - String outputFeatureAnnotationName = outputFeatures.get(i); - String outputFeatureVocabulary = "null"; - if (i + 1 < outputFeatures.size()) - { - outputFeatureVocabulary = outputFeatures.get(i + 1); - } - if (outputFeatureVocabulary.equals("null")) - { - preparator.addOutputFeature(outputFeatureAnnotationName); - } - else - { - preparator.addOutputFeature(outputFeatureAnnotationName, outputFeatureVocabulary); - } - */ + preparator.addOutputFeature(outputFeatures.get(i), null); } - preparator.maxLineLength = 80; + preparator.setCorpusFormat(corpusFormat); + preparator.setInputVocabularyLimit(inputVocabularyLimit); + preparator.setInputClearText(inputClearText); + preparator.setOutputFeatureVocabularyLimit(outputFeatureVocabularyLimit); + + preparator.maxLineLength = maxLineLength; preparator.lowercaseWords = lowercase; preparator.uniformDash = uniformDash; preparator.multisenses = false; preparator.removeAllCoarseGrained = true; preparator.addMonosemics = addMonosemics; preparator.removeMonosemics = removeMonosemics; - if (senseReduction) preparator.reducedOutputVocabulary = WordnetUtils.getReducedSynsetKeysWithHypernyms3(WordnetHelper.wn30()); - else preparator.reducedOutputVocabulary = null; + preparator.reducedOutputVocabulary = senseCompressionClusters; preparator.additionalDevFromTrainSize = devFromTrain; preparator.removeDuplicateSentences = removeDuplicateSentences; preparator.prepareTrainingFile(); } + + private static List padList(List list, int padSize, T padValue) + { + List newList = new ArrayList<>(list); + while (newList.size() < padSize) + { + newList.add(padValue); + } + return newList; + } } diff --git a/java/src/main/java/NeuralWSDTest.java b/java/src/main/java/NeuralWSDTest.java index adab74f..6b97832 100644 --- a/java/src/main/java/NeuralWSDTest.java +++ b/java/src/main/java/NeuralWSDTest.java @@ -1,6 +1,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import getalp.wsd.common.wordnet.WordnetHelper; import getalp.wsd.evaluation.WSDEvaluator; @@ -11,7 +12,7 @@ import getalp.wsd.method.result.DisambiguationResult; import getalp.wsd.method.result.MultipleDisambiguationResult; import getalp.wsd.ufsac.core.*; -import getalp.wsd.utils.ArgumentParser; +import getalp.wsd.common.utils.ArgumentParser; import getalp.wsd.utils.WordnetUtils; public class NeuralWSDTest @@ -26,7 +27,13 @@ public class NeuralWSDTest private boolean lowercase; - private boolean senseReduction; + private Map senseCompressionClusters; + + private boolean filterLemma; + + private boolean clearText; + + private int batchSize; private Disambiguator monosemicDisambiguator; @@ -34,7 +41,7 @@ public class NeuralWSDTest private WSDEvaluator evaluator; - public void test(String[] args) throws Exception + private void test(String[] args) throws Exception { ArgumentParser parser = new ArgumentParser(); parser.addArgument("python_path"); @@ -48,8 +55,14 @@ public void test(String[] args) throws Exception "ufsac-public-2.1/raganato_semeval2015.xml", "ufsac-public-2.1/raganato_ALL.xml", "ufsac-public-2.1/semeval2007task7.xml")); - parser.addArgument("lowercase", "true"); - parser.addArgument("sense_reduction", "true"); + parser.addArgument("lowercase", "false"); + parser.addArgument("sense_compression_hypernyms", "true"); + parser.addArgument("sense_compression_instance_hypernyms", "false"); + parser.addArgument("sense_compression_antonyms", "false"); + parser.addArgument("sense_compression_file", ""); + parser.addArgument("filter_lemma", "true"); + parser.addArgument("clear_text", "false"); + parser.addArgument("batch_size", "1"); if (!parser.parse(args, true)) return; pythonPath = parser.getArgValue("python_path"); @@ -57,7 +70,23 @@ public void test(String[] args) throws Exception weights = parser.getArgValueList("weights"); testCorpusPaths = parser.getArgValueList("corpus"); lowercase = parser.getArgValueBoolean("lowercase"); - senseReduction = parser.getArgValueBoolean("sense_reduction"); + boolean senseCompressionHypernyms = parser.getArgValueBoolean("sense_compression_hypernyms"); + boolean senseCompressionInstanceHypernyms = parser.getArgValueBoolean("sense_compression_instance_hypernyms"); + boolean senseCompressionAntonyms = parser.getArgValueBoolean("sense_compression_antonyms"); + String senseCompressionFile = parser.getArgValue("sense_compression_file"); + filterLemma = parser.getArgValueBoolean("filter_lemma"); + clearText = parser.getArgValueBoolean("clear_text"); + batchSize = parser.getArgValueInteger("batch_size"); + + senseCompressionClusters = null; + if (senseCompressionHypernyms || senseCompressionAntonyms) + { + senseCompressionClusters = WordnetUtils.getSenseCompressionClusters(WordnetHelper.wn30(), senseCompressionHypernyms, senseCompressionInstanceHypernyms, senseCompressionAntonyms); + } + if (!senseCompressionFile.isEmpty()) + { + senseCompressionClusters = WordnetUtils.getSenseCompressionClustersFromFile(senseCompressionFile); + } monosemicDisambiguator = new MonosemicDisambiguator(WordnetHelper.wn30()); firstSenseDisambiguator = new FirstSenseDisambiguator(WordnetHelper.wn30()); @@ -79,16 +108,10 @@ public void test(String[] args) throws Exception private void evaluate_ensemble() throws Exception { - NeuralDisambiguator neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights); + NeuralDisambiguator neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights, clearText, batchSize); neuralDisambiguator.lowercaseWords = lowercase; - if (senseReduction) - { - neuralDisambiguator.reducedOutputVocabulary = WordnetUtils.getReducedSynsetKeysWithHypernyms3(WordnetHelper.wn30()); - } - else - { - neuralDisambiguator.reducedOutputVocabulary = null; - } + neuralDisambiguator.filterLemma = filterLemma; + neuralDisambiguator.reducedOutputVocabulary = senseCompressionClusters; for (String testCorpusPath : testCorpusPaths) { System.out.println("Evaluate on corpus " + testCorpusPath); @@ -107,37 +130,41 @@ private void evaluate_ensemble() throws Exception private void evaluate_mean_scores() throws Exception { List neuralDisambiguators = new ArrayList<>(); - for (int i = 0; i < weights.size(); i++) + for (String weight : weights) { - neuralDisambiguators.add(new NeuralDisambiguator(pythonPath, dataPath, weights.get(i))); + NeuralDisambiguator neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weight, clearText, batchSize); + neuralDisambiguator.lowercaseWords = lowercase; + neuralDisambiguator.filterLemma = filterLemma; + neuralDisambiguator.reducedOutputVocabulary = senseCompressionClusters; + neuralDisambiguators.add(neuralDisambiguator); } for (String testCorpusPath : testCorpusPaths) { System.out.println("Evaluate on corpus " + testCorpusPath); - MultipleDisambiguationResult results = new MultipleDisambiguationResult(); + MultipleDisambiguationResult resultsBackoffZero = new MultipleDisambiguationResult(); + MultipleDisambiguationResult resultsBackoffMonosemics = new MultipleDisambiguationResult(); + MultipleDisambiguationResult resultsBackoffFirstSense = new MultipleDisambiguationResult(); for (int i = 0; i < weights.size(); i++) { NeuralDisambiguator neuralDisambiguator = neuralDisambiguators.get(i); - neuralDisambiguator.lowercaseWords = lowercase; - if (senseReduction) - { - neuralDisambiguator.reducedOutputVocabulary = WordnetUtils.getReducedSynsetKeysWithHypernyms3(WordnetHelper.wn30()); - } - else - { - neuralDisambiguator.reducedOutputVocabulary = null; - } Corpus testCorpus = Corpus.loadFromXML(testCorpusPath); System.out.println("" + i + " Evaluate without backoff"); - evaluator.evaluate(neuralDisambiguator, testCorpus, "wn30_key"); + DisambiguationResult resultBackoffZero = evaluator.evaluate(neuralDisambiguator, testCorpus, "wn30_key"); System.out.println("" + i + " Evaluate with monosemics"); - evaluator.evaluate(monosemicDisambiguator, testCorpus, "wn30_key"); + DisambiguationResult resultBackoffMonosemics = evaluator.evaluate(monosemicDisambiguator, testCorpus, "wn30_key"); System.out.println("" + i + " Evaluate with backoff first sense"); - DisambiguationResult result = evaluator.evaluate(firstSenseDisambiguator, testCorpus, "wn30_key"); - results.addDisambiguationResult(result); + DisambiguationResult resultBackoffFirstSense = evaluator.evaluate(firstSenseDisambiguator, testCorpus, "wn30_key"); + resultsBackoffZero.addDisambiguationResult(resultBackoffZero); + resultsBackoffMonosemics.addDisambiguationResult(resultBackoffMonosemics); + resultsBackoffFirstSense.addDisambiguationResult(resultBackoffFirstSense); } - System.out.println("Mean Scores : " + results.scoreMean()); - System.out.println("Standard Deviation Scores : " + results.scoreStandardDeviation()); + System.out.println(); + System.out.println("Mean of scores without backoff: " + resultsBackoffZero.scoreMean()); + System.out.println("Standard deviation without backoff: " + resultsBackoffZero.scoreStandardDeviation()); + System.out.println("Mean of scores with monosemics: " + resultsBackoffMonosemics.scoreMean()); + System.out.println("Standard deviation with monosemics: " + resultsBackoffMonosemics.scoreStandardDeviation()); + System.out.println("Mean of scores with backoff first sense: " + resultsBackoffFirstSense.scoreMean()); + System.out.println("Standard deviation with backoff first sense: " + resultsBackoffFirstSense.scoreStandardDeviation()); System.out.println(); } for (int i = 0; i < weights.size(); i++) diff --git a/java/src/main/java/getalp/wsd/embeddings/TextualModelLoader.java b/java/src/main/java/getalp/wsd/embeddings/TextualModelLoader.java index f993abc..dad6f99 100644 --- a/java/src/main/java/getalp/wsd/embeddings/TextualModelLoader.java +++ b/java/src/main/java/getalp/wsd/embeddings/TextualModelLoader.java @@ -18,7 +18,7 @@ public class TextualModelLoader public TextualModelLoader() { - this(true); + this(false); } public TextualModelLoader(boolean verbose) diff --git a/java/src/main/java/getalp/wsd/evaluation/WSDEvaluator.java b/java/src/main/java/getalp/wsd/evaluation/WSDEvaluator.java index 99cca3c..420866a 100644 --- a/java/src/main/java/getalp/wsd/evaluation/WSDEvaluator.java +++ b/java/src/main/java/getalp/wsd/evaluation/WSDEvaluator.java @@ -5,9 +5,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.List; +import java.util.*; +import getalp.wsd.common.utils.POSConverter; import getalp.wsd.common.wordnet.WordnetHelper; import getalp.wsd.method.Disambiguator; import getalp.wsd.method.result.MultipleDisambiguationResult; @@ -80,8 +80,15 @@ public DisambiguationResult evaluate(Disambiguator disambiguator, Corpus corpus, long endTime = System.currentTimeMillis(); double time = (endTime - startTime) / 1000.0; totalScore.time = time; - print("; good/bad/missed/total : " + totalScore.good + "/" + totalScore.bad + "/" + totalScore.missed + "/" + totalScore.total); - println(" ; C/P/R/F1 : " + String.format("%.4f", totalScore.coverage()) + "/" + String.format("%.4f", totalScore.scorePrecision()) + "/" + String.format("%.4f", totalScore.scoreRecall()) + "/" + String.format("%.4f", totalScore.scoreF1()) + " ; time : " + totalScore.time + " seconds"); + print("; good/bad/missed/total : " + totalScore.good + "/" + totalScore.bad + "/" + totalScore.missed() + "/" + totalScore.total); + print(" ; C/P/R/F1 : " + String.format("%.4f", totalScore.coverage()) + "/" + String.format("%.4f", totalScore.scorePrecision()) + "/" + String.format("%.4f", totalScore.scoreRecall()) + "/" + String.format("%.4f", totalScore.scoreF1())); + for (String pos : Arrays.asList("n", "v", "a", "r", "x")) + { + print(" ; [" + pos + "] good/bad/missed/total : " + totalScore.goodPerPOS.get(pos) + "/" + totalScore.badPerPOS.get(pos) + "/" + totalScore.missedPerPOS(pos) + "/" + totalScore.totalPerPOS.get(pos)); + print(" ; [" + pos + "] C/P/R/F1 : " + String.format("%.4f", totalScore.coveragePerPOS(pos)) + "/" + String.format("%.4f", totalScore.scorePrecisionPerPOS(pos)) + "/" + String.format("%.4f", totalScore.scoreRecallPerPOS(pos)) + "/" + String.format("%.4f", totalScore.scoreF1PerPOS(pos))); + + } + println(" ; time : " + totalScore.time + " seconds"); printFailed(); saveResultToFile(corpus.getDocuments(), "wsd_test"); return totalScore; @@ -99,12 +106,11 @@ public DisambiguationResult computeDisambiguationResult(List wordList, Str public DisambiguationResult computeDisambiguationResult(List wordList, String referenceSenseTag, String candidateSenseTag, String confidenceValueTag, double confidenceThreshold, WordnetHelper wn) { - int total = 0; - int good = 0; - int bad = 0; + DisambiguationResult res = new DisambiguationResult(); for (int i = 0 ; i < wordList.size() ; i++) { Word word = wordList.get(i); + String wordPOS = POSConverter.toWNPOS(word.getAnnotationValue("pos")); List referenceSenseKeys = word.getAnnotationValues(referenceSenseTag, ";"); if (referenceSenseKeys.isEmpty()) continue; @@ -122,7 +128,8 @@ public DisambiguationResult computeDisambiguationResult(List wordList, Str } if (referenceSynsetKeys.isEmpty()) continue; - total += 1; + res.total += 1; + res.totalPerPOS.put(wordPOS, res.totalPerPOS.get(wordPOS) + 1); String candidateSenseKey = word.getAnnotationValue(candidateSenseTag); if (candidateSenseKey.isEmpty()) continue; @@ -134,18 +141,21 @@ public DisambiguationResult computeDisambiguationResult(List wordList, Str if (confidenceValue != Double.POSITIVE_INFINITY && confidenceValue < confidenceThreshold) continue; } String candidateSynsetKey = wn.getSynsetKeyFromSenseKey(candidateSenseKey); - bad += 1; + res.bad += 1; + res.badPerPOS.put(wordPOS, res.badPerPOS.get(wordPOS) + 1); for (String refSynsetKey : referenceSynsetKeys) { if (refSynsetKey.equals(candidateSynsetKey)) { - good += 1; - bad -= 1; + res.good += 1; + res.bad -= 1; + res.goodPerPOS.put(wordPOS, res.goodPerPOS.get(wordPOS) + 1); + res.badPerPOS.put(wordPOS, res.badPerPOS.get(wordPOS) - 1); break; } } } - return new DisambiguationResult(total, good, bad); + return res; } private void saveResultToFile(List documents, String candidateSenseTag) diff --git a/java/src/main/java/getalp/wsd/method/DisambiguatorContextSentenceBatch.java b/java/src/main/java/getalp/wsd/method/DisambiguatorContextSentenceBatch.java new file mode 100644 index 0000000..0925ca5 --- /dev/null +++ b/java/src/main/java/getalp/wsd/method/DisambiguatorContextSentenceBatch.java @@ -0,0 +1,69 @@ +package getalp.wsd.method; + +import getalp.wsd.ufsac.core.*; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public abstract class DisambiguatorContextSentenceBatch extends Disambiguator +{ + private int batchSize; + + public DisambiguatorContextSentenceBatch(int batchSize) + { + this.batchSize = batchSize; + } + + @Override + public void disambiguate(Corpus corpus, String newSenseTags, String confidenceTag) + { + disambiguateDynamicSentenceBatch(corpus.getSentences(), newSenseTags, confidenceTag); + } + + @Override + public void disambiguate(Document document, String newSenseTags, String confidenceTag) + { + disambiguateDynamicSentenceBatch(document.getSentences(), newSenseTags, confidenceTag); + } + + @Override + public void disambiguate(Paragraph paragraph, String newSenseTags, String confidenceTag) + { + disambiguateDynamicSentenceBatch(paragraph.getSentences(), newSenseTags, confidenceTag); + } + + @Override + public void disambiguate(Sentence sentence, String newSenseTags, String confidenceTag) + { + disambiguateDynamicSentenceBatch(Collections.singletonList(sentence), newSenseTags, confidenceTag); + } + + @Override + public void disambiguate(List words, String newSenseTags, String confidenceTag) + { + disambiguateDynamicSentenceBatch(Collections.singletonList(new Sentence(words)), newSenseTags, confidenceTag); + } + + public void disambiguateDynamicSentenceBatch(List originalSentences, String newSenseTags, String confidenceTag) + { + List sentences = new ArrayList<>(originalSentences); + while (sentences.size() > batchSize) + { + List subSentences = sentences.subList(0, batchSize); + disambiguateFixedSentenceBatch(subSentences, newSenseTags, confidenceTag); + subSentences.clear(); + } + if (!sentences.isEmpty()) + { + int paddingSize = batchSize - sentences.size(); + for (int i = 0 ; i < paddingSize ; i++) + { + sentences.add(new Sentence("")); + } + disambiguateFixedSentenceBatch(sentences, newSenseTags, confidenceTag); + } + } + + protected abstract void disambiguateFixedSentenceBatch(List sentences, String newSenseTags, String confidenceTag); +} diff --git a/java/src/main/java/getalp/wsd/method/neural/NeuralDataPreparator.java b/java/src/main/java/getalp/wsd/method/neural/NeuralDataPreparator.java index 32b9459..8896a39 100644 --- a/java/src/main/java/getalp/wsd/method/neural/NeuralDataPreparator.java +++ b/java/src/main/java/getalp/wsd/method/neural/NeuralDataPreparator.java @@ -3,10 +3,12 @@ import java.io.BufferedReader; import java.io.IOException; import java.util.*; +import java.util.regex.Pattern; import java.util.stream.Collectors; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.stream.IntStream; import getalp.wsd.common.utils.POSConverter; import getalp.wsd.common.utils.StringUtils; @@ -14,8 +16,7 @@ import getalp.wsd.common.wordnet.WordnetHelper; import getalp.wsd.embeddings.TextualModelLoader; import getalp.wsd.embeddings.WordVectors; -import getalp.wsd.ufsac.core.Sentence; -import getalp.wsd.ufsac.core.Word; +import getalp.wsd.ufsac.simple.core.*; import getalp.wsd.ufsac.streaming.reader.StreamingCorpusReaderSentence; import getalp.wsd.utils.Json; import getalp.wsd.utils.WordnetUtils; @@ -27,6 +28,10 @@ public class NeuralDataPreparator private static final String unknownToken = ""; // input index "1" + private static final String beginningOfSentenceToken = ""; // input index "2" + + private static final String endOfSentenceToken = ""; // input index "3" + private static final String skipToken = ""; // output index "0" @@ -34,6 +39,10 @@ public class NeuralDataPreparator private static final String outputVocabularyFileName = "/output_vocabulary"; + private static final String outputTranslationVocabularyFileName1 = "/output_translation"; + + private static final String outputTranslationVocabularyFileName2 = "_vocabulary"; + private static final String trainFileName = "/train"; private static final String devFileName = "/dev"; @@ -47,21 +56,35 @@ public class NeuralDataPreparator private int inputFeatures = 0; + public List txtCorpusFeatures = new ArrayList<>(); + private List inputAnnotationName = new ArrayList<>(); private List inputEmbeddingsPath = new ArrayList<>(); - private List> inputVocabulary = new ArrayList<>(); + private List inputVocabularyPath = new ArrayList<>(); - private List inputVocabularyCurrentIndex = new ArrayList<>(); + private List> inputVocabulary = new ArrayList<>(); private int outputFeatures = 0; private List outputAnnotationName = new ArrayList<>(); + private List outputFixedVocabularyPath = new ArrayList<>(); + private List> outputVocabulary = new ArrayList<>(); - private List currentOutputVocabularyIndex = new ArrayList<>(); + private int outputTranslations = 0; + + private List outputTranslationName = new ArrayList<>(); + + private int outputTranslationFeatures = 0; + + private List outputTranslationAnnotationName = new ArrayList<>(); + + private List outputTranslationFixedVocabularyPath = new ArrayList<>(); + + private List>> outputTranslationVocabulary = new ArrayList<>(); private String outputDirectoryPath = "data/neural/wsd/"; @@ -71,7 +94,21 @@ public class NeuralDataPreparator private int outputFeatureSenseIndex = -1; - private List fixedOutputVocabularyPath = new ArrayList<>(); + private String corpusFormat; + + private int inputVocabularyLimit; + + private List inputClearText; + + private int outputFeatureVocabularyLimit; + + private int outputTranslationVocabularyLimit; + + private boolean outputTranslationClearText; + + private boolean shareTranslationVocabulary; + + private Set extraWordKeys = null; // --- begin public options @@ -79,6 +116,8 @@ public class NeuralDataPreparator public boolean lowercaseWords = true; + public boolean addWordKeyFromSenseKey = false; + public boolean uniformDash = false; public boolean multisenses = false; @@ -97,7 +136,6 @@ public class NeuralDataPreparator // --- end public options - public NeuralDataPreparator() { @@ -118,73 +156,164 @@ public void addDevelopmentCorpus(String corpusPath) originalDevPaths.add(corpusPath); } - public void addInputFeature(String annotationName) - { - this.addInputFeature(annotationName, null); - } - - public void addInputFeature(String annotationName, String embeddingsPath) + public void addInputFeature(String annotationName, String embeddingsPath, String vocabularyPath) { inputFeatures += 1; inputAnnotationName.add(annotationName); inputEmbeddingsPath.add(embeddingsPath); - } - - public void addOutputFeature(String annotationName) - { - addOutputFeature(annotationName, null); + inputVocabularyPath.add(vocabularyPath); } public void addOutputFeature(String annotationName, String vocabularyPath) { outputFeatures += 1; outputAnnotationName.add(annotationName); - fixedOutputVocabularyPath.add(vocabularyPath); + outputFixedVocabularyPath.add(vocabularyPath); if (annotationName.equals(senseTag)) { outputFeatureSenseIndex = outputFeatures - 1; } } + public void addOutputTranslation(String translationName, List translationAnnotationName, String vocabularyPath) + { + outputTranslations += 1; + outputTranslationName.add(translationName); + outputTranslationFixedVocabularyPath.add(vocabularyPath); + outputTranslationFeatures = translationAnnotationName.size(); + outputTranslationAnnotationName = new ArrayList<>(translationAnnotationName); + } + + public void setCorpusFormat(String corpusFormat) + { + this.corpusFormat = corpusFormat; + } + + public void setInputVocabularyLimit(int inputVocabularyLimit) + { + this.inputVocabularyLimit = inputVocabularyLimit; + } + + public void setInputClearText(List inputClearText) + { + this.inputClearText = inputClearText; + } + + public void setOutputFeatureVocabularyLimit(int outputFeatureVocabularyLimit) + { + this.outputFeatureVocabularyLimit = outputFeatureVocabularyLimit; + } + + public void setOutputTranslationVocabularyLimit(int outputTranslationVocabularyLimit) + { + this.outputTranslationVocabularyLimit = outputTranslationVocabularyLimit; + } + + public void setOutputTranslationClearText(boolean outputTranslationClearText) + { + this.outputTranslationClearText = outputTranslationClearText; + } + + public void setShareTranslationVocabulary(boolean shareTranslationVocabulary) + { + this.shareTranslationVocabulary = shareTranslationVocabulary; + } + + public void setExtraWordKeys(Set extraWordKeys) + { + if (extraWordKeys != null && extraWordKeys.isEmpty()) extraWordKeys = null; + this.extraWordKeys = extraWordKeys; + } + public void prepareTrainingFile() throws Exception { Files.createDirectories(Paths.get(outputDirectoryPath)); + + List trainSentences = extractSentencesFromCorpora(originalTrainPaths); + List devSentences = extractSentencesFromCorpora(originalDevPaths); + List> translationTrainSentences = new ArrayList<>(); + List> translationDevSentences = new ArrayList<>(); + for (int i = 0 ; i < outputTranslations ; i++) + { + String translationName = outputTranslationName.get(i); + List translationTrainPaths = originalTrainPaths.stream().map(str -> getTranslationCorpusName(str, translationName)).collect(Collectors.toList()); + List translationDevPaths = originalDevPaths.stream().map(str -> getTranslationCorpusName(str, translationName)).collect(Collectors.toList()); + List translationNameTrainSentences = extractSentencesFromCorpora(translationTrainPaths); + List translationNameDevSentences = extractSentencesFromCorpora(translationDevPaths); + assert(translationNameTrainSentences.size() == trainSentences.size()); + assert(translationNameDevSentences.size() == devSentences.size()); + translationTrainSentences.add(translationNameTrainSentences); + translationDevSentences.add(translationNameDevSentences); + } + extractDevSentencesParallel(trainSentences, translationTrainSentences, devSentences, translationDevSentences, additionalDevFromTrainSize); + + buildExtraWordKeysVocabulary(trainSentences, true); + buildExtraWordKeysVocabulary(devSentences, false); + initInputVocabulary(); initOutputVocabulary(); - List trainSentences = preprocessCorpora(originalTrainPaths, false); - if (removeDuplicateSentences) + initOutputTranslationVocabulary(); + + buildVocabulary(trainSentences, inputAnnotationName, inputEmbeddingsPath, inputVocabulary, true, inputVocabularyLimit); + buildVocabulary(trainSentences, outputAnnotationName, outputFixedVocabularyPath, outputVocabulary, false, outputFeatureVocabularyLimit); + for (int i = 0 ; i < outputTranslations ; i++) { - trainSentences = removeDuplicates(trainSentences); + buildVocabulary(translationTrainSentences.get(i), outputTranslationAnnotationName, outputTranslationFixedVocabularyPath, outputTranslationVocabulary.get(i), true, outputTranslationVocabularyLimit); } - List additionalDevSentences = extractDevSentences(trainSentences, additionalDevFromTrainSize); + if (shareTranslationVocabulary && outputTranslations > 0 && outputTranslationFeatures > 0) + { + // TODO: construct vocabulary together in buildVocabulary + // TODO: allow sharing vocabularies of multiples features and languages + mergeVocabularies(inputVocabulary.get(0), outputTranslationVocabulary.get(0).get(0)); + } + // TODO: filtering train for translations + if (outputTranslations == 0) + { + trainSentences = filterSentencesWithoutFeature(trainSentences, outputAnnotationName, outputVocabulary); + if (removeDuplicateSentences) + { + trainSentences = removeDuplicates(trainSentences); + } + } + removeEmptyParallelSentences(trainSentences, translationTrainSentences); + removeEmptyParallelSentences(devSentences, translationDevSentences); for (int i = 0 ; i < inputFeatures ; i++) { - writeVocabulary(inputVocabulary.get(i), outputDirectoryPath + inputVocabularyFileName + i); + writeVocabulary(inputVocabulary.get(i), outputDirectoryPath + inputVocabularyFileName + i, inputClearText.get(i)); } for (int i = 0 ; i < outputFeatures ; i++) { - writeVocabulary(outputVocabulary.get(i), outputDirectoryPath + outputVocabularyFileName + i); + writeVocabulary(outputVocabulary.get(i), outputDirectoryPath + outputVocabularyFileName + i, false); + } + for (int i = 0 ; i < outputTranslations ; i++) + { + for (int j = 0 ; j < outputTranslationFeatures ; j++) + { + writeVocabulary(outputTranslationVocabulary.get(i).get(j), outputDirectoryPath + outputTranslationVocabularyFileName1 + i + outputTranslationVocabularyFileName2 + j, outputTranslationClearText); + } } - writeCorpus(trainSentences, outputDirectoryPath + trainFileName); - List devSentences = preprocessCorpora(originalDevPaths, true); - devSentences.addAll(additionalDevSentences); - writeCorpus(devSentences, outputDirectoryPath + devFileName); + writeCorpus(trainSentences, translationTrainSentences, outputDirectoryPath + trainFileName); + writeCorpus(devSentences, translationDevSentences, outputDirectoryPath + devFileName); writeConfigFile(outputDirectoryPath + configFileName); } - private void initInputVocabulary() + private void initInputVocabulary() throws IOException { + assert(inputFeatures > 0); for (int i = 0 ; i < inputFeatures ; i++) { if (inputEmbeddingsPath.get(i) != null) { inputVocabulary.add(loadEmbeddings(inputEmbeddingsPath.get(i))); } + else if (inputVocabularyPath.get(i) != null) + { + inputVocabulary.add(loadVocabulary(inputVocabularyPath.get(i))); + } else { inputVocabulary.add(createNewInputVocabulary()); } - inputVocabularyCurrentIndex.add(inputVocabulary.get(inputVocabulary.size() - 1).size()); } } @@ -192,15 +321,36 @@ private void initOutputVocabulary() throws IOException { for (int i = 0 ; i < outputFeatures ; i++) { - if (fixedOutputVocabularyPath.get(i) != null) + if (outputFixedVocabularyPath.get(i) != null) { - outputVocabulary.add(readVocabulary(fixedOutputVocabularyPath.get(i))); + outputVocabulary.add(readVocabulary(outputFixedVocabularyPath.get(i))); } else { outputVocabulary.add(createNewOutputVocabulary()); } - currentOutputVocabularyIndex.add(outputVocabulary.get(outputVocabulary.size() - 1).size()); + } + } + + private void initOutputTranslationVocabulary() throws IOException + { + for (int i = 0 ; i < outputTranslations ; i++) + { + List> translationVocabulary = new ArrayList<>(); + assert(!outputTranslationAnnotationName.isEmpty()); + assert(outputTranslationAnnotationName.size() == outputTranslationFeatures); + for (int j = 0 ; j < outputTranslationFeatures ; j++) + { + if (outputTranslationFixedVocabularyPath.get(i) != null) + { + translationVocabulary.add(loadVocabulary(outputTranslationFixedVocabularyPath.get(i))); + } + else + { + translationVocabulary.add(createNewOutputTranslationVocabulary()); + } + } + outputTranslationVocabulary.add(translationVocabulary); } } @@ -208,7 +358,7 @@ private Map loadEmbeddings(String embeddingsPath) { WordVectors embeddings = new TextualModelLoader(false).loadVocabularyOnly(embeddingsPath); Map vocabulary = createNewInputVocabulary(); - int i = 2; + int i = vocabulary.size(); for (String vocab : embeddings.getVocabulary()) { vocabulary.put(vocab, i); @@ -217,6 +367,20 @@ private Map loadEmbeddings(String embeddingsPath) return vocabulary; } + private Map loadVocabulary(String vocabularyPath) throws IOException + { + Map vocabulary = createNewInputVocabulary(); + Wrapper i = new Wrapper<>(vocabulary.size()); + BufferedReader in = Files.newBufferedReader(Paths.get(vocabularyPath)); + in.lines().forEach(line -> + { + vocabulary.put(line, i.obj); + i.obj++; + }); + in.close(); + return vocabulary; + } + private Map readVocabulary(String vocabularyPath) throws IOException { Map vocabulary = new HashMap<>(); @@ -235,6 +399,8 @@ private Map createNewInputVocabulary() Map vocabulary = new HashMap<>(); vocabulary.put(paddingToken, 0); vocabulary.put(unknownToken, 1); + vocabulary.put(beginningOfSentenceToken, 2); + vocabulary.put(endOfSentenceToken, 3); return vocabulary; } @@ -245,12 +411,25 @@ private Map createNewOutputVocabulary() return vocabulary; } - private void writeVocabulary(Map vocabulary, String vocabularyPath) throws IOException + private Map createNewOutputTranslationVocabulary() + { + Map vocabulary = new HashMap<>(); + vocabulary.put(paddingToken, 0); + vocabulary.put(unknownToken, 1); + vocabulary.put(beginningOfSentenceToken, 2); + vocabulary.put(endOfSentenceToken, 3); + return vocabulary; + } + + private void writeVocabulary(Map vocabulary, String vocabularyPath, boolean clearText) throws IOException { BufferedWriter out = Files.newBufferedWriter(Paths.get(vocabularyPath)); - for (Map.Entry vocab : vocabulary.entrySet().stream().sorted(Map.Entry.comparingByValue()).collect(Collectors.toList())) + if (!clearText) { - out.write("" + vocab.getValue() + " " + vocab.getKey() + "\n"); + for (Map.Entry vocab : vocabulary.entrySet().stream().sorted(Map.Entry.comparingByValue()).collect(Collectors.toList())) + { + out.write("" + vocab.getKey() + "\n"); + } } out.close(); } @@ -262,169 +441,385 @@ private void writeConfigFile(String configFilePath) throws IOException config.put("input_features", inputFeatures); config.put("input_annotation_name", inputAnnotationName); config.put("input_embeddings_path", inputEmbeddingsPath.stream().map(p -> p != null ? Paths.get(p).toAbsolutePath().toString() : null).collect(Collectors.toList())); - config.put("input_embeddings_size", inputEmbeddingsPath.stream().map(p -> p != null ? null : 300).collect(Collectors.toList())); + config.put("input_clear_text", IntStream.range(0, inputFeatures).boxed().map(i -> inputClearText.get(i)).collect(Collectors.toList())); config.put("output_features", outputFeatures); config.put("output_annotation_name", outputAnnotationName); - config.put("lstm_units_size", 1000); - config.put("lstm_layers", 1); - config.put("linear_before_lstm", false); - config.put("dropout_rate_before_lstm", null); - config.put("dropout_rate", 0.5); - config.put("word_dropout_rate", null); - config.put("attention_layer", false); - config.put("legacy_model", false); + config.put("output_translations", outputTranslations); + config.put("output_translation_name", outputTranslationName); + config.put("output_translation_features", outputTranslationFeatures); + config.put("output_translation_annotation_name", outputTranslationAnnotationName); + config.put("output_translation_clear_text", outputTranslationClearText); BufferedWriter out = Files.newBufferedWriter(Paths.get(configFilePath)); Json.write(out, config); out.close(); } - - private List preprocessCorpora(List originalCorpusPaths, boolean vocabularyIsFixed) + + private List extractSentencesFromCorpora(List originalCorpusPaths) throws Exception + { + List sentences; + if (corpusFormat.equals("xml")) + { + sentences = extractSentencesFromUFSACCorpora(originalCorpusPaths); + } + else + { + sentences = extractSentencesFromTXTCorpora(originalCorpusPaths); + } + cleanSentences(sentences); + return sentences; + } + + private List extractSentencesFromTXTCorpora(List originalCorpusPaths) throws Exception + { + List allSentences = new ArrayList<>(); + if (txtCorpusFeatures.isEmpty()) + { + txtCorpusFeatures = new ArrayList<>(); + for (int i = 0; i < inputFeatures; i++) + { + txtCorpusFeatures.add(inputAnnotationName.get(i)); + } + for (int i = 0; i < outputFeatures; i++) + { + txtCorpusFeatures.add(outputAnnotationName.get(i)); + } + } + for (String originalCorpusPath : originalCorpusPaths) + { + System.out.println("Extracting sentences from corpus " + originalCorpusPath); + BufferedReader reader = Files.newBufferedReader(Paths.get(originalCorpusPath)); + for (String line = reader.readLine(); line != null ; line = reader.readLine()) + { + Sentence sentence = new Sentence(); + String[] words = line.split(RegExp.anyWhiteSpaceGrouped.pattern()); + for (String word : words) + { + Word ufsacWord = new Word(); + String[] wordFeatures = word.split(Pattern.quote("|")); + if (wordFeatures.length < 1) + { + System.out.println("Warning: empty word in sentence: " + line); + wordFeatures = new String[]{"/"}; + } + ufsacWord.setValue(wordFeatures[0]); + for (int i = 1; i < txtCorpusFeatures.size(); i++) + { + if (wordFeatures.length > i) + { + ufsacWord.setAnnotation(txtCorpusFeatures.get(i), wordFeatures[i]); + } + } + sentence.addWord(ufsacWord); + } + allSentences.add(sentence); + } + } + return allSentences; + } + + private List extractSentencesFromUFSACCorpora(List originalCorpusPaths) { List allSentences = new ArrayList<>(); StreamingCorpusReaderSentence reader = new StreamingCorpusReaderSentence() { @Override - public void readSentence(Sentence s) + public void readSentence(getalp.wsd.ufsac.core.Sentence sentence) { - List words = s.getWords(); + allSentences.add(new Sentence(sentence)); + } + }; - /// truncate lines too long - if (words.size() > maxLineLength) - { - words = words.subList(0, maxLineLength); - } + for (String originalCorpusPath : originalCorpusPaths) + { + System.out.println("Extracting sentences from corpus " + originalCorpusPath); + reader.load(originalCorpusPath); + } + return allSentences; + } - /// filtering out sentences with no output features - boolean sentenceHasOutputFeatures = false; - for (Word w : words) + private void cleanSentences(List sentences) + { + System.out.println("Cleaning sentences"); + for (Sentence s : sentences) + { + s.limitSentenceLength(maxLineLength); + + List words = s.getWords(); + + for (Word w : words) + { + /// add monosemics if asked + if (addMonosemics && !w.hasAnnotation(senseTag) && w.hasAnnotation("lemma") && w.hasAnnotation("pos")) { - /// add monosemics if asked - if (addMonosemics && !w.hasAnnotation(senseTag) && w.hasAnnotation("lemma") && w.hasAnnotation("pos")) + String wordKey = w.getAnnotationValue("lemma") + "%" + POSConverter.toWNPOS(w.getAnnotationValue("pos")); + if (wn.isWordKeyExists(wordKey)) { - String wordKey = w.getAnnotationValue("lemma") + "%" + POSConverter.toWNPOS(w.getAnnotationValue("pos")); - if (wn.isWordKeyExists(wordKey)) + List senseKeys = wn.getSenseKeyListFromWordKey(wordKey); + if (senseKeys.size() == 1) { - List senseKeys = wn.getSenseKeyListFromWordKey(wordKey); - if (senseKeys.size() == 1) - { - w.setAnnotation(senseTag, senseKeys.get(0)); - } + w.setAnnotation(senseTag, senseKeys.get(0)); } } + } + + /// clean output sense tags, convert them to synset keys + if (w.hasAnnotation(senseTag)) + { + List senseKeys = w.getAnnotationValues(senseTag, ";"); + + if (!w.hasAnnotation("lemma")) + { + w.setAnnotation("lemma", WordnetUtils.extractLemmaFromSenseKey(senseKeys.get(0))); + } + if (!w.hasAnnotation("pos")) + { + w.setAnnotation("pos", WordnetUtils.extractPOSFromSenseKey(senseKeys.get(0))); + } + + String wordKey = w.getAnnotationValue("lemma") + "%" + POSConverter.toWNPOS(w.getAnnotationValue("pos")); - /// clean output sense tags, convert them to synset keys - if (w.hasAnnotation(senseTag)) + if (addWordKeyFromSenseKey) { - String wordKey = w.getAnnotationValue("lemma") + "%" + POSConverter.toWNPOS(w.getAnnotationValue("pos")); + w.setAnnotation("word_key", wordKey); + } - if (removeMonosemics && wn.getSenseKeyListFromWordKey(wordKey).size() == 1) - { - w.removeAnnotation(senseTag); - } + if (removeMonosemics && wn.getSenseKeyListFromWordKey(wordKey).size() == 1) + { + w.removeAnnotation(senseTag); + senseKeys = Collections.emptyList(); + } - List senseKeys = w.getAnnotationValues(senseTag, ";"); - Set synsetKeys = WordnetUtils.getUniqueSynsetKeysFromSenseKeys(wn, senseKeys); - List finalSynsetKeys = new ArrayList<>(); + Set synsetKeys = WordnetUtils.getUniqueSynsetKeysFromSenseKeys(wn, senseKeys); + List finalSynsetKeys = new ArrayList<>(); - if (removeAllCoarseGrained && synsetKeys.size() > 1) + if (removeAllCoarseGrained && synsetKeys.size() > 1) + { + synsetKeys.clear(); + } + for (String synsetKey : synsetKeys) + { + if (reducedOutputVocabulary != null) { - synsetKeys.clear(); + synsetKey = reducedOutputVocabulary.getOrDefault(synsetKey, synsetKey); } - for (String synsetKey : synsetKeys) + finalSynsetKeys.add(synsetKey); + } + if (finalSynsetKeys.isEmpty()) + { + w.removeAnnotation(senseTag); + } + else + { + if (!multisenses) { + finalSynsetKeys = finalSynsetKeys.subList(0, 1); + } + w.setAnnotation(senseTag, finalSynsetKeys, ";"); + } + } + + // lowercase word + if (lowercaseWords) + { + w.setValue(w.getValue().toLowerCase()); + } + + // uniformize dash + if (uniformDash) + { + w.setValue(w.getValue().replaceAll("_", "-")); + } + } + } + } + + private void buildExtraWordKeysVocabulary(List sentences, boolean isTrain) + { + if (extraWordKeys == null) return; + System.out.println("Building extra wordkeys vocabulary"); + Map> sensesPerWordKey = new HashMap<>(); + for (Sentence s : sentences) + { + List words = s.getWords(); + for (Word w : words) + { + if (w.hasAnnotation(senseTag)) + { + String senseKey = w.getAnnotationValue(senseTag); + for (String extraWordKey : extraWordKeys) + { + // si extraWordKey pourrait avoir le sens senseKey, alors on met cette annotation + boolean isPossibleSense = false; + for (String extrasenseKey : wn.getSenseKeyListFromWordKey(extraWordKey)) + { + String extraSynsetKey = wn.getSynsetKeyFromSenseKey(extrasenseKey); if (reducedOutputVocabulary != null) { - synsetKey = reducedOutputVocabulary.getOrDefault(synsetKey, synsetKey); + extraSynsetKey = reducedOutputVocabulary.getOrDefault(extraSynsetKey, extraSynsetKey); } - finalSynsetKeys.add(synsetKey); - } - if (finalSynsetKeys.isEmpty()) - { - w.removeAnnotation(senseTag); - } - else - { - if (!multisenses) + if (extraSynsetKey.equals(senseKey)) { - finalSynsetKeys = finalSynsetKeys.subList(0, 1); + isPossibleSense = true; + break; } - w.setAnnotation(senseTag, finalSynsetKeys, ";"); } - } - - /// check if any word contains any output feature - /// and construct output vocabulary - for (int i = 0 ; i < outputFeatures ; i++) - { - Map featureVocabulary = outputVocabulary.get(i); - List featureValues = w.getAnnotationValues(outputAnnotationName.get(i), ";"); - if (!featureValues.isEmpty()) + if (isPossibleSense) { - for (String featureValue : featureValues) + w.setAnnotation(extraWordKey, senseKey); + if (isTrain) { - if (featureVocabulary.containsKey(featureValue)) - { - sentenceHasOutputFeatures = true; - } - else if (!vocabularyIsFixed && fixedOutputVocabularyPath.get(i) == null) - { - int currentIndex = currentOutputVocabularyIndex.get(i); - featureVocabulary.put(featureValue, currentIndex); - currentOutputVocabularyIndex.set(i, currentIndex + 1); - sentenceHasOutputFeatures = true; - } + sensesPerWordKey.putIfAbsent(extraWordKey, new HashSet<>()); + sensesPerWordKey.get(extraWordKey).add(senseKey); } } } } + } + } + if (isTrain) + { + for (String extraWordKey : sensesPerWordKey.keySet()) + { + if (sensesPerWordKey.get(extraWordKey).size() > 1) + { + addOutputFeature(extraWordKey, null); + } + } + } + } - /// skip this sentence - if (!sentenceHasOutputFeatures) return; + private void buildVocabulary(List allSentences, + List annotationName, List fixedVocabularyPath, + List> vocabulary, + boolean isInputVocabulary, int vocabularyLimit) + { + System.out.println("Building vocabulary"); - for (Word w : words) + List> vocabularyFrequencies = new ArrayList<>(); + for (int i = 0 ; i < annotationName.size(); i++) + { + vocabularyFrequencies.add(new HashMap<>()); + } + + for (Sentence s : allSentences) + { + List words = s.getWords(); + + for (Word w : words) + { + for (int i = 0 ; i < annotationName.size() ; i++) { - // lowercase word - if (lowercaseWords) + if (fixedVocabularyPath.get(i) != null) continue; + List featureValues; + if (isInputVocabulary) { - w.setValue(w.getValue().toLowerCase()); + featureValues = Collections.singletonList(w.getAnnotationValue(annotationName.get(i))); + if (featureValues.get(0).isEmpty()) continue; } - - // uniformize dash - if (uniformDash) + else { - w.setValue(w.getValue().replaceAll("_", "-")); + featureValues = w.getAnnotationValues(annotationName.get(i), ";"); + if (featureValues.isEmpty()) continue; } + for (String featureValue : featureValues) + { + Map featureFrequencies = vocabularyFrequencies.get(i); + int currentFrequency = featureFrequencies.getOrDefault(featureValue, 0); + currentFrequency += 1; + featureFrequencies.put(featureValue, currentFrequency); + } + } + } + } + + for (int i = 0 ; i < annotationName.size() ; i++) + { + if (fixedVocabularyPath.get(i) != null) continue; + Map featureFrequencies = vocabularyFrequencies.get(i); + Map featureVocabulary; + if (isInputVocabulary) + { + featureVocabulary = createNewInputVocabulary(); + } + else + { + featureVocabulary = createNewOutputVocabulary(); + } + int initVocabularySize = featureVocabulary.size(); + List sortedKeys = featureFrequencies.entrySet().stream().sorted(Map.Entry.comparingByValue()).map(Map.Entry::getKey).collect(Collectors.toList()); + Collections.reverse(sortedKeys); + if (vocabularyLimit <= 0) + { + vocabularyLimit = sortedKeys.size(); + } + else + { + vocabularyLimit = Math.min(vocabularyLimit, sortedKeys.size()); + } + for (int j = 0; j < vocabularyLimit; j++) + { + featureVocabulary.put(sortedKeys.get(j), j + initVocabularySize); + } + vocabulary.set(i, featureVocabulary); + } + } - // construct input vocabulary - for (int i = 0 ; i < inputFeatures ; i++) + private void mergeVocabularies(Map vocabulary1, Map vocabulary2) + { + int j = vocabulary1.size(); + for (String wordInVocabulary2 : vocabulary2.keySet()) + { + if (!vocabulary1.containsKey(wordInVocabulary2)) + { + vocabulary1.put(wordInVocabulary2, j); + j++; + } + } + vocabulary2.putAll(vocabulary1); + } + + private List filterSentencesWithoutFeature(List allSentences, List annotationName, + List> vocabulary) + { + System.out.println("Filtering sentences without feature"); + + List filteredSentences = new ArrayList<>(); + + for (Sentence s : allSentences) + { + List words = s.getWords(); + boolean sentenceHasOutputFeatures = false; + for (Word w : words) + { + for (int i = 0 ; i < annotationName.size() ; i++) + { + List featureValues = w.getAnnotationValues(annotationName.get(i), ";"); + if (featureValues.isEmpty()) continue; + Map featureVocabulary = vocabulary.get(i); + for (String featureValue : featureValues) { - String featureValue = w.getAnnotationValue(inputAnnotationName.get(i)); - Map featureVocabulary = inputVocabulary.get(i); - if (!featureValue.isEmpty() && !featureVocabulary.containsKey(featureValue) && !vocabularyIsFixed && inputEmbeddingsPath.get(i) == null) + if (featureVocabulary.containsKey(featureValue)) { - int currentIndex = inputVocabularyCurrentIndex.get(i); - featureVocabulary.put(featureValue, currentIndex); - inputVocabularyCurrentIndex.set(i, currentIndex + 1); + sentenceHasOutputFeatures = true; } } } - - /// add the sentence - allSentences.add(new Sentence(new ArrayList<>(words))); } - }; - for (String originalCorpusPath : originalCorpusPaths) - { - reader.load(originalCorpusPath); + if (sentenceHasOutputFeatures) + { + filteredSentences.add(s); + } } - return allSentences; + return filteredSentences; } private List removeDuplicates(List sentences) { + System.out.println("Removing duplicate sentences"); + Map realSentences = new HashMap<>(); for (Sentence currentSentence : sentences) { @@ -496,22 +891,107 @@ else if (!multisenses) return realTrueSentences; } - private List extractDevSentences(List trainSentences, int count) + private void extractDevSentencesParallel(List trainSentences, List> translatedTrainSentences, + List devSentences, List> translatedDevSentences, + int count) + { + if (count <= 0) return; + if (trainSentences.size() <= count) return; + System.out.println("Extracting dev sentences from train"); + + // generating random indices of sentences to transfer + List randomIndices = new ArrayList<>(); + for (int i = 0 ; i < trainSentences.size() ; i++) + { + randomIndices.add(i); + } + Collections.shuffle(randomIndices); + randomIndices = randomIndices.subList(0, count); + + // fetching Sentence object from indices + List trainSentencesToExtract = new ArrayList<>(); + List> translatedTrainSentencesToExtract = new ArrayList<>(); + for (int i = 0 ; i < outputTranslations ; i++) + { + translatedTrainSentencesToExtract.add(new ArrayList<>()); + } + for (int index : randomIndices) + { + trainSentencesToExtract.add(trainSentences.get(index)); + for (int i = 0 ; i < outputTranslations ; i++) + { + translatedTrainSentencesToExtract.get(i).add(translatedTrainSentences.get(i).get(index)); + } + } + + // actual remove from train / add to dev the sentences + trainSentences.removeAll(trainSentencesToExtract); + devSentences.addAll(trainSentencesToExtract); + for (int i = 0 ; i < outputTranslations ; i++) + { + translatedTrainSentences.get(i).removeAll(translatedTrainSentencesToExtract.get(i)); + translatedDevSentences.get(i).addAll(translatedTrainSentencesToExtract.get(i)); + } + } + + private void removeEmptyParallelSentences(List sentences, List> translatedSentences) { - if (count <= 0) return Collections.emptyList(); - Collections.shuffle(trainSentences); - List subPartOfTrainSentences = trainSentences.subList(0, count); - List devSentences = new ArrayList<>(subPartOfTrainSentences); - subPartOfTrainSentences.clear(); - return devSentences; + System.out.println("Removing empty parallel sentences"); + List sentenceIndicesToRemove = new ArrayList<>(); + for (int i = 0 ; i < sentences.size() ; i++) + { + for (List oneLanguageTranslatedSentences : translatedSentences) + { + Sentence translatedSentence = oneLanguageTranslatedSentences.get(i); + List translatedSentenceWords = translatedSentence.getWords(); + boolean empty = true; + for (Word w : translatedSentenceWords) + { + if (!empty) break; + for (String annotation : outputTranslationAnnotationName) + { + if (w.hasAnnotation(annotation)) + { + empty = false; + break; + } + } + } + if (empty) + { + sentenceIndicesToRemove.add(i); + break; + } + } + } + sentences.removeAll(sentenceIndicesToRemove.stream().map(sentences::get).collect(Collectors.toList())); + for (List tsentences : translatedSentences) + { + tsentences.removeAll(sentenceIndicesToRemove.stream().map(tsentences::get).collect(Collectors.toList())); + } } - private void writeCorpus(List sentences, String corpusPath) throws Exception + private String getTranslationCorpusName(String originalCorpusName, String translationName) { + if (corpusFormat.equals("xml")) + { + return originalCorpusName.substring(0, originalCorpusName.lastIndexOf(".xml")) + "." + translationName + ".xml"; + } + else + { + return originalCorpusName.substring(0, originalCorpusName.lastIndexOf(".")) + "." + translationName; + } + } + + private void writeCorpus(List sentences, List> translatedSentences, String corpusPath) throws Exception + { + System.out.println("Writing corpus " + corpusPath); + Wrapper writer = new Wrapper<>(); writer.obj = Files.newBufferedWriter(Paths.get(corpusPath)); - for (Sentence s : sentences) + for (int si = 0 ; si < sentences.size(); si++) { + Sentence s = sentences.get(si); List words = s.getWords(); /// step 1 : write input features @@ -522,72 +1002,122 @@ private void writeCorpus(List sentences, String corpusPath) throws Exc { String featureValue = w.getAnnotationValue(inputAnnotationName.get(i)); Map featureVocabulary = inputVocabulary.get(i); - if (featureValue.isEmpty() || !featureVocabulary.containsKey(featureValue)) + if (featureValue.isEmpty() || !(featureVocabulary.containsKey(featureValue) || (inputVocabularyLimit <= 0 && inputClearText.get(i)))) { featureValue = unknownToken; - } - featureValue = Integer.toString(featureVocabulary.get(featureValue)); + } + if (inputClearText.get(i)) + { + featureValue = featureValue.replace("/", ""); + } + else + { + featureValue = Integer.toString(featureVocabulary.get(featureValue)); + } featureValues.add(featureValue); } writer.obj.write(StringUtils.join(featureValues, "/") + " "); } writer.obj.newLine(); - /// step 3 : write output features - for (Word w : words) + if (outputFeatures > 0) { - List featureValues = new ArrayList<>(); - for (int i = 0 ; i < outputFeatures ; i++) + /// step 2 : write output features + for (Word w : words) { - Map featureVocabulary = outputVocabulary.get(i); - List thisFeatureValues = w.getAnnotationValues(outputAnnotationName.get(i), ";"); - thisFeatureValues = thisFeatureValues.stream().filter(featureVocabulary::containsKey).collect(Collectors.toList()); - if (thisFeatureValues.isEmpty()) + List featureValues = new ArrayList<>(); + for (int i = 0; i < outputFeatures; i++) { - thisFeatureValues = Collections.singletonList(skipToken); + Map featureVocabulary = outputVocabulary.get(i); + List thisFeatureValues = w.getAnnotationValues(outputAnnotationName.get(i), ";"); + thisFeatureValues = thisFeatureValues.stream().filter(featureVocabulary::containsKey).collect(Collectors.toList()); + if (thisFeatureValues.isEmpty()) + { + thisFeatureValues = Collections.singletonList(skipToken); + } + thisFeatureValues = thisFeatureValues.stream().map(value -> Integer.toString(featureVocabulary.get(value))).collect(Collectors.toList()); + featureValues.add(StringUtils.join(thisFeatureValues, ";")); } - thisFeatureValues = thisFeatureValues.stream().map(value -> Integer.toString(featureVocabulary.get(value))).collect(Collectors.toList()); - featureValues.add(StringUtils.join(thisFeatureValues, ";")); + writer.obj.write(StringUtils.join(featureValues, "/") + " "); } - writer.obj.write(StringUtils.join(featureValues, "/") + " "); - } - writer.obj.newLine(); + writer.obj.newLine(); - /// step 4 : write output (sense_tag) restrictions - for (Word w : words) - { - for (int i = 0 ; i < outputFeatures ; i++) + /// step 3 : write output (sense_tag) restrictions + for (Word w : words) { - if (i > 0) - { - writer.obj.write("/"); - } - if (outputFeatureSenseIndex == i && w.hasAnnotation(senseTag) && w.getAnnotationValues(senseTag, ";").stream().anyMatch(outputVocabulary.get(i)::containsKey)) + for (int i = 0 ; i < outputFeatures ; i++) { - List restrictedSenses = new ArrayList<>(); - String wordKey = w.getAnnotationValue("lemma") + "%" + POSConverter.toWNPOS(w.getAnnotationValue("pos")); - for (String senseKey : wn.getSenseKeyListFromWordKey(wordKey)) + if (i > 0) { - String synsetKey = wn.getSynsetKeyFromSenseKey(senseKey); - if (reducedOutputVocabulary != null) + writer.obj.write("/"); + } + String featureTag = outputAnnotationName.get(i); + Map featureVocabulary = outputVocabulary.get(i); + if (w.hasAnnotation(featureTag) && w.getAnnotationValues(featureTag, ";").stream().anyMatch(featureVocabulary::containsKey)) + { + if (outputFeatureSenseIndex == i) { - synsetKey = reducedOutputVocabulary.getOrDefault(synsetKey, synsetKey); + List restrictedSenses = new ArrayList<>(); + String wordKey = w.getAnnotationValue("lemma") + "%" + POSConverter.toWNPOS(w.getAnnotationValue("pos")); + for (String senseKey : wn.getSenseKeyListFromWordKey(wordKey)) + { + String synsetKey = wn.getSynsetKeyFromSenseKey(senseKey); + if (reducedOutputVocabulary != null) + { + synsetKey = reducedOutputVocabulary.getOrDefault(synsetKey, synsetKey); + } + if (featureVocabulary.containsKey(synsetKey)) + { + restrictedSenses.add("" + featureVocabulary.get(synsetKey)); + } + } + writer.obj.write(StringUtils.join(restrictedSenses, ";")); } - if (outputVocabulary.get(i).containsKey(synsetKey)) + else { - restrictedSenses.add("" + outputVocabulary.get(outputFeatureSenseIndex).get(synsetKey)); + writer.obj.write("-1"); } } - writer.obj.write(StringUtils.join(restrictedSenses, ";")); + else + { + writer.obj.write("0"); + } } - else + writer.obj.write(" "); + } + writer.obj.newLine(); + } + + /// step 4 : write translation output + for (int ti = 0 ; ti < outputTranslations ; ti++) + { + Sentence translatedSentence = translatedSentences.get(ti).get(si); + words = translatedSentence.getWords(); + for (Word w : words) + { + List featureValues = new ArrayList<>(); + for (int i = 0; i < outputTranslationFeatures; i++) { - writer.obj.write("0"); + String featureValue = w.getAnnotationValue(outputTranslationAnnotationName.get(i)); + Map featureVocabulary = outputTranslationVocabulary.get(ti).get(i); + if (featureValue.isEmpty() || !featureVocabulary.containsKey(featureValue)) + { + featureValue = unknownToken; + } + if (outputTranslationClearText) + { + featureValue = featureValue.replace("/", ""); + } + else + { + featureValue = Integer.toString(featureVocabulary.get(featureValue)); + } + featureValues.add(featureValue); } + writer.obj.write(StringUtils.join(featureValues, "/") + " "); } - writer.obj.write(" "); + writer.obj.newLine(); } - writer.obj.newLine(); } writer.obj.close(); } diff --git a/java/src/main/java/getalp/wsd/method/neural/NeuralDisambiguator.java b/java/src/main/java/getalp/wsd/method/neural/NeuralDisambiguator.java index f2ea5d2..428d651 100644 --- a/java/src/main/java/getalp/wsd/method/neural/NeuralDisambiguator.java +++ b/java/src/main/java/getalp/wsd/method/neural/NeuralDisambiguator.java @@ -9,42 +9,67 @@ import getalp.wsd.common.utils.RegExp; import getalp.wsd.common.utils.StringUtils; import getalp.wsd.common.wordnet.WordnetHelper; -import getalp.wsd.method.DisambiguatorContextSentence; +import getalp.wsd.method.DisambiguatorContextSentenceBatch; +import getalp.wsd.ufsac.core.Sentence; import getalp.wsd.ufsac.core.Word; import getalp.wsd.utils.Json; -public class NeuralDisambiguator extends DisambiguatorContextSentence implements AutoCloseable +public class NeuralDisambiguator extends DisambiguatorContextSentenceBatch implements AutoCloseable { private static final String unknownToken = ""; - + private WordnetHelper wn = WordnetHelper.wn30(); private int inputFeatures; - + private List inputAnnotationNames; - + + private List inputClearText; + private List> inputVocabulary; - + private int outputFeatures; - // private List outputAnnotationNames; - + private int senseFeatureIndex; + + private int outputTranslations; + + private List outputAnnotationNames; + + private Map reversedOutputAnnotationNames; + private List> outputVocabulary; - + private List> reversedOutputVocabulary; - private Process pythonProcess = null; - - private BufferedReader pythonProcessReader = null; - - private BufferedWriter pythonProcessWriter = null; + private List> reversedOutputTranslationVocabulary; + + private Process pythonProcess; + + private BufferedReader pythonProcessReader; + + private BufferedWriter pythonProcessWriter; + + private boolean clearText; + + private int batchSize; + + private int beamSize; + + private boolean disambiguate = true; + + private boolean translate = false; + + private boolean extraLemma = false; // --- begin public options public boolean lowercaseWords = true; - + + public boolean filterLemma = true; + public Map reducedOutputVocabulary = null; - + // --- end public options public NeuralDisambiguator(String pythonPath, String neuralPath, String weightsPath) @@ -52,14 +77,61 @@ public NeuralDisambiguator(String pythonPath, String neuralPath, String weightsP this(pythonPath, neuralPath, Collections.singletonList(weightsPath)); } - public NeuralDisambiguator(String pythonPath, String neuralPath, List weightsPaths) + public NeuralDisambiguator(String pythonPath, String neuralPath, String weightsPath, boolean clearText, int batchSize) { + this(pythonPath, neuralPath, Collections.singletonList(weightsPath), clearText, batchSize); + } + + public NeuralDisambiguator(String pythonPath, String neuralPath, String weightsPath, boolean clearText, int batchSize, boolean extraLemma) + { + this(pythonPath, neuralPath, Collections.singletonList(weightsPath), clearText, batchSize, extraLemma); + } + + public NeuralDisambiguator(String pythonPath, String neuralPath, String weightsPath, boolean clearText, int batchSize, boolean translate, int beamSize) + { + this(pythonPath, neuralPath, Collections.singletonList(weightsPath), clearText, batchSize, translate, beamSize); + } + + public NeuralDisambiguator(String pythonPath, String neuralPath, String weightsPath, boolean clearText, int batchSize, boolean translate, int beamSize, boolean extraLemma) + { + this(pythonPath, neuralPath, Collections.singletonList(weightsPath), clearText, batchSize, translate, beamSize, extraLemma); + } + + public NeuralDisambiguator(String pythonPath, String neuralPath, List weightsPath) + { + this(pythonPath, neuralPath, weightsPath, false, 1); + } + + public NeuralDisambiguator(String pythonPath, String neuralPath, List weightsPaths, boolean clearText, int batchSize) + { + this(pythonPath, neuralPath, weightsPaths, clearText, batchSize, false, 1); + } + + public NeuralDisambiguator(String pythonPath, String neuralPath, List weightsPaths, boolean clearText, int batchSize, boolean extraLemma) + { + this(pythonPath, neuralPath, weightsPaths, clearText, batchSize, false, 1, extraLemma); + } + + public NeuralDisambiguator(String pythonPath, String neuralPath, List weightsPaths, boolean clearText, int batchSize, boolean translate, int beamSize) + { + this(pythonPath, neuralPath, weightsPaths, clearText, batchSize, translate, beamSize, false); + } + + public NeuralDisambiguator(String pythonPath, String neuralPath, List weightsPaths, boolean clearText, int batchSize, boolean translate, int beamSize, boolean extraLemma) + { + super(batchSize); try { + this.clearText = clearText; + this.batchSize = batchSize; + this.beamSize = beamSize; + this.translate = translate; + this.extraLemma = extraLemma; initPythonProcess(pythonPath, neuralPath, weightsPaths); readConfigFile(neuralPath); initInputVocabulary(neuralPath); initOutputVocabulary(neuralPath); + initTranslationOutputVocabulary(neuralPath); } catch (Exception e) { @@ -67,27 +139,46 @@ public NeuralDisambiguator(String pythonPath, String neuralPath, List we } } + public int getInputFeatures() + { + return inputFeatures; + } + + public List getInputAnnotationNames() + { + return inputAnnotationNames; + } + private void initPythonProcess(String pythonPath, String neuralPath, List weightsPaths) throws IOException { List args = new ArrayList<>(Arrays.asList(pythonPath + "/launch.sh", "getalp.wsd.predict", "--data_path", neuralPath, "--weights")); args.addAll(weightsPaths); + if (clearText) args.add("--clear_text"); + args.addAll(Arrays.asList("--batch_size", "" + batchSize)); + if (disambiguate) args.add("--disambiguate"); + if (translate) args.add("--translate"); + args.addAll(Arrays.asList("--beam_size", "" + beamSize)); + if (extraLemma) args.add("--output_all_features"); ProcessBuilder pb = new ProcessBuilder(args); pb.redirectError(ProcessBuilder.Redirect.INHERIT); pythonProcess = pb.start(); pythonProcessReader = new BufferedReader(new InputStreamReader(pythonProcess.getInputStream())); pythonProcessWriter = new BufferedWriter(new OutputStreamWriter(pythonProcess.getOutputStream())); } - + @SuppressWarnings("unchecked") private void readConfigFile(String neuralPath) throws IOException { Map config = Json.readMap(neuralPath + "/config.json"); inputFeatures = (int) config.get("input_features"); inputAnnotationNames = (List) config.get("input_annotation_name"); + inputClearText = (List) config.get("input_clear_text"); outputFeatures = (int) config.get("output_features"); - // outputAnnotationNames = (List) config.get("output_annotation_name"); + outputAnnotationNames = (List) config.get("output_annotation_name"); + outputTranslations = (int) config.get("output_translation_features"); + senseFeatureIndex = 0; } - + private void initInputVocabulary(String neuralPath) throws Exception { inputVocabulary = new ArrayList<>(); @@ -99,38 +190,65 @@ private void initInputVocabulary(String neuralPath) throws Exception private void initOutputVocabulary(String neuralPath) throws Exception { - outputVocabulary = new ArrayList<>(); - reversedOutputVocabulary = new ArrayList<>(); + outputVocabulary = new ArrayList<>(); + reversedOutputVocabulary = new ArrayList<>(); + reversedOutputAnnotationNames = new HashMap<>(); for (int i = 0 ; i < outputFeatures ; i++) { - Map vocabulary = initVocabulary(neuralPath + "/output_vocabulary" + i); - Map reversedVocabulary = new HashMap<>(); - for (String key : vocabulary.keySet()) - { - reversedVocabulary.put(vocabulary.get(key), key); - } - outputVocabulary.add(vocabulary); - reversedOutputVocabulary.add(reversedVocabulary); + Map vocabulary = initVocabulary(neuralPath + "/output_vocabulary" + i); + Map reversedVocabulary = new HashMap<>(); + for (String key : vocabulary.keySet()) + { + reversedVocabulary.put(vocabulary.get(key), key); + } + outputVocabulary.add(vocabulary); + reversedOutputVocabulary.add(reversedVocabulary); + reversedOutputAnnotationNames.put(outputAnnotationNames.get(i), i); + } + } + + private void initTranslationOutputVocabulary(String neuralPath) throws Exception + { + reversedOutputTranslationVocabulary = new ArrayList<>(); + for (int i = 0 ; i < outputTranslations ; i++) + { + Map vocabulary = initVocabulary(neuralPath + "/output_translation" + i + "_vocabulary0"); + Map reversedVocabulary = new HashMap<>(); + for (String key : vocabulary.keySet()) + { + reversedVocabulary.put(vocabulary.get(key), key); + } + reversedOutputTranslationVocabulary.add(reversedVocabulary); } } - private HashMap initVocabulary(String filePath) throws Exception + private Map initVocabulary(String filePath) throws Exception { - HashMap ret = new HashMap<>(); + Map ret = new HashMap<>(); + List vocabAsList = new ArrayList<>(); BufferedReader reader = Files.newBufferedReader(Paths.get(filePath)); reader.lines().forEach(line -> { - String[] tokens = line.split(RegExp.anyWhiteSpaceGrouped.pattern()); - ret.put(tokens[1], Integer.valueOf(tokens[0])); + String[] linesplit = line.split(RegExp.anyWhiteSpaceGrouped.pattern()); + if (linesplit.length == 1) vocabAsList.add(linesplit[0]); + else vocabAsList.add(linesplit[1]); }); reader.close(); + for (int i = 0 ; i < vocabAsList.size() ; i++) + { + ret.put(vocabAsList.get(i), i); + } return ret; } - private void writePredictInput(List words) throws Exception + private void writePredictInput(List sentences) throws Exception { - writePredictInputSampleX(words); - writePredictInputSampleZ(words); + for (Sentence sentence : sentences) + { + List words = sentence.getWords(); + writePredictInputSampleX(words); + writePredictInputSampleZ(words); + } pythonProcessWriter.flush(); } @@ -143,15 +261,22 @@ private void writePredictInputSampleX(List words) throws Exception w.setValue(w.getValue().toLowerCase()); } List featureValues = new ArrayList<>(); - for (int i = 0 ; i < inputFeatures ; i++) + for (int i = 0; i < inputFeatures; i++) { String featureValue = w.getAnnotationValue(inputAnnotationNames.get(i)); - Map featureVocabulary = inputVocabulary.get(i); - if (featureValue.isEmpty() || !featureVocabulary.containsKey(featureValue)) + if (inputClearText.get(i) || clearText) { - featureValue = unknownToken; - } - featureValue = Integer.toString(featureVocabulary.get(featureValue)); + featureValue = featureValue.replace("/", ""); + } + else + { + Map featureVocabulary = inputVocabulary.get(i); + if (featureValue.isEmpty() || !featureVocabulary.containsKey(featureValue)) + { + featureValue = unknownToken; + } + featureValue = Integer.toString(featureVocabulary.get(featureValue)); + } featureValues.add(featureValue); } pythonProcessWriter.write(StringUtils.join(featureValues, "/") + " "); @@ -161,59 +286,105 @@ private void writePredictInputSampleX(List words) throws Exception private void writePredictInputSampleZ(List words) throws Exception { + if (outputFeatures <= 0) return; + if (extraLemma) return; for (Word word : words) { - List possibleSenseKeys = new ArrayList<>(); - - if (word.hasAnnotation("lemma") && word.hasAnnotation("pos")) + if (filterLemma) { - String pos = POSConverter.toWNPOS(word.getAnnotationValue("pos")); - List lemmas = word.getAnnotationValues("lemma", ";"); - for (String lemma : lemmas) + List possibleSenseKeys = new ArrayList<>(); + if (word.hasAnnotation("lemma") && word.hasAnnotation("pos")) { - String wordKey = lemma + "%" + pos; - if (!wn.isWordKeyExists(wordKey)) continue; - possibleSenseKeys.addAll(wn.getSenseKeyListFromWordKey(wordKey)); + String pos = POSConverter.toWNPOS(word.getAnnotationValue("pos")); + List lemmas = word.getAnnotationValues("lemma", ";"); + for (String lemma : lemmas) + { + String wordKey = lemma + "%" + pos; + if (!wn.isWordKeyExists(wordKey)) continue; + possibleSenseKeys.addAll(wn.getSenseKeyListFromWordKey(wordKey)); + } } - } - - if (!possibleSenseKeys.isEmpty()) - { - List possibleSenseKeyIndices = new ArrayList<>(); - for (String possibleSenseKey : possibleSenseKeys) + if (!possibleSenseKeys.isEmpty()) { - String possibleSynsetKey = wn.getSynsetKeyFromSenseKey(possibleSenseKey); - if (reducedOutputVocabulary != null) + List possibleSenseKeyIndices = new ArrayList<>(); + for (String possibleSenseKey : possibleSenseKeys) { - possibleSynsetKey = reducedOutputVocabulary.getOrDefault(possibleSynsetKey, possibleSynsetKey); + String possibleSynsetKey = wn.getSynsetKeyFromSenseKey(possibleSenseKey); + if (reducedOutputVocabulary != null) + { + possibleSynsetKey = reducedOutputVocabulary.getOrDefault(possibleSynsetKey, possibleSynsetKey); + } + if (outputVocabulary.get(senseFeatureIndex).containsKey(possibleSynsetKey)) + { + possibleSenseKeyIndices.add("" + outputVocabulary.get(senseFeatureIndex).get(possibleSynsetKey)); + } } - if (outputVocabulary.get(0).containsKey(possibleSynsetKey)) + if (possibleSenseKeyIndices.isEmpty()) { - possibleSenseKeyIndices.add("" + outputVocabulary.get(0).get(possibleSynsetKey)); + pythonProcessWriter.write("0 "); + } + else + { + pythonProcessWriter.write(StringUtils.join(possibleSenseKeyIndices, ";") + " "); } - } - if (possibleSenseKeyIndices.isEmpty()) - { - pythonProcessWriter.write("0 "); } else { - pythonProcessWriter.write(StringUtils.join(possibleSenseKeyIndices, ";") + " "); + pythonProcessWriter.write("0 "); } } else { - pythonProcessWriter.write("0 "); + pythonProcessWriter.write("-1 "); } } pythonProcessWriter.newLine(); } - - private void readPredictOutput(List words, String senseTag, String confidenceTag) throws Exception + + private List readPredictOutput(List sentences, String senseTag, String confidenceTag) throws Exception { - String line = pythonProcessReader.readLine(); - int[] output = parsePredictOutput(line); - propagatePredictOutput(words, output, senseTag, confidenceTag); + List translations = new ArrayList<>(); + for (int i = 0 ; i < sentences.size() ; i++) + { + Sentence sentence = sentences.get(i); + if (outputFeatures > 0) + { + List words = sentence.getWords(); + String line = pythonProcessReader.readLine(); + if (line.startsWith("Better speed can be achieved with apex installed")) + { + i--; + continue; + } + if (extraLemma) + { + int[][] output = parsePredictOutputExtraLemma(line); + propagatePredictOutputExtraLemma(words, output, senseTag, confidenceTag); + } + else + { + int[] output = parsePredictOutput(line); + propagatePredictOutput(words, output, senseTag, confidenceTag); + } + } + // TODO: for (int i = 0 ; i < outputTranslations ; i++) + if (outputTranslations > 0) + { + String line = pythonProcessReader.readLine(); + if (line.startsWith("Better speed can be achieved with apex installed")) + { + i--; + continue; + } + if (line.isEmpty()) + { + line = "0"; + } + String[] output = line.split(RegExp.anyWhiteSpaceGrouped.pattern()); + translations.add(processTranslationOutput(output)); + } + } + return translations; } private int[] parsePredictOutput(String line) @@ -226,8 +397,67 @@ private int[] parsePredictOutput(String line) } return output; } - + + private int[][] parsePredictOutputExtraLemma(String line) + { + String[] lineSplit = line.split(RegExp.anyWhiteSpaceGrouped.pattern()); + int[][] output = new int[lineSplit.length][]; + for (int i = 0 ; i < lineSplit.length ; i++) + { + String[] wordSplit = lineSplit[i].split("/"); + output[i] = new int[wordSplit.length]; + for (int j = 0 ; j < wordSplit.length ; j++) + { + output[i][j] = Integer.valueOf(wordSplit[j]); + } + } + return output; + } + private void propagatePredictOutput(List words, int[] output, String senseTag, String confidenceTag) + { + for (int i = 0 ; i < output.length ; i++) + { + Word word = words.get(i); + if (word.hasAnnotation(senseTag)) continue; + if (filterLemma) + { + if (!word.hasAnnotation("lemma")) continue; + if (!word.hasAnnotation("pos")) continue; + int wordOutput = output[i]; + String pos = POSConverter.toWNPOS(word.getAnnotationValue("pos")); + List lemmas = word.getAnnotationValues("lemma", ";"); + for (String lemma : lemmas) + { + String wordKey = lemma + "%" + pos; + if (!wn.isWordKeyExists(wordKey)) continue; + List lemmaSenseKeys = wn.getSenseKeyListFromWordKey(wordKey); + for (String possibleSenseKey : lemmaSenseKeys) + { + String possibleSynsetKey = wn.getSynsetKeyFromSenseKey(possibleSenseKey); + if (reducedOutputVocabulary != null) + { + possibleSynsetKey = reducedOutputVocabulary.getOrDefault(possibleSynsetKey, possibleSynsetKey); + } + if (reversedOutputVocabulary.get(senseFeatureIndex).get(wordOutput).equals(possibleSynsetKey)) + { + word.setAnnotation(senseTag, possibleSenseKey); + //word.setAnnotation(confidenceTag, confidenceInfo); + } + } + } + } + else + { + int wordOutput = output[i]; + String possibleSynsetKey = reversedOutputVocabulary.get(senseFeatureIndex).get(wordOutput); + String possibleSenseKey = wn.getSenseKeyListFromSynsetKey(possibleSynsetKey).get(0); + word.setAnnotation(senseTag, possibleSenseKey); + } + } + } + + private void propagatePredictOutputExtraLemma(List words, int[][] output, String senseTag, String confidenceTag) { for (int i = 0 ; i < output.length ; i++) { @@ -235,43 +465,95 @@ private void propagatePredictOutput(List words, int[] output, String sense if (word.hasAnnotation(senseTag)) continue; if (!word.hasAnnotation("lemma")) continue; if (!word.hasAnnotation("pos")) continue; - int wordOutput = output[i]; + int[] wordOutput = output[i]; String pos = POSConverter.toWNPOS(word.getAnnotationValue("pos")); List lemmas = word.getAnnotationValues("lemma", ";"); - for (String lemma : lemmas) + // TODO: multiple lemmas + String lemma = lemmas.get(0); + String wordKey = lemma + "%" + pos; + if (!wn.isWordKeyExists(wordKey)) continue; + if (!reversedOutputAnnotationNames.containsKey(wordKey)) continue; + int extraLemmaFeatureIndex = reversedOutputAnnotationNames.get(wordKey); + List lemmaSenseKeys = wn.getSenseKeyListFromWordKey(wordKey); + for (String possibleSenseKey : lemmaSenseKeys) { - String wordKey = lemma + "%" + pos; - if (!wn.isWordKeyExists(wordKey)) continue; - List lemmaSenseKeys = wn.getSenseKeyListFromWordKey(wordKey); - for (String possibleSenseKey : lemmaSenseKeys) + String possibleSynsetKey = wn.getSynsetKeyFromSenseKey(possibleSenseKey); + if (reducedOutputVocabulary != null) { - String possibleSynsetKey = wn.getSynsetKeyFromSenseKey(possibleSenseKey); - if (reducedOutputVocabulary != null) - { - possibleSynsetKey = reducedOutputVocabulary.getOrDefault(possibleSynsetKey, possibleSynsetKey); - } - if (reversedOutputVocabulary.get(0).get(wordOutput).equals(possibleSynsetKey)) - { - word.setAnnotation(senseTag, possibleSenseKey); - //word.setAnnotation(confidenceTag, confidenceInfo); - } + possibleSynsetKey = reducedOutputVocabulary.getOrDefault(possibleSynsetKey, possibleSynsetKey); + } + if (reversedOutputVocabulary.get(extraLemmaFeatureIndex).get(wordOutput[extraLemmaFeatureIndex]).equals(possibleSynsetKey)) + { + word.setAnnotation(senseTag, possibleSenseKey); + //word.setAnnotation(confidenceTag, confidenceInfo); } } } } - private void disambiguateNoCatch(List words, String senseTag, String confidenceTag) throws Exception + private Sentence processTranslationOutput(String[] output) + { + Sentence translation = new Sentence(); + for (String wordValue : output) + { + new Word(wordValue, translation); + } + return translation; + } + + private void disambiguateNoCatch(List sentences, String senseTag, String confidenceTag) throws Exception { - writePredictInput(words); - readPredictOutput(words, senseTag, confidenceTag); + writePredictInput(sentences); + readPredictOutput(sentences, senseTag, confidenceTag); } - + @Override - public void disambiguate(List words, String senseTag, String confidenceTag) + protected void disambiguateFixedSentenceBatch(List sentences, String senseTag, String confidenceTag) + { + try + { + disambiguateNoCatch(sentences, senseTag, confidenceTag); + } + catch (Exception e) + { + throw new RuntimeException(e); + } + } + + public List disambiguateAndTranslateDynamicSentenceBatch(List originalSentences) + { + return disambiguateAndTranslateDynamicSentenceBatch(originalSentences, "", ""); + } + + public List disambiguateAndTranslateDynamicSentenceBatch(List originalSentences, String newSenseTags, String confidenceTag) + { + List ret = new ArrayList<>(); + List sentences = new ArrayList<>(originalSentences); + while (sentences.size() > batchSize) + { + List subSentences = sentences.subList(0, batchSize); + ret.addAll(disambiguateAndTranslateFixedSentenceBatch(subSentences, newSenseTags, confidenceTag)); + subSentences.clear(); + } + if (!sentences.isEmpty()) + { + int paddingSize = batchSize - sentences.size(); + for (int i = 0 ; i < paddingSize ; i++) + { + sentences.add(new Sentence("")); + } + List translatedSentences = disambiguateAndTranslateFixedSentenceBatch(sentences, newSenseTags, confidenceTag); + translatedSentences = translatedSentences.subList(0, originalSentences.size()); + ret.addAll(translatedSentences); + } + return ret; + } + + private List disambiguateAndTranslateFixedSentenceBatch(List sentences, String senseTag, String confidenceTag) { try { - disambiguateNoCatch(words, senseTag, confidenceTag); + return disambiguateAndTranslateNoCatch(sentences, senseTag, confidenceTag); } catch (Exception e) { @@ -279,6 +561,12 @@ public void disambiguate(List words, String senseTag, String confidenceTag } } + private List disambiguateAndTranslateNoCatch(List sentences, String senseTag, String confidenceTag) throws Exception + { + writePredictInput(sentences); + return readPredictOutput(sentences, senseTag, confidenceTag); + } + @Override public void close() throws Exception { diff --git a/java/src/main/java/getalp/wsd/method/result/DisambiguationResult.java b/java/src/main/java/getalp/wsd/method/result/DisambiguationResult.java index 4e190c1..e93e0aa 100644 --- a/java/src/main/java/getalp/wsd/method/result/DisambiguationResult.java +++ b/java/src/main/java/getalp/wsd/method/result/DisambiguationResult.java @@ -1,5 +1,9 @@ package getalp.wsd.method.result; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + public class DisambiguationResult { public int total; @@ -8,24 +12,32 @@ public class DisambiguationResult public int bad; - public int attempted; - - public int missed; - public double time; - + + public Map totalPerPOS = initMapPerPos(); + + public Map goodPerPOS = initMapPerPos(); + + public Map badPerPOS = initMapPerPos(); + public DisambiguationResult() { this(0, 0, 0); } public DisambiguationResult(int total, int good, int bad) + { + this(total, good, bad, initMapPerPos(), initMapPerPos(), initMapPerPos()); + } + + public DisambiguationResult(int total, int good, int bad, Map totalPerPOS, Map goodPerPOS, Map badPerPOS) { this.total = total; this.good = good; this.bad = bad; - this.attempted = good + bad; - this.missed = total - attempted; + this.totalPerPOS = totalPerPOS; + this.goodPerPOS = goodPerPOS; + this.badPerPOS = badPerPOS; } public void concatenateResult(DisambiguationResult other) @@ -33,13 +45,27 @@ public void concatenateResult(DisambiguationResult other) total += other.total; good += other.good; bad += other.bad; - attempted += other.attempted; - missed += other.missed; + for (String pos : Arrays.asList("n", "v", "a", "r", "x")) + { + totalPerPOS.put(pos, totalPerPOS.get(pos) + other.totalPerPOS.get(pos)); + goodPerPOS.put(pos, goodPerPOS.get(pos) + other.goodPerPOS.get(pos)); + badPerPOS.put(pos, badPerPOS.get(pos) + other.badPerPOS.get(pos)); + } } - + + public int attempted() + { + return good + bad; + } + + public int missed() + { + return total - attempted(); + } + public double coverage() { - return ratioPercent(total - missed, total); + return ratioPercent(attempted(), total); } public double scoreRecall() @@ -49,7 +75,7 @@ public double scoreRecall() public double scorePrecision() { - return ratioPercent(good, total - missed); + return ratioPercent(good, attempted()); } public double scoreF1() @@ -59,8 +85,50 @@ public double scoreF1() return 2.0 * ((p * r) / (p + r)); } + public int attemptedPerPOS(String pos) + { + return goodPerPOS.get(pos) + badPerPOS.get(pos); + } + + public int missedPerPOS(String pos) + { + return totalPerPOS.get(pos) - attemptedPerPOS(pos); + } + + public double coveragePerPOS(String pos) + { + return ratioPercent(attemptedPerPOS(pos), totalPerPOS.get(pos)); + } + + public double scoreRecallPerPOS(String pos) + { + return ratioPercent(goodPerPOS.get(pos), totalPerPOS.get(pos)); + } + + public double scorePrecisionPerPOS(String pos) + { + return ratioPercent(goodPerPOS.get(pos), attemptedPerPOS(pos)); + } + + public double scoreF1PerPOS(String pos) + { + double r = scoreRecallPerPOS(pos); + double p = scorePrecisionPerPOS(pos); + return 2.0 * ((p * r) / (p + r)); + } + private double ratioPercent(double num, double den) { return (num / den) * 100; } + + private static Map initMapPerPos() + { + Map map = new HashMap<>(); + for (String pos : Arrays.asList("n", "v", "a", "r", "x")) + { + map.put(pos, 0); + } + return map; + } } \ No newline at end of file diff --git a/java/src/main/java/getalp/wsd/method/result/MultipleDisambiguationResult.java b/java/src/main/java/getalp/wsd/method/result/MultipleDisambiguationResult.java index 3f6e420..49df902 100644 --- a/java/src/main/java/getalp/wsd/method/result/MultipleDisambiguationResult.java +++ b/java/src/main/java/getalp/wsd/method/result/MultipleDisambiguationResult.java @@ -36,7 +36,7 @@ public double timeMean() public double[] allScores() { - return results.stream().mapToDouble((DisambiguationResult r) -> {return r.scoreF1();}).toArray(); + return results.stream().mapToDouble(DisambiguationResult::scoreF1).toArray(); } public double[] allTimes() diff --git a/java/src/main/java/getalp/wsd/ufsac/simple/core/Annotation.java b/java/src/main/java/getalp/wsd/ufsac/simple/core/Annotation.java new file mode 100644 index 0000000..4ad7f93 --- /dev/null +++ b/java/src/main/java/getalp/wsd/ufsac/simple/core/Annotation.java @@ -0,0 +1,61 @@ +package getalp.wsd.ufsac.simple.core; + +import getalp.wsd.common.utils.StringUtils; + +import java.util.Arrays; +import java.util.List; + +public class Annotation +{ + private String annotationName; + + private String annotationValue; + + public Annotation(String name, String value) + { + if (name == null) this.annotationName = ""; + else this.annotationName = name; + if (value == null) this.annotationValue = ""; + else this.annotationValue = value; + } + + public Annotation(String name, List values, String delimiter) + { + if (name == null) this.annotationName = ""; + else this.annotationName = name; + if (values == null) this.annotationValue = ""; + else this.annotationValue = StringUtils.join(values, delimiter); + } + + public String getAnnotationName() + { + return annotationName; + } + + public String getAnnotationValue() + { + return annotationValue; + } + + public void setAnnotationValue(String value) + { + if (value == null) this.annotationValue = ""; + else this.annotationValue = value; + } + + public List getAnnotationValues(String delimiter) + { + return Arrays.asList(annotationValue.split(delimiter)); + } + + public void setAnnotationValues(List values, String delimiter) + { + if (values == null) this.annotationValue = ""; + else this.annotationValue = StringUtils.join(values, delimiter); + } + + public String toString() + { + return annotationName + "=" + annotationValue; + } +} diff --git a/java/src/main/java/getalp/wsd/ufsac/simple/core/LexicalEntity.java b/java/src/main/java/getalp/wsd/ufsac/simple/core/LexicalEntity.java new file mode 100644 index 0000000..fd9238e --- /dev/null +++ b/java/src/main/java/getalp/wsd/ufsac/simple/core/LexicalEntity.java @@ -0,0 +1,99 @@ +package getalp.wsd.ufsac.simple.core; + +import java.util.*; + +public class LexicalEntity +{ + private Map annotationsAsMap; + + private List annotationsAsList; + + public LexicalEntity() + { + annotationsAsList = new ArrayList<>(); + annotationsAsMap = new HashMap<>(); + } + + public LexicalEntity(getalp.wsd.ufsac.core.LexicalEntity lexicalEntityToCopy) + { + this(); + for (getalp.wsd.ufsac.core.Annotation annotationToCopy : lexicalEntityToCopy.getAnnotations()) + { + setAnnotation(annotationToCopy.getAnnotationName(), annotationToCopy.getAnnotationValue()); + } + } + + public List getAnnotations() + { + return Collections.unmodifiableList(annotationsAsList); + } + + public String getAnnotationValue(String annotationName) + { + if (!annotationsAsMap.containsKey(annotationName)) return ""; + return annotationsAsMap.get(annotationName).getAnnotationValue(); + } + + public List getAnnotationValues(String annotationName, String delimiter) + { + if (!annotationsAsMap.containsKey(annotationName)) return Collections.emptyList(); + return annotationsAsMap.get(annotationName).getAnnotationValues(delimiter); + } + + public void setAnnotation(String annotationName, String annotationValue) + { + if (annotationName == null || annotationName.equals("")) return; + if (annotationValue == null) annotationValue = ""; + if (hasAnnotation(annotationName)) + { + annotationsAsMap.get(annotationName).setAnnotationValue(annotationValue); + } + else + { + Annotation a = new Annotation(annotationName, annotationValue); + annotationsAsList.add(a); + annotationsAsMap.put(annotationName, a); + } + } + + public void setAnnotation(String annotationName, List annotationValues, String delimiter) + { + if (annotationName == null || annotationName.equals("")) return; + if (annotationValues == null) annotationValues = Collections.emptyList(); + if (hasAnnotation(annotationName)) + { + annotationsAsMap.get(annotationName).setAnnotationValues(annotationValues, delimiter); + } + else + { + Annotation a = new Annotation(annotationName, annotationValues, delimiter); + annotationsAsList.add(a); + annotationsAsMap.put(annotationName, a); + } + } + + public void removeAnnotation(String annotationName) + { + annotationsAsList.removeIf(a -> a.getAnnotationName().equals(annotationName)); + annotationsAsMap.remove(annotationName); + } + + public void removeAllAnnotations() + { + annotationsAsList.clear(); + annotationsAsMap.clear(); + } + + public boolean hasAnnotation(String annotationName) + { + return !getAnnotationValue(annotationName).isEmpty(); + } + + public void transfertAnnotationsToCopy(LexicalEntity copy) + { + for (Annotation a : this.annotationsAsList) + { + copy.setAnnotation(a.getAnnotationName(), a.getAnnotationValue()); + } + } +} diff --git a/java/src/main/java/getalp/wsd/ufsac/simple/core/Sentence.java b/java/src/main/java/getalp/wsd/ufsac/simple/core/Sentence.java new file mode 100644 index 0000000..112d731 --- /dev/null +++ b/java/src/main/java/getalp/wsd/ufsac/simple/core/Sentence.java @@ -0,0 +1,101 @@ +package getalp.wsd.ufsac.simple.core; + +import getalp.wsd.common.utils.RegExp; + +import java.util.ArrayList; +import java.util.List; + +public class Sentence extends LexicalEntity +{ + List words = new ArrayList<>(); + + public Sentence() + { + super(); + } + + public Sentence(getalp.wsd.ufsac.core.Sentence sentenceToCopy) + { + super(sentenceToCopy); + for (getalp.wsd.ufsac.core.Word word : sentenceToCopy.getWords()) + { + this.addWord(new Word(word)); + } + } + + public Sentence(String value) + { + addWordsFromString(value); + } + + public Sentence(List words) + { + for (Word word : new ArrayList<>(words)) + { + addWord(word); + } + } + + public void addWord(Word word) + { + words.add(word); + } + + public void removeWord(Word word) + { + words.remove(word); + } + + public void removeAllWords() + { + words.clear(); + } + + public List getWords() + { + return words; + } + + public void limitSentenceLength(int maxLength) + { + if (words.size() > maxLength) + { + words = words.subList(0, maxLength); + } + } + + public Sentence clone() + { + Sentence newSentence = new Sentence(); + transfertWordsToCopy(newSentence); + transfertAnnotationsToCopy(newSentence); + return newSentence; + } + + public void transfertWordsToCopy(Sentence other) + { + for (Word word : getWords()) + { + other.addWord(word.clone()); + } + } + + public void addWordsFromString(String value) + { + String[] wordsArray = value.split(RegExp.anyWhiteSpaceGrouped.toString()); + for (String wordInArray : wordsArray) + { + addWord(new Word(wordInArray)); + } + } + + public String toString() + { + String ret = ""; + for (Word word : getWords()) + { + ret += word.toString() + " "; + } + return ret.trim(); + } +} diff --git a/java/src/main/java/getalp/wsd/ufsac/simple/core/Word.java b/java/src/main/java/getalp/wsd/ufsac/simple/core/Word.java new file mode 100644 index 0000000..299989c --- /dev/null +++ b/java/src/main/java/getalp/wsd/ufsac/simple/core/Word.java @@ -0,0 +1,41 @@ +package getalp.wsd.ufsac.simple.core; + +public class Word extends LexicalEntity +{ + public Word() + { + super(); + } + + public Word(String value) + { + setAnnotation("surface_form", value); + } + + public Word(getalp.wsd.ufsac.core.Word wordToCopy) + { + super(wordToCopy); + } + + public void setValue(String value) + { + setAnnotation("surface_form", value); + } + + public String getValue() + { + return getAnnotationValue("surface_form"); + } + + public Word clone() + { + Word copy = new Word(); + transfertAnnotationsToCopy(copy); + return copy; + } + + public String toString() + { + return getAnnotationValue("surface_form"); + } +} diff --git a/java/src/main/java/getalp/wsd/utils/WordnetUtils.java b/java/src/main/java/getalp/wsd/utils/WordnetUtils.java index c23148e..fce4123 100644 --- a/java/src/main/java/getalp/wsd/utils/WordnetUtils.java +++ b/java/src/main/java/getalp/wsd/utils/WordnetUtils.java @@ -1,5 +1,8 @@ package getalp.wsd.utils; +import java.io.BufferedReader; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -14,6 +17,16 @@ public class WordnetUtils { + public static String extractLemmaFromSenseKey(String senseKey) + { + return senseKey.substring(0, senseKey.indexOf("%")); + } + + public static String extractPOSFromSenseKey(String senseKey) + { + return POSConverter.toWNPOS(Integer.valueOf(senseKey.substring(senseKey.indexOf("%") + 1, senseKey.indexOf("%") + 2))); + } + public static Set getUniqueSynsetKeysFromSenseKeys(WordnetHelper wn, List senseKeys) { Set synsetKeys = new HashSet<>(); @@ -23,7 +36,7 @@ public static Set getUniqueSynsetKeysFromSenseKeys(WordnetHelper wn, Lis } return synsetKeys; } - + public static void getHypernymHierarchy(WordnetHelper wn, String synsetKey, List hypernymyHierarchy) { if (hypernymyHierarchy.contains(synsetKey)) return; @@ -34,14 +47,33 @@ public static void getHypernymHierarchy(WordnetHelper wn, String synsetKey, List getHypernymHierarchy(wn, hypernymSynsetKeys.get(0), hypernymyHierarchy); } } - + public static List getHypernymHierarchy(WordnetHelper wn, String synsetKey) { List hypernymyHierarchy = new ArrayList<>(); getHypernymHierarchy(wn, synsetKey, hypernymyHierarchy); return hypernymyHierarchy; } - + + public static void getHypernymHierarchyIncludeInstanceHypernyms(WordnetHelper wn, String synsetKey, List hypernymyHierarchy) + { + if (hypernymyHierarchy.contains(synsetKey)) return; + hypernymyHierarchy.add(synsetKey); + List hypernymSynsetKeys = wn.getHypernymSynsetKeysFromSynsetKey(synsetKey); + hypernymSynsetKeys.addAll(wn.getInstanceHypernymSynsetKeysFromSynsetKey(synsetKey)); + if (!hypernymSynsetKeys.isEmpty()) + { + getHypernymHierarchy(wn, hypernymSynsetKeys.get(0), hypernymyHierarchy); + } + } + + public static List getHypernymHierarchyIncludeInstanceHypernyms(WordnetHelper wn, String synsetKey) + { + List hypernymyHierarchy = new ArrayList<>(); + getHypernymHierarchyIncludeInstanceHypernyms(wn, synsetKey, hypernymyHierarchy); + return hypernymyHierarchy; + } + public static Map getReducedSynsetKeysWithHypernyms1(WordnetHelper wn, String[] corpora, boolean removeMonosemics, boolean removeCoarseGrained) { String senseTag = "wn" + wn.getVersion() + "_key"; @@ -354,4 +386,151 @@ public static Map getReducedSynsetKeysWithHypernyms4(WordnetHelp return synsetKeysToSimpleSynsetKey; } + + public static Map getSenseCompressionThroughHypernymsAndInstanceHypernymsClusters(WordnetHelper wn, Map currentClusters) + { + Map> allVocabulary = new HashMap<>(); + Map> allHypernymHierarchy = new HashMap<>(); + + for (String wordKey : wn.getVocabulary()) + { + allVocabulary.putIfAbsent(wordKey, new HashSet<>()); + for (String senseKey : wn.getSenseKeyListFromWordKey(wordKey)) + { + String synsetKey = wn.getSynsetKeyFromSenseKey(senseKey); + allVocabulary.get(wordKey).add(synsetKey); + allHypernymHierarchy.putIfAbsent(synsetKey, getHypernymHierarchyIncludeInstanceHypernyms(wn, synsetKey)); + } + } + + Set necessarySynsetKeys = new HashSet<>(); + + for (String wordKey : allVocabulary.keySet()) + { + for (String synsetKey : allVocabulary.get(wordKey)) + { + List hypernymHierarchy = allHypernymHierarchy.get(synsetKey); + int whereToStop = hypernymHierarchy.size(); + boolean found = false; + for (int i = 0 ; i < hypernymHierarchy.size() ; i++) + { + if (found) break; + for (String synsetKey2 : allVocabulary.get(wordKey)) + { + if (synsetKey.equals(synsetKey2)) continue; + if (found) break; + List hypernymHierarchy2 = allHypernymHierarchy.get(synsetKey2); + for (int j = 0 ; j < hypernymHierarchy2.size() ; j++) + { + if (hypernymHierarchy.get(i).equals(hypernymHierarchy2.get(j))) + { + whereToStop = i; + found = true; + break; + } + } + } + } + if (whereToStop == 0) + { + necessarySynsetKeys.add(hypernymHierarchy.get(0)); + } + else // > 0 + { + necessarySynsetKeys.add(hypernymHierarchy.get(whereToStop - 1)); + } + } + } + + Map synsetKeysToSimpleSynsetKey = new HashMap<>(); + + for (String synsetKey : allHypernymHierarchy.keySet()) + { + List hypernymHierarchy = allHypernymHierarchy.get(synsetKey); + for (int i = 0 ; i < hypernymHierarchy.size() ; i++) + { + if (necessarySynsetKeys.contains(hypernymHierarchy.get(i))) + { + synsetKeysToSimpleSynsetKey.put(synsetKey, hypernymHierarchy.get(i)); + break; + } + } + } + + return synsetKeysToSimpleSynsetKey; + } + + public static Map getSenseCompressionThroughHypernymsClusters(WordnetHelper wn, Map currentClusters) + { + return getReducedSynsetKeysWithHypernyms3(wn); + } + + public static Map getSenseCompressionThroughAntonymsClusters(WordnetHelper wn, Map currentClusters) + { + Map antonymClusters = new HashMap<>(); + for (String synsetKey : currentClusters.values()) + { + if (antonymClusters.containsKey(synsetKey)) continue; + List antonymSynsetKeys = wn.getAntonymSynsetKeysFromSynsetKey(synsetKey); + for (String antonymSynsetKey : antonymSynsetKeys) + { + antonymClusters.put(antonymSynsetKey, synsetKey); + } + antonymClusters.put(synsetKey, synsetKey); + } + Map newClusters = new HashMap<>(); + for (String synsetKey : currentClusters.keySet()) + { + newClusters.put(synsetKey, antonymClusters.get(currentClusters.get(synsetKey))); + } + return newClusters; + } + + public static Map getSenseCompressionClusters(WordnetHelper wn, boolean hypernyms, boolean instanceHypernyms, boolean antonyms) + { + Map clusters = new HashMap<>(); + for (String wordKey : wn.getVocabulary()) + { + for (String senseKey : wn.getSenseKeyListFromWordKey(wordKey)) + { + String synsetKey = wn.getSynsetKeyFromSenseKey(senseKey); + clusters.putIfAbsent(synsetKey, synsetKey); + } + } + + if (hypernyms) + { + if (instanceHypernyms) + { + clusters = getSenseCompressionThroughHypernymsAndInstanceHypernymsClusters(wn, clusters); + } + else + { + clusters = getSenseCompressionThroughHypernymsClusters(wn, clusters); + } + } + + if (antonyms) + { + clusters = getSenseCompressionThroughAntonymsClusters(wn, clusters); + } + + return clusters; + } + + public static Map getSenseCompressionClustersFromFile(String filePath) + { + try + { + Map mapping = new HashMap<>(); + BufferedReader reader = Files.newBufferedReader(Paths.get(filePath)); + reader.lines().map(line -> line.split(" ")).forEach(line -> mapping.put(line[0], line[1])); + reader.close(); + return mapping; + } + catch (Exception e) + { + throw new RuntimeException(e); + } + } } diff --git a/python/getalp/common/common.py b/python/getalp/common/common.py index 81d9034..2daf7fd 100644 --- a/python/getalp/common/common.py +++ b/python/getalp/common/common.py @@ -1,6 +1,79 @@ import pathlib +import sys +import os def create_directory_if_not_exists(directory_path): pathlib.Path(directory_path).mkdir(parents=True, exist_ok=True) + +def get_abs_path(path): + if path is None: + return None + elif os.path.isabs(path): + return path + else: + return os.path.abspath(path) + + +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + + +class Struct: + def __init__(self, **entries): + self.__dict__.update(entries) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise RuntimeError('Boolean value expected.') + + +def set_if_not_none(value, value_if_none): + return value_if_none if value is None else value + + +def get_value_as_str_list(value): + if value is None: + return [] + elif isinstance(value, str): + return [value] + else: + return value + + +def get_value_as_int_list(value): + if value is None: + return [] + elif isinstance(value, int): + return [value] + else: + return value + + +def get_value_as_bool_list(value): + if value is None: + return [] + elif isinstance(value, bool): + return [value] + else: + return value + + +def pad_list(list_to_pad, pad_length, pad_value): + for i in range(len(list_to_pad), pad_length): + list_to_pad.append(pad_value) + + +def count_lines(file_path): + file = open(file_path) + line_count = 0 + for _ in file: + line_count += 1 + file.close() + return line_count diff --git a/python/getalp/wsd/common.py b/python/getalp/wsd/common.py index d95c325..e6816a0 100644 --- a/python/getalp/wsd/common.py +++ b/python/getalp/wsd/common.py @@ -1,16 +1,30 @@ -import random -import numpy as np import json from typing import List +from getalp.common.common import eprint, count_lines +from getalp.wsd.torch_fix import * +from getalp.wsd.torch_utils import cpu_device +from torch.nn.utils.rnn import pad_sequence + +pad_token_index = 0 +unk_token_index = 1 +bos_token_index = 2 +eos_token_index = 3 +reserved_token_count = 4 +pad_token = "" +unk_token = "" +bos_token = "" +eos_token = "" + + +def get_vocabulary(vocabulary_file_path): + vocabulary_file = open(vocabulary_file_path) + vocabulary = [line.rstrip().split()[-1] for line in vocabulary_file] + vocabulary_file.close() + return vocabulary def get_vocabulary_size(vocabulary_file_path): - vocabulary_size = 0 - vocabulary_file = open(vocabulary_file_path) - for _ in vocabulary_file: - vocabulary_size += 1 - vocabulary_file.close() - return vocabulary_size + return count_lines(vocabulary_file_path) def get_embeddings_size(embeddings_file_path): @@ -21,171 +35,219 @@ def get_embeddings_size(embeddings_file_path): return embeddings_size -def load_vocabulary(vocabulary_file_path): - vocabulary_file = open(vocabulary_file_path) - vocabulary = [] - for line in vocabulary_file: - vocabulary.append(line.split()[1]) - vocabulary_file.close() - return vocabulary - - def get_pretrained_embeddings(pretrained_model_path): - embeddings_count = 2 + get_vocabulary_size(pretrained_model_path) + embeddings_count = reserved_token_count + get_vocabulary_size(pretrained_model_path) embeddings_size = get_embeddings_size(pretrained_model_path) - embeddings = np.empty(shape=(embeddings_count, embeddings_size), dtype=np.float32) - embeddings[0] = np.zeros(embeddings_size) # = 0 - embeddings[1] = np.zeros(embeddings_size) # = 1 - i = 2 + embeddings = torch_empty((embeddings_count, embeddings_size), dtype=torch_float32, device=cpu_device) + embeddings[pad_token_index] = torch_zeros(embeddings_size) # = 0 + embeddings[unk_token_index] = torch_zeros(embeddings_size) # = 1 + embeddings[bos_token_index] = torch_zeros(embeddings_size) # = 2 + embeddings[eos_token_index] = torch_zeros(embeddings_size) # = 3 + i = reserved_token_count f = open(pretrained_model_path) for line in f: vector = line.split()[1:] + if len(vector) != embeddings_size: + eprint("Warning: cannot load pretrained embedding at index " + str(i-reserved_token_count)) + continue vector = [float(i) for i in vector] - embeddings[i] = np.array(vector, dtype=np.float32) + embeddings[i] = torch_tensor(vector, dtype=torch_float32, device=cpu_device) i += 1 f.close() return embeddings -def read_sample_x_or_y_from_string(string): - sample_x: List = None +def read_sample_x_from_string(string: str, feature_count: int, clear_text: List[bool]): + sample_x: List = [[] for _ in range(feature_count)] for word in string.split(): word_features = word.split('/') - if sample_x is None: - sample_x = [[] for _ in range(len(word_features))] - for i in range(len(word_features)): - sample_x[i].append(int(word_features[i])) - for i in range(len(sample_x)): - sample_x[i] = np.array(sample_x[i], dtype=np.int64) + for i in range(feature_count): + if clear_text[i]: + sample_x[i].append(word_features[i].replace("", "/")) + else: + sample_x[i].append(int(word_features[i])) + for i in range(feature_count): + if not clear_text[i]: + sample_x[i] = torch_tensor(sample_x[i], dtype=torch_long, device=cpu_device) return sample_x -def read_sample_z_from_string(string): - sample_y = None +def read_sample_y_from_string(string: str, feature_count: int): + sample_y: List = [[] for _ in range(feature_count)] for word in string.split(): word_features = word.split('/') - if sample_y is None: - sample_y = [[] for _ in range(len(word_features))] - for i in range(len(word_features)): - sample_y[i].append([int(j) for j in word_features[i].split(";")]) + for i in range(feature_count): + sample_y[i].append(int(word_features[i])) + for i in range(feature_count): + sample_y[i] = torch_tensor(sample_y[i], dtype=torch_long, device=cpu_device) return sample_y -def read_all_samples_from_file(file_path): +def read_sample_z_from_string(string: str, feature_count: int): + sample_z = [[] for _ in range(feature_count)] + for word in string.split(): + word_features = word.split('/') + for i in range(feature_count): + sample_z[i].append([int(j) for j in word_features[i].split(";")]) + return sample_z + + +def read_sample_t_from_string(string: str, feature_count: int): + sample_t: List = [[] for _ in range(feature_count)] + for word in string.split(): + word_features = word.split('/') + for i in range(feature_count): + sample_t[i].append(int(word_features[i])) + for i in range(feature_count): + sample_t[i].append(eos_token_index) + sample_t[i] = torch_tensor(sample_t[i], dtype=torch_long, device=cpu_device) + return sample_t + + +def read_samples_from_file(file_path: str, input_clear_text: List[bool], output_features: int, output_translations: int, output_translation_features: int, output_translation_clear_text: bool, limit: int = -1): file = open(file_path, "r") samples = [] sample_triplet = [] + sample_tt = [] i = 0 for line in file: if i == 0: - sample_x = read_sample_x_or_y_from_string(line) + if limit > 0 and len(samples) >= limit > 0: + break + sample_x = read_sample_x_from_string(line, feature_count=len(input_clear_text), clear_text=input_clear_text) sample_triplet = [sample_x] - i = 1 + if output_features > 0: + i = 1 + else: + sample_triplet.append([]) + sample_triplet.append([]) + i = 3 elif i == 1: - sample_y = read_sample_x_or_y_from_string(line) + sample_y = read_sample_y_from_string(line, feature_count=output_features) sample_triplet.append(sample_y) i = 2 elif i == 2: - sample_z = read_sample_z_from_string(line) + sample_z = read_sample_z_from_string(line, feature_count=output_features) sample_triplet.append(sample_z) - samples.append(sample_triplet) - i = 0 + if output_translations > 0: + i = 3 + else: + sample_triplet.append([]) + samples.append(sample_triplet) + i = 0 + elif i == 3: + if output_translation_clear_text: + raise NotImplementedError + # TODO: + # sample_t = read_sample_x_from_string(line, feature_count=0, clear_text=[True], add_eos_token=False) + else: + sample_t = read_sample_t_from_string(line, feature_count=output_translation_features) + sample_tt.append(sample_t) + if len(sample_tt) >= output_translations: + sample_triplet.append(sample_tt) + sample_tt = [] + samples.append(sample_triplet) + i = 0 file.close() return samples -def create_fake_batch(batch_size, sample_size, input_features, input_vocabulary_sizes, output_features, output_vocabulary_sizes): - batch_x = [] - for i in range(input_features): - feature_batch_x = [] - for j in range(batch_size): - sample_x = [] - for k in range(sample_size): - sample_x.append(random.randrange(0, input_vocabulary_sizes[i])) - feature_batch_x.append(sample_x) - feature_batch_x = np.array(feature_batch_x, dtype=np.int64) - batch_x.append(feature_batch_x) - batch_y = [] - for i in range(output_features): - feature_batch_y = [] - for j in range(batch_size): - sample_y = [] - for k in range(sample_size): - sample_y.append(random.randrange(0, output_vocabulary_sizes[i])) - feature_batch_y.append(sample_y) - feature_batch_y = np.array(feature_batch_y, dtype=np.int64) - batch_y.append(feature_batch_y) - return batch_x, batch_y - - -def read_batch_from_samples(samples, batch_size, current_index): - batch_x = None - batch_y = None - batch_z = None +def pad_batch_x(batch_x, clear_text): + for i in range(len(batch_x)): + if not clear_text[i]: + batch_x[i] = pad_sequence(batch_x[i], batch_first=True) + + +def pad_batch_y(batch_y): + for i in range(len(batch_y)): + batch_y[i] = pad_sequence(batch_y[i], batch_first=True) + + +def pad_batch_tt(batch_tt): + for i in range(len(batch_tt)): + for j in range(len(batch_tt[i])): + batch_tt[i][j] = pad_sequence(batch_tt[i][j], batch_first=True) + + +def unpad_turn_to_text_and_remove_bpe_of_batch_t(batch_t, vocabulary: List[str]): + ret: List[str] = [] + for k in range(len(batch_t)): + str_as_list = [] + for l in range(len(batch_t[k])): + value = batch_t[k][l].item() + if value == eos_token_index or value == pad_token_index: + break + value = vocabulary[value] + str_as_list.append(value) + str_as_str = " ".join(str_as_list) + str_as_str = str_as_str.replace("@@ ", "") + str_as_str = str_as_str.replace(" ##", "") + ret.append(str_as_str) + return ret + + +def read_batch_from_samples(samples, batch_size: int, token_per_batch: int, current_index: int, input_features: int, output_features: int, output_translations: int, output_translation_features: int, input_clear_text: List[bool], output_translation_clear_text: bool): + batch_x = [[] for _ in range(input_features)] + batch_y = [[] for _ in range(output_features)] + batch_z = [[] for _ in range(output_features)] + batch_tt = [[[] for __ in range(output_translation_features)] for _ in range(output_translations)] actual_batch_size = 0 reached_eof = False max_length = 0 + max_length_tt: List[int] = [0 for _ in range(output_translations)] - for i in range(current_index, current_index + batch_size): - if i >= len(samples): + while True: + if current_index >= len(samples): reached_eof = True break + if actual_batch_size >= batch_size > 0: + break - sample = samples[i] + sample = samples[current_index] + + max_length_if_accepted = max(max_length, len(sample[0][0])) + if (actual_batch_size + 1) * max_length_if_accepted > token_per_batch > 0: + break + max_length = max_length_if_accepted sample_x = sample[0] - if batch_x is None: - batch_x = [[] for _ in range(len(sample_x))] - for j in range(len(sample_x)): + for j in range(input_features): batch_x[j].append(sample_x[j]) - max_length = max(max_length, len(sample_x[0])) sample_y = sample[1] - if batch_y is None: - batch_y = [[] for _ in range(len(sample_y))] - for j in range(len(sample_y)): + for j in range(output_features): batch_y[j].append(sample_y[j]) sample_z = sample[2] - if batch_z is None: - batch_z = [[] for _ in range(len(sample_z))] - for j in range(len(sample_z)): + for j in range(output_features): batch_z[j].append(sample_z[j]) - actual_batch_size += 1 - - for j in range(0, actual_batch_size): - padding_needed = max_length - len(batch_x[0][j]) - for i in range(len(batch_x)): - batch_x[i][j] = np.pad(batch_x[i][j], (0, padding_needed), mode='constant', constant_values=0) - for i in range(len(batch_y)): - batch_y[i][j] = np.pad(batch_y[i][j], (0, padding_needed), mode='constant', constant_values=0) - - if batch_x is None: - batch_x = [] - - if batch_y is None: - batch_y = [] + sample_tt = sample[3] + for j in range(output_translations): + for k in range(output_translation_features): + batch_tt[j][k].append(sample_tt[j][k]) + max_length_tt[j] = max(max_length_tt[j], len(sample_tt[j][0])) - if batch_z is None: - batch_z = [] - - for i in range(len(batch_x)): - batch_x[i] = np.array(batch_x[i], dtype=np.int64) + actual_batch_size += 1 + current_index += 1 - for i in range(len(batch_y)): - batch_y[i] = np.array(batch_y[i], dtype=np.int64) + pad_batch_x(batch_x, input_clear_text) + pad_batch_y(batch_y) + pad_batch_tt(batch_tt) # TODO: output_translation_clear_text - return batch_x, batch_y, batch_z, actual_batch_size, reached_eof + return batch_x, batch_y, batch_z, batch_tt, actual_batch_size, reached_eof -def save_training_info(file_path, current_ensemble, current_epoch, current_batch, train_line, current_best_wsd, current_best_loss): - info = {"current_ensemble":current_ensemble, - "current_epoch":current_epoch, - "current_batch":current_batch, - "train_line":train_line, - "current_best_wsd":current_best_wsd, - "current_best_loss":current_best_loss, +def save_training_info(file_path, current_ensemble, current_epoch, current_batch, current_batch_total, train_line, current_best_loss, current_best_wsd, current_best_bleu, random_seed): + info = {"current_ensemble": current_ensemble, + "current_epoch": current_epoch, + "current_batch": current_batch, + "current_batch_total": current_batch_total, + "train_line": train_line, + "current_best_loss": current_best_loss, + "current_best_wsd": current_best_wsd, + "current_best_bleu": current_best_bleu, + "random_seed": random_seed } file = open(file_path, "w") json.dump(info, file) @@ -196,24 +258,13 @@ def load_training_info(file_path): file = open(file_path, "r") info = json.load(file) file.close() - return info["current_ensemble"], info["current_epoch"], info["current_batch"], info["train_line"], info["current_best_wsd"], info["current_best_loss"] - - -def save_training_losses(file_path, train_loss, dev_loss, dev_wsd): - file = open(file_path, "a") - file.write(str(train_loss) + " " + str(dev_loss) + " " + str(dev_wsd) + "\n") - file.close() - - -def load_training_losses(file_path): - file = open(file_path, "r") - train_losses = [] - dev_losses = [] - dev_wsd = [] - for line in file: - linesplit = line.split() - train_losses.append(float(linesplit[0])) - dev_losses.append(float(linesplit[1])) - dev_wsd.append(float(linesplit[2])) - file.close() - return train_losses, dev_losses, dev_wsd + return (info["current_ensemble"], + info["current_epoch"], + info["current_batch"], + info["current_batch_total"], + info["train_line"], + info["current_best_loss"], + info["current_best_wsd"], + info["current_best_bleu"], + info["random_seed"] + ) diff --git a/python/getalp/wsd/data_config.py b/python/getalp/wsd/data_config.py new file mode 100644 index 0000000..9c29229 --- /dev/null +++ b/python/getalp/wsd/data_config.py @@ -0,0 +1,89 @@ +import json +from getalp.wsd.common import get_vocabulary, get_vocabulary_size, get_pretrained_embeddings +from getalp.common.common import get_value_as_str_list, get_value_as_bool_list, pad_list +from typing import List +import os +import numpy as np + + +class DataConfig(object): + + def __init__(self): + self.config_root_path: str = str() + self.input_features: int = int() + self.input_vocabularies: List[List[str]] = [] + self.input_vocabulary_sizes: List[int] = [] + self.input_embeddings_path: List[str] = [] + self.input_embeddings: List[np.array] = [] + self.input_clear_text: List[bool] = [] + self.output_features: int = int() + self.output_feature_names: List[str] = [] + self.output_vocabulary_sizes: List[int] = [] + self.output_translations: int = 0 + self.output_translation_features: int = 0 + self.output_translation_vocabularies: List[List[List[str]]] = [] + self.output_translation_vocabulary_sizes: List[List[int]] = [] + self.output_translation_clear_text: bool = bool() + + def load_from_file(self, file_path): + file = open(file_path, "r") + data = json.load(file) + file.close() + self.load_from_serializable_data(data, os.path.dirname(os.path.abspath(file_path))) + + def load_from_serializable_data(self, data, config_root_path): + self.config_root_path = config_root_path + self.input_features = data["input_features"] + self.load_input_vocabularies() + self.load_input_embeddings_path(data) + self.load_input_embeddings() + self.load_input_clear_text_values(data) + self.output_features = data["output_features"] + self.output_feature_names = data["output_annotation_name"] + self.load_output_vocabulary() + self.output_translations = data.get("output_translations", 0) + self.output_translation_features = data.get("output_translations", 1) + self.load_translation_output_vocabulary() + self.output_translation_clear_text = data.get("output_translation_clear_text", False) + + def load_input_vocabularies(self): + for i in range(0, self.input_features): + vocab = get_vocabulary(self.config_root_path + "/input_vocabulary" + str(i)) + self.input_vocabularies.append(vocab) + self.input_vocabulary_sizes.append(len(vocab)) + + def load_input_embeddings_path(self, data): + self.input_embeddings_path = get_value_as_str_list(data.get("input_embeddings_path", None)) + self.input_embeddings_path = [get_real_path(path, self.config_root_path) for path in self.input_embeddings_path] + pad_list(self.input_embeddings_path, self.input_features, None) + + def load_input_embeddings(self): + self.input_embeddings = [None if path is None else get_pretrained_embeddings(path) for path in self.input_embeddings_path] + + def load_input_clear_text_values(self, data): + self.input_clear_text = get_value_as_bool_list(data.get("input_clear_text", None)) + pad_list(self.input_clear_text, self.input_features, False) + + def load_output_vocabulary(self): + for i in range(0, self.output_features): + self.output_vocabulary_sizes.append(get_vocabulary_size(self.config_root_path + "/output_vocabulary" + str(i))) + + def load_translation_output_vocabulary(self): + for i in range(self.output_translations): + vocabs: List[List[str]] = [] + vocab_sizes: List[int] = [] + for j in range(self.output_translation_features): + vocab = get_vocabulary(self.config_root_path + "/output_translation" + str(i) + "_vocabulary" + str(j)) + vocabs.append(vocab) + vocab_sizes.append(len(vocab)) + self.output_translation_vocabularies.append(vocabs) + self.output_translation_vocabulary_sizes.append(vocab_sizes) + + +def get_real_path(path, root_path): + if path is None: + return None + elif os.path.isabs(path): + return path + else: + return root_path + "/" + path diff --git a/python/getalp/wsd/loss/__init__.py b/python/getalp/wsd/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/getalp/wsd/loss/label_smoothed_cross_entropy.py b/python/getalp/wsd/loss/label_smoothed_cross_entropy.py new file mode 100644 index 0000000..cab90c8 --- /dev/null +++ b/python/getalp/wsd/loss/label_smoothed_cross_entropy.py @@ -0,0 +1,24 @@ +from torch.nn import Module +from torch.nn.functional import log_softmax +from torch import Tensor + + +class LabelSmoothedCrossEntropyCriterion(Module): + + def __init__(self, label_smoothing=0.1, ignore_index=-100): + super().__init__() + self.eps = label_smoothing + self.ignore_index = ignore_index + + # output : N x vocab_size + # target : N [0, vocab_size[ + def forward(self, output: Tensor, target: Tensor): + output = log_softmax(output, dim=1) + non_pad_mask = target.ne(self.ignore_index) + nll_loss = -output.gather(dim=1, index=target.unsqueeze(1))[non_pad_mask] + smooth_loss = -output.sum(dim=1, keepdim=True)[non_pad_mask] + nll_loss = nll_loss.mean() + smooth_loss = smooth_loss.mean() + eps_i = self.eps / output.size(1) + loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss + return loss diff --git a/python/getalp/wsd/model.py b/python/getalp/wsd/model.py index 10b5193..b6e4522 100644 --- a/python/getalp/wsd/model.py +++ b/python/getalp/wsd/model.py @@ -1,77 +1,154 @@ -from torch.nn import Module, Embedding, LSTM, Dropout, Linear, CrossEntropyLoss -from torch.nn.functional import softmax +from torch.nn import Module, ModuleList, CrossEntropyLoss +from torch.nn.functional import log_softmax from torch.optim import Adam import torch -import numpy as np -from getalp.wsd.modules.attention import Attention +from getalp.wsd.loss.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion +from getalp.wsd.modules.embeddings import get_elmo_embeddings, get_bert_embeddings, EmbeddingsLUT +from getalp.wsd.modules.encoders import EncoderBase, EncoderLSTM, EncoderTransformer +from getalp.wsd.modules.decoders import DecoderClassify, DecoderTranslateTransformer +from getalp.wsd.optim import SchedulerFixed, SchedulerNoam from getalp.wsd.model_config import ModelConfig +from getalp.wsd.data_config import DataConfig +from getalp.wsd.torch_utils import default_device import random - -cpu_device = torch.device("cpu") - -if torch.cuda.is_available(): - gpu_device = torch.device("cuda:0") - default_device = gpu_device -else: - default_device = cpu_device +from typing import List, Union class Model(object): - def __init__(self): - self.config: ModelConfig = ModelConfig() + def __init__(self, config: ModelConfig): + self.config: ModelConfig = config self.backend: TorchModel = None + self.optimizer: TorchModelOptimizer = TorchModelOptimizer() + self.classification_criterion = CrossEntropyLoss(ignore_index=0) + self.translation_criterion = LabelSmoothedCrossEntropyCriterion(label_smoothing=0.1, ignore_index=0) def create_model(self): - self.backend = TorchModel(self.config) + self.backend = TorchModel(self.config, self.config.data_config) + self.optimizer.set_backend(self.backend) + + def get_number_of_parameters(self, filter_requires_grad: bool): + raw_count = sum(p.numel() for p in self.backend.parameters() if not filter_requires_grad or p.requires_grad) + if raw_count > 1000000: + str_count = "%.2f" % (float(raw_count) / float(1000000)) + "M" + elif raw_count > 1000: + str_count = "%.2f" % (float(raw_count) / float(1000)) + "K" + else: + str_count = str(raw_count) + return str_count + + def set_adam_parameters(self, adam_beta1: float, adam_beta2: float, adam_eps: float): + self.optimizer.set_adam_parameters(adam_beta1=adam_beta1, adam_beta2=adam_beta2, adam_eps=adam_eps) - def set_learning_rate(self, learning_rate): - self.backend.optimizer = Adam(filter(lambda p: p.requires_grad, self.backend.parameters()), lr=learning_rate) + def set_lr_scheduler(self, lr_scheduler: str, fixed_lr: float, warmup: int, model_size: int): + self.optimizer.set_scheduler(scheduler=lr_scheduler, fixed_lr=fixed_lr, warmup=warmup, model_size=model_size) + + def update_learning_rate(self, step): + self.optimizer.update_learning_rate(step) + + def set_beam_size(self, beam_size: int): + if self.backend.decoder_translation is not None: + self.backend.decoder_translation.beam_size = beam_size def load_model_weights(self, file_path): - self.backend.load_state_dict(torch.load(file_path, map_location=str(default_device)), strict=True) + save = torch.load(file_path, map_location=str(default_device)) + self.config.load_from_serializable_data(save["config"]) + self.create_model() + self.backend.encoder.load_state_dict(save["backend_encoder"], strict=True) + if self.backend.decoder_classification is not None: + self.backend.decoder_classification.load_state_dict(save["backend_decoder_classification"], strict=True) + if self.backend.decoder_translation is not None: + self.backend.decoder_translation.load_state_dict(save["backend_decoder_translation"], strict=True) + self.optimizer.adam.load_state_dict(save["optimizer"]) + for i in range(len(self.backend.embeddings)): + if not self.backend.embeddings[i].is_fixed(): + self.backend.embeddings[i].load_state_dict(save["backend_embeddings" + str(i)], strict=True) def save_model_weights(self, file_path): - torch.save(self.backend.state_dict(), file_path) + save = {"config": self.config.get_serializable_data(), + "backend_encoder": self.backend.encoder.state_dict(), + "optimizer": self.optimizer.adam.state_dict()} + if self.backend.decoder_classification is not None: + save["backend_decoder_classification"] = self.backend.decoder_classification.state_dict() + if self.backend.decoder_translation is not None: + save["backend_decoder_translation"] = self.backend.decoder_translation.state_dict() + for i in range(len(self.backend.embeddings)): + if not self.backend.embeddings[i].is_fixed(): + save["backend_embeddings" + str(i)] = self.backend.embeddings[i].state_dict() + torch.save(save, file_path) def begin_train_on_batch(self): - self.backend.optimizer.zero_grad() + self.optimizer.adam.zero_grad() - def train_on_batch(self, batch_x, batch_y, batch_z): + def train_on_batch(self, batch_x, batch_y, batch_tt): self.backend.train() - self.zero_random_tokens(batch_x, self.config.word_dropout_rate) - losses, total_loss = self.forward_and_compute_loss(batch_x, batch_y, batch_z) + losses, total_loss = self.forward_and_compute_loss(batch_x, batch_y, batch_tt) total_loss.backward() return losses def end_train_on_batch(self): - self.backend.optimizer.step() + self.optimizer.adam.step() + + @torch.no_grad() + def predict_wsd_on_batch(self, batch_x): + self.backend.eval() + batch_x = self.convert_batch_x_on_default_device(batch_x) + outputs_classification, _ = self.backend(batch_x, [[None]]) + outputs = outputs_classification[0] + return outputs - def predict_model_on_batch(self, batch_x): + @torch.no_grad() + def predict_all_features_on_batch(self, batch_x): self.backend.eval() - batch_x = self.convert_batch_on_default_device(batch_x) - output = self.backend(batch_x) - output = softmax(output[0], dim=2) - return output.detach().cpu().numpy() + batch_x = self.convert_batch_x_on_default_device(batch_x) + outputs_classification, _ = self.backend(batch_x, [[None]]) + outputs = outputs_classification + return outputs - def predict_model_on_sample(self, sample_x): - return self.predict_model_on_batch([np.expand_dims(x, axis=0) for x in sample_x])[0] + @torch.no_grad() + def predict_translation_on_batch(self, batch_x): + self.backend.eval() + batch_x = self.convert_batch_x_on_default_device(batch_x) + _, outputs_translation = self.backend(batch_x, [[None]]) + outputs = outputs_translation[0][0] + return outputs - def test_model_on_batch(self, batch_x, batch_y, batch_z): + @torch.no_grad() + def predict_wsd_and_translation_on_batch(self, batch_x): self.backend.eval() - losses, total_loss = self.forward_and_compute_loss(batch_x, batch_y, batch_z) + batch_x = self.convert_batch_x_on_default_device(batch_x) + outputs_classification, outputs_translation = self.backend(batch_x, [[None]]) + outputs_classification = outputs_classification[0] + outputs_translation = outputs_translation[0][0] + return outputs_classification, outputs_translation + + @torch.no_grad() + def test_model_on_batch(self, batch_x, batch_y, batch_tt): + self.backend.eval() + losses, total_loss = self.forward_and_compute_loss(batch_x, batch_y, batch_tt) return losses - def forward_and_compute_loss(self, batch_x, batch_y, batch_z): - batch_x = self.convert_batch_on_default_device(batch_x) - outputs = self.backend(batch_x) + def forward_and_compute_loss(self, batch_x, batch_y, batch_tt): + batch_x = self.convert_batch_x_on_default_device(batch_x) + batch_y = self.convert_batch_y_on_default_device(batch_y) + batch_tt = self.convert_batch_tt_on_default_device(batch_tt) + outputs_classification, outputs_translation = self.backend(batch_x, batch_tt) losses = [] total_loss = None for i in range(len(batch_y)): - batch_y[i] = torch.from_numpy(batch_y[i]).to(default_device) - feature_outputs = outputs[i].view(-1, outputs[i].shape[2]) + feature_outputs = outputs_classification[i].view(-1, outputs_classification[i].shape[2]) feature_batch_y = batch_y[i].view(-1) - loss = self.backend.criterion(feature_outputs, feature_batch_y) + loss = self.classification_criterion(feature_outputs, feature_batch_y) + losses.append(loss.item()) + if total_loss is None: + total_loss = loss + else: + total_loss = total_loss + loss + if len(batch_tt) > 0: + outputs_translation[0][0] = outputs_translation[0][0].contiguous() + translation_output = outputs_translation[0][0].view(-1, outputs_translation[0][0].shape[2]) + translation_batch_tt = batch_tt[0][0].view(-1) + loss = self.translation_criterion(translation_output, translation_batch_tt) losses.append(loss.item()) if total_loss is None: total_loss = loss @@ -79,105 +156,122 @@ def forward_and_compute_loss(self, batch_x, batch_y, batch_z): total_loss = total_loss + loss return losses, total_loss + def convert_batch_x_on_default_device(self, batch_x): + return [x.to(default_device) if not self.config.data_config.input_clear_text[i] else x for i, x in enumerate(batch_x)] + @staticmethod - def convert_batch_on_default_device(batch): - for i in range(len(batch)): - batch[i] = torch.from_numpy(batch[i]).to(default_device) - return batch + def convert_batch_y_on_default_device(batch_y): + return [x.to(default_device) for x in batch_y] + + @staticmethod + def convert_batch_tt_on_default_device(batch_tt): + return [[x.to(default_device) for x in y] for y in batch_tt] @staticmethod def zero_random_tokens(batch, proba): - if proba is None: return + if proba is None: + return for i in range(len(batch[0])): if random.random() < proba: for j in range(len(batch)): batch[j][i] = 0 + # samples : sample x xyztt x feat x batch x seq + def preprocess_samples(self, samples): + for sample in samples: + sample[0][0], new_size, indices = self.backend.embeddings[0].preprocess_sample_first(sample[0][0]) + for i in range(self.config.data_config.input_features): + sample[0][i] = self.backend.embeddings[i].preprocess_sample_next(sample[0][i], new_size, indices) -class TorchModel(Module): - def __init__(self, config): - super(TorchModel, self).__init__() - - resulting_embeddings_size = 0 - self.embeddings = [] - for i in range(0, config.input_features): - resulting_embeddings_size += config.input_embeddings_sizes[i] - if config.input_embeddings[i] is not None: - module = Embedding.from_pretrained(embeddings=torch.from_numpy(config.input_embeddings[i]).to(default_device), freeze=True) - if not config.legacy_model: - self.add_module("input_embedding" + str(i), module) - else: - module = Embedding(num_embeddings=config.input_vocabulary_sizes[i], embedding_dim=config.input_embeddings_sizes[i], padding_idx=0) - self.add_module("input_embedding" + str(i), module) - self.embeddings.append(module) +class TorchModelOptimizer(object): - if config.linear_before_lstm: - self.linear_before_lstm = Linear(in_features=resulting_embeddings_size, out_features=resulting_embeddings_size) - self.add_module("linear_before_lstm", self.linear_before_lstm) + def __init__(self): + super().__init__() + self.adam_beta1: float = None + self.adam_beta2: float = None + self.adam_eps: float = None + self.adam: Adam = None + self.scheduler: Union[SchedulerFixed, SchedulerNoam] = None + + def set_adam_parameters(self, adam_beta1: float, adam_beta2: float, adam_eps: float): + self.adam_beta1 = adam_beta1 + self.adam_beta2 = adam_beta2 + self.adam_eps = adam_eps + + def set_scheduler(self, scheduler: str, fixed_lr: float, warmup: int, model_size: int): + if scheduler == "noam": + self.scheduler = SchedulerNoam(warmup=warmup, model_size=model_size) + else: # if scheduler == "fixed": + self.scheduler = SchedulerFixed(fixed_lr=fixed_lr) + + def set_backend(self, backend: Module): + if self.adam_beta1 is None or self.adam_beta2 is None or self.adam_eps is None: + self.adam = Adam(filter(lambda p: p.requires_grad, backend.parameters())) else: - self.linear_before_lstm = None + self.adam = Adam(filter(lambda p: p.requires_grad, backend.parameters()), betas=(self.adam_beta1, self.adam_beta2), eps=self.adam_eps) - if config.dropout_rate_before_lstm is not None: - self.dropout_before_lstm = Dropout(p=config.dropout_rate_before_lstm) - self.add_module("dropout_before_lstm", self.dropout_before_lstm) - else: - self.dropout_before_lstm = None + def update_learning_rate(self, step: int): + self.set_learning_rate(self.scheduler.get_learning_rate(step)) - self.lstm = LSTM(input_size=resulting_embeddings_size, hidden_size=config.lstm_units_size, - num_layers=config.lstm_layers, bidirectional=True, batch_first=True) - self.add_module("lstm", self.lstm) + def set_learning_rate(self, learning_rate: float): + for param_group in self.adam.param_groups: + param_group['lr'] = learning_rate - if config.dropout_rate is not None: - self.dropout = Dropout(p=config.dropout_rate) - self.add_module("dropout", self.dropout) - else: - self.dropout = None - if config.attention_layer is True: - self.attention = Attention(in_features=config.lstm_units_size * 2) - self.add_module("attention", self.attention) - next_layer_in_features = config.lstm_units_size * 4 +class TorchModel(Module): + + def __init__(self, config: ModelConfig, data_config: DataConfig): + super().__init__() + self.config = config + + self.embeddings: List[Module] = [] + for i in range(0, data_config.input_features): + if config.input_elmo_path[i] is not None: + module = get_elmo_embeddings(elmo_path=config.input_elmo_path[i], input_vocabulary=data_config.input_vocabularies[i], clear_text=data_config.input_clear_text[i]) + elif config.input_bert_path[i] is not None: + module = get_bert_embeddings(bert_path=config.input_bert_path[i], clear_text=data_config.input_clear_text[i]) + else: + module = EmbeddingsLUT(input_embeddings=data_config.input_embeddings[i], input_vocabulary_size=data_config.input_vocabulary_sizes[i], input_embeddings_size=config.input_embeddings_sizes[i], clear_text=data_config.input_clear_text[i]) + config.input_embeddings_sizes[i] = module.get_output_dim() + self.add_module("input_embedding" + str(i), module) + self.embeddings.append(module) + + if config.encoder_type == "lstm": + self.encoder = EncoderLSTM(config) + elif config.encoder_type == "transformer": + self.encoder = EncoderTransformer(config) else: - self.attention = None - next_layer_in_features = config.lstm_units_size * 2 - - self.output_linears = [] - if config.legacy_model: - module = Linear(in_features=next_layer_in_features, out_features=config.output_vocabulary_sizes[0]) - self.output_linears.append(module) - self.add_module("linear", module) + self.encoder = EncoderBase(config) + + if data_config.output_features > 0: + self.decoder_classification = DecoderClassify(config, data_config) else: - for i in range(0, config.output_features): - module = Linear(in_features=next_layer_in_features, out_features=config.output_vocabulary_sizes[i]) - self.output_linears.append(module) - self.add_module("output_linear" + str(i), module) + self.decoder_classification = None + if data_config.output_translations > 0: + self.decoder_translation = DecoderTranslateTransformer(config, data_config, self.embeddings[0]) + else: + self.decoder_translation = None if torch.cuda.is_available(): self.cuda() - self.criterion = CrossEntropyLoss(ignore_index=0) - - self.optimizer = Adam(filter(lambda p: p.requires_grad, self.parameters())) - - - def forward(self, inputs): - for i in range(0, len(inputs)): - inputs[i] = self.embeddings[i](inputs[i]) - inputs = torch.cat(inputs, dim=2) - if self.linear_before_lstm is not None: - inputs = self.linear_before_lstm(inputs) - if self.dropout_before_lstm is not None: - inputs = self.dropout_before_lstm(inputs) - inputs, _ = self.lstm(inputs) - if self.dropout is not None: - inputs = self.dropout(inputs) - if self.attention is not None: - inputs = self.attention(inputs) - outputs = [] - for i in range(0, len(self.output_linears)): - outputs.append(self.output_linears[i](inputs)) - return outputs - - + # inputs: + # - List[Union[LongTensor, List[str]]] features x batch x seq_in (input features) + # - List[List[LongTensor]] translations x features x batch x seq_out (output translations) (training only) + # outputs: + # - List[FloatTensor] features x batch x seq_in x vocab_out (output features) + # - List[List[FloatTensor]] translations x features x batch x seq_out x vocab_out (output translations) + def forward(self, inputs, translation_true_output): + inputs[0], pad_mask, token_indices = self.embeddings[0](inputs[0]) + for i in range(1, len(inputs)): + inputs[i], _, _ = self.embeddings[i](inputs[i]) + inputs = self.encoder(inputs, pad_mask) + classification_outputs = [] + if self.decoder_classification is not None: + classification_outputs = self.decoder_classification(inputs, token_indices) + translation_outputs = [] + if self.decoder_translation is not None: + translation_outputs = self.decoder_translation(inputs, pad_mask, translation_true_output[0][0]) + return classification_outputs, [[translation_outputs]] diff --git a/python/getalp/wsd/model_config.py b/python/getalp/wsd/model_config.py index 3ce65bd..dbc705f 100644 --- a/python/getalp/wsd/model_config.py +++ b/python/getalp/wsd/model_config.py @@ -1,83 +1,134 @@ import json -from getalp.wsd.common import get_pretrained_embeddings, get_vocabulary_size +from getalp.common.common import get_value_as_int_list, pad_list, get_value_as_str_list +from getalp.wsd.data_config import DataConfig from typing import List -import numpy as np -import os + class ModelConfig(object): - def __init__(self): - self.config_root_path: str = str() - self.input_features: int = int() - self.input_vocabulary_sizes: List[int] = [] - self.input_embeddings: List[np.array] = [] + def __init__(self, data_config: DataConfig): + self.data_config: DataConfig = data_config self.input_embeddings_sizes: List[int] = [] - self.output_features: int = int() - self.output_vocabulary_sizes: List[int] = [] - self.lstm_units_size: int = int() - self.lstm_layers: int = int() - self.linear_before_lstm: bool = bool() - self.dropout_rate_before_lstm: float = float() - self.dropout_rate: float = float() - self.word_dropout_rate: float = float() - self.attention_layer: bool = bool() - self.legacy_model: bool = bool() + self.input_elmo_path: List[str] = [] + self.input_bert_path: List[str] = [] + self.input_flair_path: List[str] = [] + self.input_word_dropout_rate: float = float() + self.input_resize: List[int] = [] + self.input_apply_linear: bool = bool() + self.input_linear_size: int = int() + self.input_dropout_rate: float = float() + self.encoder_type: str = str() + self.encoder_lstm_hidden_size: int = int() + self.encoder_lstm_layers: int = int() + self.encoder_lstm_dropout: float = float() + self.encoder_transformer_hidden_size: int = int() + self.encoder_transformer_layers: int = int() + self.encoder_transformer_heads: int = int() + self.encoder_transformer_dropout: float = float() + self.encoder_transformer_positional_encoding: bool = bool() + self.encoder_transformer_scale_embeddings: bool = bool() + self.encoder_output_size: int = int() + self.decoder_translation_transformer_hidden_size = int() + self.decoder_translation_transformer_layers = int() + self.decoder_translation_transformer_heads = int() + self.decoder_translation_transformer_dropout = float() + self.decoder_translation_scale_embeddings: bool = bool() + self.decoder_translation_share_embeddings = bool() + self.decoder_translation_share_encoder_embeddings = bool() + self.decoder_translation_tokenizer_bert = str() def load_from_file(self, file_path): file = open(file_path, "r") data = json.load(file) file.close() - self.load_from_serializable_data(data, os.path.dirname(os.path.abspath(file_path))) + self.load_from_serializable_data(data) + + def load_from_serializable_data(self, data): + self.set_input_elmo_path(data.get("input_elmo_path", None)) + self.set_input_bert_model(data.get("input_bert_path", None)) + self.set_input_flair_model(data.get("input_flair_path", None)) + self.load_input_embeddings_sizes(data) + self.input_word_dropout_rate = data.get("input_word_dropout_rate", None) + self.input_resize = data.get("input_resize", [None for _ in range(self.data_config.input_features)]) + self.input_apply_linear = data.get("input_apply_linear", False) + self.input_linear_size = data.get("input_linear_size", None) + self.input_dropout_rate = data.get("input_dropout_rate", None) + self.encoder_type = data.get("encoder_type", "lstm") + self.encoder_lstm_hidden_size = data.get("encoder_lstm_hidden_size", 1000) + self.encoder_lstm_layers = data.get("encoder_lstm_layers", 1) + self.encoder_lstm_dropout = data.get("encoder_lstm_dropout", 0.5) + self.encoder_transformer_hidden_size = data.get("encoder_transformer_hidden_size", 512) + self.encoder_transformer_layers = data.get("encoder_transformer_layers", 6) + self.encoder_transformer_heads = data.get("encoder_transformer_heads", 8) + self.encoder_transformer_dropout = data.get("encoder_transformer_dropout", 0.1) + self.encoder_transformer_positional_encoding = data.get("encoder_transformer_positional_encoding", True) + self.encoder_transformer_scale_embeddings = data.get("encoder_transformer_scale_embeddings", True) + self.decoder_translation_transformer_hidden_size = data.get("decoder_translation_transformer_hidden_size", 512) + self.decoder_translation_transformer_layers = data.get("decoder_translation_transformer_layers", 6) + self.decoder_translation_transformer_heads = data.get("decoder_translation_transformer_heads", 8) + self.decoder_translation_transformer_dropout = data.get("decoder_translation_transformer_dropout", 0.1) + self.decoder_translation_scale_embeddings = data.get("decoder_translation_scale_embeddings", True) + self.decoder_translation_share_embeddings = data.get("decoder_translation_share_embeddings", False) + self.decoder_translation_share_encoder_embeddings = data.get("decoder_translation_share_encoder_embeddings", False) + self.decoder_translation_tokenizer_bert = data.get("decoder_translation_tokenizer_bert", None) - def load_from_serializable_data(self, data, config_root_path): - self.config_root_path = config_root_path - self.input_features = data["input_features"] - self.load_input_vocabularies() - self.load_input_embeddings(data) - self.output_features = data["output_features"] - self.load_output_vocabulary() - self.lstm_units_size = data["lstm_units_size"] - self.lstm_layers = data["lstm_layers"] - self.linear_before_lstm = data["linear_before_lstm"] - self.dropout_rate_before_lstm = data["dropout_rate_before_lstm"] - self.dropout_rate = data["dropout_rate"] - self.word_dropout_rate = data["word_dropout_rate"] - self.attention_layer = data["attention_layer"] - self.legacy_model = data["legacy_model"] + def load_input_embeddings_sizes(self, data): + self.input_embeddings_sizes = get_value_as_int_list(data.get("input_embeddings_size", None)) + pad_list(self.input_embeddings_sizes, self.data_config.input_features, 300) + self.reset_input_embeddings_sizes() - def load_input_vocabularies(self): - for i in range(0, self.input_features): - self.input_vocabulary_sizes.append(get_vocabulary_size(self.config_root_path + "/input_vocabulary" + str(i))) + def set_input_elmo_path(self, elmo_path): + self.input_elmo_path = get_value_as_str_list(elmo_path) + pad_list(self.input_elmo_path, self.data_config.input_features, None) + self.reset_input_embeddings_sizes() - def load_input_embeddings(self, data): - input_embeddings_paths = data["input_embeddings_path"] - if input_embeddings_paths is None: - input_embeddings_paths = [] - elif isinstance(input_embeddings_paths, str): - input_embeddings_paths = [input_embeddings_paths] - self.input_embeddings = [] - for input_embeddings_path in input_embeddings_paths: - if input_embeddings_path is None: - self.input_embeddings.append(None) - elif os.path.isabs(input_embeddings_path): - self.input_embeddings.append(get_pretrained_embeddings(input_embeddings_path)) - else: - self.input_embeddings.append(get_pretrained_embeddings(self.config_root_path + "/" + input_embeddings_path)) - for i in range(len(self.input_embeddings), len(self.input_vocabulary_sizes)): - self.input_embeddings.append(None) + def set_input_bert_model(self, bert_model): + self.input_bert_path = get_value_as_str_list(bert_model) + pad_list(self.input_bert_path, self.data_config.input_features, None) + self.reset_input_embeddings_sizes() - self.input_embeddings_sizes = data["input_embeddings_size"] - if self.input_embeddings_sizes is None: - self.input_embeddings_sizes = [] - elif isinstance(self.input_embeddings_sizes, int): - self.input_embeddings_sizes = [self.input_embeddings_sizes] - for i in range(len(self.input_embeddings_sizes), len(self.input_embeddings)): - self.input_embeddings_sizes.append(None) - for i in range(0, len(self.input_embeddings_sizes)): - if self.input_embeddings[i] is not None: - self.input_embeddings_sizes[i] = self.input_embeddings[i].shape[1] + def set_input_flair_model(self, flair_model): + self.input_flair_path = get_value_as_str_list(flair_model) + pad_list(self.input_flair_path, self.data_config.input_features, None) + self.reset_input_embeddings_sizes() - def load_output_vocabulary(self): - for i in range(0, self.output_features): - self.output_vocabulary_sizes.append(get_vocabulary_size(self.config_root_path + "/output_vocabulary" + str(i))) + def reset_input_embeddings_sizes(self): + for i in range(len(self.input_embeddings_sizes)): + if self.data_config.input_embeddings[i] is not None: + self.input_embeddings_sizes[i] = self.data_config.input_embeddings[i].shape[1] + if self.input_elmo_path[i] is not None \ + or self.input_bert_path[i] is not None \ + or self.input_flair_path[i] is not None: + self.input_embeddings_sizes[i] = None + def get_serializable_data(self): + data = { + "input_embeddings_size": self.input_embeddings_sizes, + "input_elmo_path": self.input_elmo_path, + "input_bert_path": self.input_bert_path, + "input_flair_path": self.input_flair_path, + "input_word_dropout_rate": self.input_word_dropout_rate, + "input_resize": self.input_resize, + "input_apply_linear": self.input_apply_linear, + "input_linear_size": self.input_linear_size, + "input_dropout_rate": self.input_dropout_rate, + "encoder_type": self.encoder_type, + "encoder_lstm_hidden_size": self.encoder_lstm_hidden_size, + "encoder_lstm_layers": self.encoder_lstm_layers, + "encoder_lstm_dropout": self.encoder_lstm_dropout, + "encoder_transformer_hidden_size": self.encoder_transformer_hidden_size, + "encoder_transformer_layers": self.encoder_transformer_layers, + "encoder_transformer_heads": self.encoder_transformer_heads, + "encoder_transformer_dropout": self.encoder_transformer_dropout, + "encoder_transformer_positional_encoding": self.encoder_transformer_positional_encoding, + "encoder_transformer_scale_embeddings": self.encoder_transformer_scale_embeddings, + "decoder_translation_transformer_hidden_size": self.decoder_translation_transformer_hidden_size, + "decoder_translation_transformer_layers": self.decoder_translation_transformer_layers, + "decoder_translation_transformer_heads": self.decoder_translation_transformer_heads, + "decoder_translation_transformer_dropout": self.decoder_translation_transformer_dropout, + "decoder_translation_scale_embeddings": self.decoder_translation_scale_embeddings, + "decoder_translation_share_embeddings": self.decoder_translation_share_embeddings, + "decoder_translation_share_encoder_embeddings": self.decoder_translation_share_encoder_embeddings, + "decoder_translation_tokenizer_bert": self.decoder_translation_tokenizer_bert + } + return data diff --git a/python/getalp/wsd/modules/__init__.py b/python/getalp/wsd/modules/__init__.py index e69de29..db4a430 100644 --- a/python/getalp/wsd/modules/__init__.py +++ b/python/getalp/wsd/modules/__init__.py @@ -0,0 +1 @@ +from .positional_encoding import PositionalEncoding diff --git a/python/getalp/wsd/modules/decoders/__init__.py b/python/getalp/wsd/modules/decoders/__init__.py new file mode 100644 index 0000000..d401db9 --- /dev/null +++ b/python/getalp/wsd/modules/decoders/__init__.py @@ -0,0 +1,2 @@ +from .decoder_classify import DecoderClassify +from .decoder_translate_transformer import DecoderTranslateTransformer diff --git a/python/getalp/wsd/modules/decoders/decoder_classify.py b/python/getalp/wsd/modules/decoders/decoder_classify.py new file mode 100644 index 0000000..2ce0d52 --- /dev/null +++ b/python/getalp/wsd/modules/decoders/decoder_classify.py @@ -0,0 +1,37 @@ +from torch.nn import Module, Linear +from getalp.wsd.model_config import ModelConfig +from getalp.wsd.data_config import DataConfig +from getalp.wsd.torch_fix import * +from getalp.wsd.torch_utils import default_device + + +class DecoderClassify(Module): + + def __init__(self, config: ModelConfig, data_config: DataConfig): + super().__init__() + + self.output_linears = [] + for i in range(0, data_config.output_features): + module = Linear(in_features=config.encoder_output_size, out_features=data_config.output_vocabulary_sizes[i]) + self.output_linears.append(module) + self.add_module("output_linear" + str(i), module) + + # input: + # - inputs: FloatTensor - batch x seq_in x hidden + # - token_indices: List[List[int]] - batch x real_seq_in + # output: + # - output: List[FloatTensor] - batch x real_seq_in x out_vocabulary_dim + def forward(self, inputs, token_indices): + if token_indices is not None: + max_length = max([len(seq) for seq in token_indices]) + new_inputs = torch_zeros(inputs.size(0), max_length, inputs.size(2), dtype=torch_float32, device=default_device) + for i in range(len(token_indices)): + for j in range(len(token_indices[i])): + new_inputs[i][j] = inputs[i][token_indices[i][j]] + inputs = new_inputs + outputs = [] + for i in range(0, len(self.output_linears)): + outputs.append(self.output_linears[i](inputs)) + return outputs + + diff --git a/python/getalp/wsd/modules/decoders/decoder_translate_transformer.py b/python/getalp/wsd/modules/decoders/decoder_translate_transformer.py new file mode 100644 index 0000000..d82e5b8 --- /dev/null +++ b/python/getalp/wsd/modules/decoders/decoder_translate_transformer.py @@ -0,0 +1,14 @@ +from torch.nn import Module +import torch +from getalp.wsd.model_config import ModelConfig +from getalp.wsd.data_config import DataConfig + + +class DecoderTranslateTransformer(Module): + + def __init__(self, config: ModelConfig, data_config: DataConfig, encoder_embeddings): + super().__init__() + raise NotImplementedError + + def forward(self, encoder_output: torch.Tensor, pad_mask: torch.Tensor, true_output: torch.Tensor): + raise NotImplementedError diff --git a/python/getalp/wsd/modules/embeddings/__init__.py b/python/getalp/wsd/modules/embeddings/__init__.py new file mode 100644 index 0000000..4677e9e --- /dev/null +++ b/python/getalp/wsd/modules/embeddings/__init__.py @@ -0,0 +1,3 @@ +from .embeddings_lut import EmbeddingsLUT +from .embeddings_bert import EmbeddingsBert, get_bert_embeddings +from .embeddings_elmo import EmbeddingsElmo, get_elmo_embeddings diff --git a/python/getalp/wsd/modules/embeddings/embeddings_bert.py b/python/getalp/wsd/modules/embeddings/embeddings_bert.py new file mode 100644 index 0000000..aad2310 --- /dev/null +++ b/python/getalp/wsd/modules/embeddings/embeddings_bert.py @@ -0,0 +1,87 @@ +from torch.nn import Module, Linear +from getalp.wsd.torch_fix import * +from torch.nn.utils.rnn import pad_sequence +from getalp.wsd.torch_utils import default_device +from typing import List, Union, Dict + + +class EmbeddingsBert(Module): + + def __init__(self, bert_path: str): + super().__init__() + from pytorch_pretrained_bert import BertModel, BertTokenizer + self.bert_embeddings = BertModel.from_pretrained(bert_path) + self.bert_tokenizer = BertTokenizer.from_pretrained(bert_path, do_lower_case=False) + for param in self.bert_embeddings.parameters(): + param.requires_grad = False + self._is_fixed = True + self._output_dim = self.bert_embeddings.config.hidden_size + + # input: + # - sample_x: List[str] - seq_in + # output: + # - sample_x: LongTensor - seq_out + # - new_size: int - seq_out + # - indices: List[int] - seq_in + def preprocess_sample_first(self, sample_x): + seq_token_indices: List[int] = [] + seq_tokens: Union[List[str], torch.Tensor] = [] + current_index = 1 # 0 is [CLS] + for token in sample_x: + subtokens = self.bert_tokenizer.tokenize(token) + seq_token_indices.append(current_index) + current_index += len(subtokens) + for subtoken in subtokens: + seq_tokens.append(subtoken) + seq_tokens = ["[CLS]"] + seq_tokens + ["[SEP]"] + seq_tokens = self.bert_tokenizer.convert_tokens_to_ids(seq_tokens) + seq_tokens = torch_tensor(seq_tokens, dtype=torch_long) + return seq_tokens, seq_tokens.size(0), seq_token_indices + + # input: + # - sample_x: LongTensor - seq_in + # - new_size: int - seq_out + # - indices: List[int] - seq_in + # output: + # - sample_x: Tuple[LongTensor, List[int]] - sample_x, indices + @staticmethod + def preprocess_sample_next(sample_x, new_size, indices): + return sample_x, indices + + # inputs: + # - inputs: List[List[str]] (batch x seq_in) + # output: + # - output: FloatTensor (batch x seq_out x hidden) + # - pad_mask: LongTensor (batch x seq_out) + # - token_indices: List[List[int]] (batch x seq_in) + def forward(self, inputs): + tokens: List[torch.Tensor] = [] + token_indices: List[List[int]] = [] + for seq in inputs: + tokens.append(seq[0].to(default_device)) + token_indices.append(seq[1]) + inputs = tokens + pad_mask = [torch_ones_like(x) for x in inputs] + pad_mask = pad_sequence(pad_mask, batch_first=True, padding_value=0) + inputs = pad_sequence(inputs, batch_first=True, padding_value=0) + inputs, _ = self.bert_embeddings(inputs, attention_mask=pad_mask, output_all_encoded_layers=False) + return inputs, pad_mask, token_indices + + def get_output_dim(self): + return self._output_dim + + def is_fixed(self): + return self._is_fixed + + def get_lut_embeddings(self): + return self.bert_embeddings.embeddings.word_embeddings + + +_bert_embeddings_wrapper: Dict[str, EmbeddingsBert] = {} + + +def get_bert_embeddings(bert_path: str, clear_text: bool): + assert clear_text + if bert_path not in _bert_embeddings_wrapper: + _bert_embeddings_wrapper[bert_path] = EmbeddingsBert(bert_path) + return _bert_embeddings_wrapper[bert_path] diff --git a/python/getalp/wsd/modules/embeddings/embeddings_elmo.py b/python/getalp/wsd/modules/embeddings/embeddings_elmo.py new file mode 100644 index 0000000..9175aa9 --- /dev/null +++ b/python/getalp/wsd/modules/embeddings/embeddings_elmo.py @@ -0,0 +1,79 @@ +from torch.nn import Module +from typing import List, Dict, Tuple +from getalp.wsd.torch_utils import default_device + + +class EmbeddingsElmo(Module): + + # elmo_path = {small, medium, original} + def __init__(self, elmo_path: str, input_vocabulary: List[str], clear_text: bool): + super().__init__() + from allennlp.modules.elmo import Elmo + if elmo_path in _elmo_models_map: + options_file_path, weights_file_path = _elmo_models_map[elmo_path] + else: + options_file_path, weights_file_path = elmo_path + "_options.json", elmo_path + "_weights.hdf5" + self.elmo_embeddings = Elmo(options_file=options_file_path, weight_file=weights_file_path, num_output_representations=1, vocab_to_cache=input_vocabulary) + self.clear_text = clear_text + + # input: + # - sample_x: Union[List[str], LongTensor] - seq_in + # output: + # - sample_x: Union[List[str], LongTensor] - seq_out + # - new_size: int - seq_out + # - indices: List[int] - seq_in + @staticmethod + def preprocess_sample_first(sample_x): + return sample_x, None, None + + # input: + # - sample_x: Union[List[str], LongTensor] - seq_in + # - new_size: int - seq_out + # - indices: List[int] - seq_in + # output: + # - sample_x: Union[List[str], LongTensor] - seq_out + @staticmethod + def preprocess_sample_next(sample_x, new_size, indices): + return sample_x + + # inputs: + # - inputs: Union[List[List[str]], LongTensor] (batch x seq_in) + # output: + # - output: FloatTensor (batch x seq_out x hidden) + # - pad_mask: LongTensor (batch x seq_out) + # - token_indices: List[List[int]] (batch x seq_in) + def forward(self, inputs): + if self.clear_text: + from allennlp.modules.elmo import batch_to_ids + inputs = batch_to_ids(inputs) + inputs = inputs.to(default_device) + return self.elmo_embeddings(inputs)["elmo_representations"][0], None, None + else: + return self.elmo_embeddings(inputs, inputs)["elmo_representations"][0], inputs, None + + def get_output_dim(self): + return self.elmo_embeddings.get_output_dim() + + @staticmethod + def is_fixed(): + return True + + +_elmo_embeddings_wrapper: Dict[Tuple[str, Tuple[str], bool], EmbeddingsElmo] = {} + + +def get_elmo_embeddings(elmo_path: str, input_vocabulary: List[str], clear_text: bool): + hashable_parameters = (elmo_path, tuple(input_vocabulary), clear_text) + if elmo_path not in _elmo_embeddings_wrapper: + _elmo_embeddings_wrapper[hashable_parameters] = EmbeddingsElmo(elmo_path, input_vocabulary, clear_text) + return _elmo_embeddings_wrapper[hashable_parameters] + + +_elmo_models_map: Dict[str, Tuple[str, str]] = { + "small": ("https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json", + "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"), + "medium": ("https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json", + "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"), + "original": ("https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json", + "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5") +} diff --git a/python/getalp/wsd/modules/embeddings/embeddings_lut.py b/python/getalp/wsd/modules/embeddings/embeddings_lut.py new file mode 100644 index 0000000..4c87358 --- /dev/null +++ b/python/getalp/wsd/modules/embeddings/embeddings_lut.py @@ -0,0 +1,62 @@ +from torch.nn import Module, Embedding +from getalp.wsd.common import pad_token_index, unk_token_index +from getalp.wsd.torch_fix import * +from getalp.wsd.torch_utils import default_device + + +class EmbeddingsLUT(Module): + + def __init__(self, input_embeddings, input_vocabulary_size, input_embeddings_size, clear_text): + super().__init__() + assert not clear_text + if input_embeddings is not None: + self.lut_embeddings = Embedding.from_pretrained(embeddings=input_embeddings, freeze=True) + self._is_fixed = True + else: + self.lut_embeddings = Embedding(num_embeddings=input_vocabulary_size, embedding_dim=input_embeddings_size, padding_idx=pad_token_index) + self._is_fixed = False + self._output_dim = input_embeddings_size + + # input: + # - sample_x: LongTensor - seq_in + # output: + # - sample_x: LongTensor - seq_out + # - new_size: int - seq_out + # - indices: List[int] - seq_in + @staticmethod + def preprocess_sample_first(sample_x): + return sample_x, None, None + + # input: + # - sample_x: LongTensor - seq_in + # - new_size: int - seq_out + # - indices: List[int] - seq_in + # output: + # - sample_x: LongTensor - seq_out + @staticmethod + def preprocess_sample_next(sample_x, new_size, indices): + if indices is None: + return sample_x + new_sample_x = torch_full([new_size], fill_value=unk_token_index, dtype=torch_long) + for i in range(len(indices)): + new_sample_x[indices[i]] = sample_x[i] + return new_sample_x + + # inputs: + # - inputs: LongTensor (batch x seq_in) + # output: + # - output: FloatTensor (batch x seq_out x hidden) + # - pad_mask: LongTensor (batch x seq_out) + # - token_indices: List[List[int]] (batch x seq_in) + def forward(self, inputs): + embeddings = self.lut_embeddings(inputs) + return embeddings, inputs, None + + def get_output_dim(self): + return self._output_dim + + def is_fixed(self): + return self._is_fixed + + def get_lut_embeddings(self): + return self.lut_embeddings diff --git a/python/getalp/wsd/modules/encoders/__init__.py b/python/getalp/wsd/modules/encoders/__init__.py new file mode 100644 index 0000000..5f97cf7 --- /dev/null +++ b/python/getalp/wsd/modules/encoders/__init__.py @@ -0,0 +1,3 @@ +from .encoder_base import EncoderBase +from .encoder_lstm import EncoderLSTM +from .encoder_transformer import EncoderTransformer diff --git a/python/getalp/wsd/modules/encoders/encoder_base.py b/python/getalp/wsd/modules/encoders/encoder_base.py new file mode 100644 index 0000000..97c87ed --- /dev/null +++ b/python/getalp/wsd/modules/encoders/encoder_base.py @@ -0,0 +1,59 @@ +from torch.nn import Module, Dropout, Linear, ModuleList +from getalp.wsd.torch_fix import torch_cat +from getalp.wsd.model_config import ModelConfig + + +class EncoderBase(Module): + + def __init__(self, config: ModelConfig): + super().__init__() + + self.input_resize = None + for i in range(config.data_config.input_features): + if config.input_resize[i] is not None and self.input_resize is None: + self.input_resize = ModuleList() + + if self.input_resize is not None: + self.resulting_embeddings_size = 0 + for i in range(config.data_config.input_features): + if config.input_resize[i] is not None: + self.input_resize.append(Linear(in_features=config.input_embeddings_sizes[i], out_features=config.input_resize[i])) + self.resulting_embeddings_size += config.input_resize[i] + else: + self.input_resize.append(None) + self.resulting_embeddings_size += config.input_embeddings_sizes[i] + else: + self.resulting_embeddings_size = sum(config.input_embeddings_sizes) + + if config.input_apply_linear: + if config.input_linear_size is None: + self.input_linear = Linear(in_features=self.resulting_embeddings_size, out_features=self.resulting_embeddings_size) + else: + self.input_linear = Linear(in_features=self.resulting_embeddings_size, out_features=config.input_linear_size) + self.resulting_embeddings_size = config.input_linear_size + else: + self.input_linear = None + + if config.input_dropout_rate is not None: + self.input_dropout = Dropout(p=config.input_dropout_rate) + else: + self.input_dropout = None + + config.encoder_output_size = self.resulting_embeddings_size + + # input: + # - embeddings: List[FloatTensor] - features x batch x seq x hidden + # - pad_mask: LongTensor - batch x seq + # output: + # - output FloatTensor - batch x seq x hidden + def forward(self, embeddings, pad_mask): + if self.input_resize is not None: + for i in range(len(embeddings)): + if self.input_resize[i] is not None: + embeddings[i] = self.input_resize[i](embeddings[i]) + embeddings = torch_cat(embeddings, dim=2) + if self.input_linear is not None: + embeddings = self.input_linear(embeddings) + if self.input_dropout is not None: + embeddings = self.input_dropout(embeddings) + return embeddings diff --git a/python/getalp/wsd/modules/encoders/encoder_lstm.py b/python/getalp/wsd/modules/encoders/encoder_lstm.py new file mode 100644 index 0000000..a4e89e8 --- /dev/null +++ b/python/getalp/wsd/modules/encoders/encoder_lstm.py @@ -0,0 +1,37 @@ +from torch.nn import Module, LSTM, Dropout +from getalp.wsd.model_config import ModelConfig +from getalp.wsd.modules.encoders.encoder_base import EncoderBase + + +class EncoderLSTM(Module): + + def __init__(self, config: ModelConfig): + super().__init__() + + self.base = EncoderBase(config) + + if config.encoder_lstm_layers > 0: + self.lstm = LSTM(input_size=self.base.resulting_embeddings_size, hidden_size=config.encoder_lstm_hidden_size, + num_layers=config.encoder_lstm_layers, bidirectional=True, batch_first=True) + config.encoder_output_size = config.encoder_lstm_hidden_size * 2 + else: + self.lstm = None + config.encoder_output_size = self.base.resulting_embeddings_size + + if config.encoder_lstm_dropout is not None: + self.dropout = Dropout(p=config.encoder_lstm_dropout) + else: + self.dropout = None + + # input: + # - embeddings List[FloatTensor] - features x batch x seq x hidden + # - pad_mask LongTensor - batch x seq + # output: + # - output FloatTensor - batch x seq x hidden + def forward(self, embeddings, pad_mask): + embeddings = self.base(embeddings, pad_mask) + if self.lstm is not None: + embeddings, (_, _) = self.lstm(embeddings) + if self.dropout is not None: + embeddings = self.dropout(embeddings) + return embeddings diff --git a/python/getalp/wsd/modules/encoders/encoder_transformer.py b/python/getalp/wsd/modules/encoders/encoder_transformer.py new file mode 100644 index 0000000..aa7af20 --- /dev/null +++ b/python/getalp/wsd/modules/encoders/encoder_transformer.py @@ -0,0 +1,51 @@ +from torch.nn import Module, LayerNorm, ModuleList, Dropout +from getalp.wsd.model_config import ModelConfig +from onmt.encoders.transformer import TransformerEncoderLayer +from getalp.wsd.modules.encoders.encoder_base import EncoderBase +from getalp.wsd.common import pad_token_index +from getalp.wsd.modules import PositionalEncoding +import math + + +class EncoderTransformer(Module): + + def __init__(self, config: ModelConfig): + super().__init__() + + self.base = EncoderBase(config) + + if config.encoder_transformer_positional_encoding: + self.positional_encoding = PositionalEncoding(self.base.resulting_embeddings_size) + # self.add_module("pe", self.positional_encoding) + else: + self.positional_encoding = None + + if config.encoder_transformer_scale_embeddings: + self.embeddings_scale = math.sqrt(float(self.base.resulting_embeddings_size)) + else: + self.embeddings_scale = None + + self.dropout = Dropout(config.encoder_transformer_dropout) + + self.transformer = ModuleList([TransformerEncoderLayer(self.base.resulting_embeddings_size, config.encoder_transformer_heads, config.encoder_transformer_hidden_size, config.encoder_transformer_dropout) for _ in range(config.encoder_transformer_layers)]) + self.layer_norm = LayerNorm(self.base.resulting_embeddings_size, eps=1e-6) + + config.encoder_output_size = self.base.resulting_embeddings_size + + # input: + # - embeddings List[FloatTensor] - features x batch x seq x hidden + # - pad_mask LongTensor - batch x seq + # output: + # - output FloatTensor - batch x seq x hidden + def forward(self, embeddings, pad_mask): + embeddings = self.base(embeddings, pad_mask) # batch x seq x hidden + if self.embeddings_scale is not None: + embeddings = embeddings * self.embeddings_scale + if self.positional_encoding is not None: + embeddings = embeddings + self.positional_encoding(embeddings.size(1)) + embeddings = self.dropout(embeddings) + pad_mask = pad_mask.eq(pad_token_index).unsqueeze(1) # batch x 1 x seq + for layer in self.transformer: + embeddings = layer(embeddings, pad_mask) + embeddings = self.layer_norm(embeddings) + return embeddings diff --git a/python/getalp/wsd/modules/positional_encoding.py b/python/getalp/wsd/modules/positional_encoding.py new file mode 100644 index 0000000..d2de228 --- /dev/null +++ b/python/getalp/wsd/modules/positional_encoding.py @@ -0,0 +1,28 @@ +from torch.nn import Module +from getalp.wsd.torch_fix import * +import math + + +class PositionalEncoding(Module): + + def __init__(self, input_embeddings_size, max_len=5000): + super().__init__() + pe = torch_zeros(max_len, input_embeddings_size) # max_len x input_embeddings_size + position = torch_arange(start=0, end=max_len, step=1).unsqueeze(1) # max_len x 1 + div_term = torch_exp((torch_arange(start=0, end=input_embeddings_size, step=2, dtype=torch_float32) * -(math.log(10000.0) / input_embeddings_size))) + pe[:, 0::2] = torch_sin(position.float() * div_term) + pe[:, 1::2] = torch_cos(position.float() * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + self.pe = pe + self.input_embeddings_size = input_embeddings_size + + # inputs: + # - int (seq) + # output: + # - FloatTensor (1 x seq x hidden) + def forward(self, seq: int, full: bool = True): + if full: + return self.pe[:, :seq, :] + else: + return self.pe[:, seq, :] diff --git a/python/getalp/wsd/optim/__init__.py b/python/getalp/wsd/optim/__init__.py new file mode 100644 index 0000000..73e175d --- /dev/null +++ b/python/getalp/wsd/optim/__init__.py @@ -0,0 +1,2 @@ +from .scheduler_fixed import SchedulerFixed +from .scheduler_noam import SchedulerNoam diff --git a/python/getalp/wsd/optim/scheduler_fixed.py b/python/getalp/wsd/optim/scheduler_fixed.py new file mode 100644 index 0000000..726392e --- /dev/null +++ b/python/getalp/wsd/optim/scheduler_fixed.py @@ -0,0 +1,10 @@ + + +class SchedulerFixed(object): + + def __init__(self, fixed_lr: float): + super().__init__() + self.fixed_lr = fixed_lr + + def get_learning_rate(self, step: int): + return self.fixed_lr diff --git a/python/getalp/wsd/optim/scheduler_noam.py b/python/getalp/wsd/optim/scheduler_noam.py new file mode 100644 index 0000000..e7355de --- /dev/null +++ b/python/getalp/wsd/optim/scheduler_noam.py @@ -0,0 +1,12 @@ + + +class SchedulerNoam(object): + + def __init__(self, warmup: int, model_size: int): + super().__init__() + self.warmup = warmup + self.model_size = model_size + + def get_learning_rate(self, step: int): + step = max(1, step) + return self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)) diff --git a/python/getalp/wsd/predict.py b/python/getalp/wsd/predict.py index f395070..f2bf066 100644 --- a/python/getalp/wsd/predict.py +++ b/python/getalp/wsd/predict.py @@ -6,11 +6,21 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument('--data_path', required=True, type=str) parser.add_argument('--weights', nargs="+", type=str) + parser.add_argument('--clear_text', action="store_true", help=" ") + parser.add_argument('--batch_size', nargs="?", type=int, default=1, help=" ") + parser.add_argument('--disambiguate', action="store_true", help=" ") + parser.add_argument('--beam_size', nargs="?", type=int, default=1, help=" ") + parser.add_argument('--output_all_features', action="store_true", help=" ") args = parser.parse_args() predicter = Predicter() predicter.training_root_path = args.data_path predicter.ensemble_weights_path = args.weights + predicter.clear_text = args.clear_text + predicter.batch_size = args.batch_size + predicter.disambiguate = args.disambiguate + predicter.beam_size = args.beam_size + predicter.output_all_features = args.output_all_features predicter.predict() diff --git a/python/getalp/wsd/predicter.py b/python/getalp/wsd/predicter.py index 847f295..1f8bc68 100644 --- a/python/getalp/wsd/predicter.py +++ b/python/getalp/wsd/predicter.py @@ -1,6 +1,6 @@ from getalp.wsd.common import * -from getalp.wsd.model import Model, ModelConfig -import numpy as np +from getalp.wsd.model import Model, ModelConfig, DataConfig +from torch.nn.functional import log_softmax import sys @@ -8,85 +8,185 @@ class Predicter(object): def __init__(self): self.training_root_path: str = str() - self.ensemble_weights_path: str = str() + self.ensemble_weights_path: List[str] = [] + self.clear_text: bool = bool() + self.batch_size: int = int() + self.disambiguate: bool = bool() + self.translate: bool = False + self.beam_size: int = int() + self.output_all_features: bool = bool() + self.data_config: DataConfig = None def predict(self): config_file_path = self.training_root_path + "/config.json" - config = ModelConfig() + self.data_config = DataConfig() + self.data_config.load_from_file(config_file_path) + config = ModelConfig(self.data_config) config.load_from_file(config_file_path) + if self.clear_text: + config.data_config.input_clear_text = [True for _ in range(config.data_config.input_features)] + if self.data_config.output_features <= 0: + self.disambiguate = False + if self.data_config.output_translations <= 0: + self.translate = False - model = Model() - model.config = config + assert(self.disambiguate or self.translate) - ensemble = self.create_ensemble(len(self.ensemble_weights_path), model) - self.load_ensemble_weights(ensemble, self.ensemble_weights_path) + ensemble = self.create_ensemble(config, self.ensemble_weights_path) - output = None i = 0 + batch_x = None + batch_z = None for line in sys.stdin: if i == 0: - sample_x = read_sample_x_or_y_from_string(line) - output = self.predict_ensemble_on_sample(ensemble, sample_x) - i = 1 + sample_x = read_sample_x_from_string(line, feature_count=config.data_config.input_features, clear_text=config.data_config.input_clear_text) + self.preprocess_sample_x(ensemble, sample_x) + if batch_x is None: + batch_x = [[] for _ in range(len(sample_x))] + for j in range(len(sample_x)): + batch_x[j].append(sample_x[j]) + if self.disambiguate and not self.output_all_features: + i = 1 + else: + if len(batch_x[0]) >= self.batch_size: + self.predict_and_output(ensemble, batch_x, batch_z, self.data_config.input_clear_text) + batch_x = None elif i == 1: - sample_z = read_sample_z_from_string(line) - sample_y = self.generate_wsd_on_sample(output, sample_z) - sys.stdout.write(sample_y + "\n") - sys.stdout.flush() + sample_z = read_sample_z_from_string(line, feature_count=config.data_config.output_features) + if batch_z is None: + batch_z = [[] for _ in range(len(sample_z))] + for j in range(len(sample_z)): + batch_z[j].append(sample_z[j]) i = 0 + if len(batch_z[0]) >= self.batch_size: + self.predict_and_output(ensemble, batch_x, batch_z, self.data_config.input_clear_text) + batch_x = None + batch_z = None + if batch_x is not None: + self.predict_and_output(ensemble, batch_x, batch_z, self.data_config.input_clear_text) - @staticmethod - def create_ensemble(ensemble_size, model): - ensemble = [] - for _ in range(ensemble_size): - copy = Model() - copy.config = model.config - copy.create_model() - ensemble.append(copy) + def create_ensemble(self, config: ModelConfig, ensemble_weights_paths: List[str]): + ensemble = [Model(config) for _ in range(len(ensemble_weights_paths))] + for i in range(len(ensemble)): + ensemble[i].load_model_weights(ensemble_weights_paths[i]) + ensemble[i].set_beam_size(self.beam_size) return ensemble - @staticmethod - def load_ensemble_weights(ensemble, ensemble_weights_paths): - for i in range(0, len(ensemble)): - ensemble[i].load_model_weights(ensemble_weights_paths[i]) - + def preprocess_sample_x(ensemble: List[Model], sample_x): + ensemble[0].preprocess_samples([[sample_x]]) + + def predict_and_output(self, ensemble: List[Model], batch_x, batch_z, clear_text): + pad_batch_x(batch_x, clear_text) + output_wsd, output_translation = None, None + # TODO: refact this horror + if self.disambiguate and not self.translate and self.output_all_features: + output_all_features = Predicter.predict_ensemble_all_features_on_batch(ensemble, batch_x) + batch_all_features = Predicter.generate_all_features_on_batch(output_all_features, batch_x) + for sample_all_features in batch_all_features: + sys.stdout.write(sample_all_features + "\n") + sys.stdout.flush() + return + if self.disambiguate and not self.translate: + output_wsd = Predicter.predict_ensemble_wsd_on_batch(ensemble, batch_x) + elif self.translate and not self.disambiguate: + output_translation = Predicter.predict_ensemble_translation_on_batch(ensemble, batch_x) + else: + output_wsd, output_translation = Predicter.predict_ensemble_wsd_and_translation_on_batch(ensemble, batch_x) + if output_wsd is not None and output_translation is None: + batch_wsd = Predicter.generate_wsd_on_batch(output_wsd, batch_z) + for sample_wsd in batch_wsd: + sys.stdout.write(sample_wsd + "\n") + elif output_translation is not None and output_wsd is None: + batch_translation = Predicter.generate_translation_on_batch(output_translation, ensemble[0].config.data_config.output_translation_vocabularies[0][0]) + for sample_translation in batch_translation: + sys.stdout.write(sample_translation + "\n") + elif output_wsd is not None and output_translation is not None: + batch_wsd = Predicter.generate_wsd_on_batch(output_wsd, batch_z) + batch_translation = Predicter.generate_translation_on_batch(output_translation, ensemble[0].config.data_config.output_translation_vocabularies[0][0]) + assert len(batch_wsd) == len(batch_translation) + for i in range(len(batch_wsd)): + sys.stdout.write(batch_wsd[i] + "\n") + sys.stdout.write(batch_translation[i] + "\n") + sys.stdout.flush() @staticmethod - def predict_ensemble_on_sample(ensemble, sample_x): + def predict_ensemble_wsd_on_batch(ensemble: List[Model], batch_x): if len(ensemble) == 1: - return ensemble[0].predict_model_on_sample(sample_x) + return ensemble[0].predict_wsd_on_batch(batch_x) ensemble_sample_y = None for model in ensemble: - model_sample_y = model.predict_model_on_sample(sample_x) - model_sample_y = np.log(model_sample_y) + model_sample_y = model.predict_wsd_on_batch(batch_x) + model_sample_y = log_softmax(model_sample_y, dim=2) if ensemble_sample_y is None: ensemble_sample_y = model_sample_y else: - ensemble_sample_y = np.sum([ensemble_sample_y, model_sample_y], axis=0) - ensemble_sample_y = np.divide(ensemble_sample_y, len(ensemble)) - ensemble_sample_y = np.exp(ensemble_sample_y) + ensemble_sample_y = model_sample_y + ensemble_sample_y return ensemble_sample_y + @staticmethod + def predict_ensemble_all_features_on_batch(ensemble: List[Model], batch_x): + if len(ensemble) == 1: + return ensemble[0].predict_all_features_on_batch(batch_x) + else: + # TODO: manage ensemble + return None + + @staticmethod + def predict_ensemble_translation_on_batch(ensemble: List[Model], batch_x): + if len(ensemble) == 1: + return ensemble[0].predict_translation_on_batch(batch_x) + else: + # TODO: manage ensemble + return None + + @staticmethod + def predict_ensemble_wsd_and_translation_on_batch(ensemble: List[Model], batch_x): + if len(ensemble) == 1: + return ensemble[0].predict_wsd_and_translation_on_batch(batch_x) + else: + # TODO: manage ensemble + return None + + @staticmethod + def generate_wsd_on_batch(output, batch_z): + batch_wsd = [] + for i in range(len(batch_z[0])): + batch_wsd.append(Predicter.generate_wsd_on_sample(output[i], batch_z[0][i])) + return batch_wsd + + @staticmethod + def generate_all_features_on_batch(output, batch_x): + batch_wsd = [] + for i in range(len(batch_x[0])): + batch_wsd.append(Predicter.generate_all_features_on_sample(output, batch_x, i)) + return batch_wsd + + @staticmethod + def generate_translation_on_batch(output, vocabulary): + return unpad_turn_to_text_and_remove_bpe_of_batch_t(output, vocabulary) @staticmethod def generate_wsd_on_sample(output, sample_z): - sample_y = "" - for j in range(len(output)): - if j < len(sample_z[0]): - restricted_possibilities = sample_z[0][j] - max_proba = 0 + sample_wsd: List[str] = [] + for i in range(len(sample_z)): + restricted_possibilities = sample_z[i] + if 0 in restricted_possibilities: + sample_wsd.append("0") + elif -1 in restricted_possibilities: + sample_wsd.append(str(torch_argmax(output[i]).item())) + else: + max_proba = None max_possibility = None for possibility in restricted_possibilities: - if possibility != 0: - proba = output[j][possibility] - if proba > max_proba: - max_proba = proba - max_possibility = possibility - if max_possibility is not None: - sample_y += str(max_possibility) + " " - else: - sample_y += "0 " - return sample_y + proba = output[i][possibility] + if max_proba is None or proba > max_proba: + max_proba = proba + max_possibility = possibility + sample_wsd.append(str(max_possibility)) + return " ".join(sample_wsd) + @staticmethod + def generate_all_features_on_sample(output, batch_x, i): + return " ".join(["/".join([str(torch_argmax(output[k][i][j]).item()) for k in range(len(output))]) for j in range(len(batch_x[0][i]))]) diff --git a/python/getalp/wsd/torch_fix.py b/python/getalp/wsd/torch_fix.py new file mode 100644 index 0000000..e4e2e73 --- /dev/null +++ b/python/getalp/wsd/torch_fix.py @@ -0,0 +1,78 @@ +import torch + + +def torch_tensor(*args, **kwargs) -> torch.Tensor: + return torch.tensor(*args, **kwargs) + + +def torch_from_numpy(*args, **kwargs) -> torch.Tensor: + return torch.from_numpy(*args, **kwargs) + + +def torch_empty(*args, **kwargs) -> torch.Tensor: + return torch.empty(*args, **kwargs) + + +def torch_zeros(*args, **kwargs) -> torch.Tensor: + return torch.zeros(*args, **kwargs) + + +def torch_ones(*args, **kwargs) -> torch.Tensor: + return torch.ones(*args, **kwargs) + + +def torch_ones_like(*args, **kwargs) -> torch.Tensor: + return torch.ones_like(*args, **kwargs) + + +def torch_full(*args, **kwargs) -> torch.Tensor: + return torch.full(*args, **kwargs) + + +def torch_cat(*args, **kwargs) -> torch.Tensor: + return torch.cat(*args, **kwargs) + + +def torch_stack(*args, **kwargs) -> torch.Tensor: + return torch.stack(*args, **kwargs) + + +def torch_transpose(*args, **kwargs) -> torch.Tensor: + return torch.transpose(*args, **kwargs) + + +def torch_squeeze(*args, **kwargs) -> torch.Tensor: + return torch.squeeze(*args, **kwargs) + + +def torch_unsqueeze(*args, **kwargs) -> torch.Tensor: + return torch.unsqueeze(*args, **kwargs) + + +def torch_max(*args, **kwargs) -> torch.Tensor: + return torch.max(*args, **kwargs) + + +def torch_argmax(*args, **kwargs) -> torch.Tensor: + return torch.argmax(*args, **kwargs) + + +def torch_arange(*args, **kwargs) -> torch.Tensor: + return torch.arange(*args, **kwargs) + + +def torch_exp(*args, **kwargs) -> torch.Tensor: + return torch.exp(*args, **kwargs) + + +def torch_sin(*args, **kwargs) -> torch.Tensor: + return torch.sin(*args, **kwargs) + + +def torch_cos(*args, **kwargs) -> torch.Tensor: + return torch.cos(*args, **kwargs) + + +torch_float32: torch.dtype = torch.float32 +torch_long: torch.dtype = torch.long +torch_uint8: torch.dtype = torch.uint8 diff --git a/python/getalp/wsd/torch_utils.py b/python/getalp/wsd/torch_utils.py new file mode 100644 index 0000000..712a6fa --- /dev/null +++ b/python/getalp/wsd/torch_utils.py @@ -0,0 +1,11 @@ +import torch + + +cpu_device = torch.device("cpu") + +if torch.cuda.is_available(): + gpu_device = torch.device("cuda:0") + default_device = gpu_device +else: + default_device = cpu_device + diff --git a/python/getalp/wsd/train.py b/python/getalp/wsd/train.py index 23c0400..2da6ad0 100644 --- a/python/getalp/wsd/train.py +++ b/python/getalp/wsd/train.py @@ -1,5 +1,7 @@ from getalp.wsd.trainer import Trainer +from getalp.common.common import str2bool import argparse +import pprint def main(): @@ -7,30 +9,84 @@ def main(): parser.add_argument('--data_path', required=True, type=str, help=" ") parser.add_argument('--model_path', required=True, type=str, help=" ") parser.add_argument('--batch_size', nargs="?", type=int, default=100, help=" ") + parser.add_argument('--token_per_batch', nargs="?", type=int, default=8000, help=" ") parser.add_argument('--ensemble_count', nargs="?", type=int, default=8, help=" ") - parser.add_argument('--epoch_count', nargs="?", type=int, default=100, help=" ") + parser.add_argument('--epoch_count', nargs="?", type=int, default=40, help=" ") parser.add_argument('--eval_frequency', nargs="?", type=int, default=4000, help=" ") parser.add_argument('--update_frequency', nargs="?", type=int, default=1, help=" ") - parser.add_argument('--lr', nargs="?", type=float, default=0.0001, help=" ") - parser.add_argument('--warmup_sample_size', nargs="?", type=int, default=80, help=" ") + parser.add_argument('--warmup_batch_count', nargs="?", type=int, default=10, help=" ") + parser.add_argument('--input_embeddings_size', nargs="+", type=int, default=None, help=" ") + parser.add_argument('--input_elmo_model', nargs="+", type=str, default=None, help=" ") + parser.add_argument('--input_bert_model', nargs="+", type=str, default=None, help=" ") + parser.add_argument('--input_word_dropout_rate', nargs="?", type=float, default=None, help=" ") + parser.add_argument('--input_apply_linear', nargs="?", type=str2bool, default=None, help=" ") + parser.add_argument('--input_linear_size', nargs="?", type=int, default=None, help=" ") + parser.add_argument('--input_dropout_rate', nargs="?", type=float, default=None, help=" ") + parser.add_argument('--encoder_type', nargs="?", type=str, default=None, help=" ") + parser.add_argument('--encoder_lstm_hidden_size', nargs="?", type=int, default=None, help=" ") + parser.add_argument('--encoder_lstm_layers', nargs="?", type=int, default=None, help=" ") + parser.add_argument('--encoder_lstm_dropout', nargs="?", type=float, default=None, help=" ") + parser.add_argument('--encoder_transformer_hidden_size', nargs="?", type=int, default=None, help=" ") + parser.add_argument('--encoder_transformer_layers', nargs="?", type=int, default=None, help=" ") + parser.add_argument('--encoder_transformer_heads', nargs="?", type=int, default=None, help=" ") + parser.add_argument('--encoder_transformer_dropout', nargs="?", type=float, default=None, help=" ") + parser.add_argument('--encoder_transformer_positional_encoding', nargs="?", type=str2bool, default=None, help=" ") + parser.add_argument('--encoder_transformer_scale_embeddings', nargs="?", type=str2bool, default=None, help=" ") + parser.add_argument('--optimizer', nargs="?", type=str, default="adam", help=" ", choices=["adam"]) + parser.add_argument('--adam_beta1', nargs="?", type=float, default=0.9, help=" ") + parser.add_argument('--adam_beta2', nargs="?", type=float, default=0.999, help=" ") + parser.add_argument('--adam_eps', nargs="?", type=float, default=1e-8, help=" ") + parser.add_argument('--lr_scheduler', nargs="?", type=str, default="fixed", help=" ", choices=("fixed", "noam")) + parser.add_argument('--lr_scheduler_fixed_lr', nargs="?", type=float, default=0.0001, help=" ") + parser.add_argument('--lr_scheduler_noam_warmup', nargs="?", type=int, default=6000, help=" ") + parser.add_argument('--lr_scheduler_noam_model_size', nargs="?", type=int, default=512, help=" ") parser.add_argument('--reset', action="store_true", help=" ") + parser.add_argument('--save_best_loss', action="store_true", help=" ") + parser.add_argument('--save_every_epoch', action="store_true", help=" ") args = parser.parse_args() - print(args) + print("Command line arguments:") + pprint.pprint(vars(args)) trainer = Trainer() + trainer.data_path = args.data_path trainer.model_path = args.model_path trainer.batch_size = args.batch_size - trainer.test_every_batch = args.eval_frequency + trainer.token_per_batch = args.token_per_batch + trainer.eval_frequency = args.eval_frequency trainer.update_every_batch = args.update_frequency trainer.stop_after_epoch = args.epoch_count trainer.ensemble_size = args.ensemble_count - trainer.save_best_loss = False - trainer.save_end_of_epoch = False + trainer.save_best_loss = args.save_best_loss + trainer.save_end_of_epoch = args.save_every_epoch trainer.shuffle_train_on_init = True - trainer.warmup_sample_size = args.warmup_sample_size + trainer.warmup_batch_count = args.warmup_batch_count + trainer.input_embeddings_size = args.input_embeddings_size + trainer.input_elmo_model = args.input_elmo_model + trainer.input_bert_model = args.input_bert_model + trainer.input_word_dropout_rate = args.input_word_dropout_rate + trainer.input_apply_linear = args.input_apply_linear + trainer.input_linear_size = args.input_linear_size + trainer.input_dropout_rate = args.input_dropout_rate + trainer.encoder_type = args.encoder_type + trainer.encoder_lstm_layers = args.encoder_lstm_layers + trainer.encoder_lstm_hidden_size = args.encoder_lstm_hidden_size + trainer.encoder_lstm_dropout = args.encoder_lstm_dropout + trainer.encoder_transformer_hidden_size = args.encoder_transformer_hidden_size + trainer.encoder_transformer_layers = args.encoder_transformer_layers + trainer.encoder_transformer_heads = args.encoder_transformer_heads + trainer.encoder_transformer_dropout = args.encoder_transformer_dropout + trainer.encoder_transformer_positional_encoding = args.encoder_transformer_positional_encoding + trainer.encoder_transformer_scale_embeddings = args.encoder_transformer_scale_embeddings + trainer.optimizer = args.optimizer + trainer.adam_beta1 = args.adam_beta1 + trainer.adam_beta2 = args.adam_beta2 + trainer.adam_eps = args.adam_eps + trainer.lr_scheduler = args.lr_scheduler + trainer.lr_scheduler_fixed_lr = args.lr_scheduler_fixed_lr + trainer.lr_scheduler_noam_warmup = args.lr_scheduler_noam_warmup + trainer.lr_scheduler_noam_model_size = args.lr_scheduler_noam_model_size trainer.reset = args.reset - trainer.learning_rate = args.lr trainer.train() diff --git a/python/getalp/wsd/trainer.py b/python/getalp/wsd/trainer.py index 757894c..4d24a43 100644 --- a/python/getalp/wsd/trainer.py +++ b/python/getalp/wsd/trainer.py @@ -1,7 +1,16 @@ from getalp.wsd.common import * -from getalp.wsd.model import Model, ModelConfig -from getalp.common.common import create_directory_if_not_exists +from getalp.wsd.model import Model, ModelConfig, DataConfig +from getalp.common.common import create_directory_if_not_exists, set_if_not_none import os +import sys +import shutil +import pprint +import sacrebleu +import random +try: + import tensorboardX +except ImportError: + tensorboardX = None class Trainer(object): @@ -10,80 +19,180 @@ def __init__(self): self.data_path: str = str() self.model_path: str = str() self.batch_size = int() - self.test_every_batch = int() + self.token_per_batch = int() + self.eval_frequency = int() + self.update_every_batch: int = int() self.stop_after_epoch = int() self.ensemble_size = int() self.save_best_loss = bool() self.save_end_of_epoch = bool() self.shuffle_train_on_init = bool() - self.warmup_sample_size: int = int() + self.warmup_batch_count: int = int() + self.input_embeddings_size: List[int] = None + self.input_elmo_model: List[str] = None + self.input_bert_model: List[str] = None + self.input_flair_model: List[str] = None + self.input_word_dropout_rate: float = None + self.input_resize: List[int] = None + self.input_apply_linear: bool = None + self.input_linear_size: int = None + self.input_dropout_rate: float = None + self.encoder_type: str = None + self.encoder_lstm_hidden_size: int = None + self.encoder_lstm_layers: int = None + self.encoder_lstm_dropout: float = None + self.encoder_transformer_hidden_size: int = None + self.encoder_transformer_layers: int = None + self.encoder_transformer_heads: int = None + self.encoder_transformer_dropout: float = None + self.encoder_transformer_positional_encoding: bool = None + self.encoder_transformer_scale_embeddings: bool = None + self.decoder_translation_transformer_hidden_size: int = None + self.decoder_translation_transformer_dropout: float = None + self.decoder_translation_scale_embeddings: bool = None + self.decoder_translation_share_embeddings: bool = None + self.decoder_translation_share_encoder_embeddings: bool = None + self.decoder_translation_tokenizer_bert: str = None + self.optimizer: str = str() + self.adam_beta1: float = float() + self.adam_beta2: float = float() + self.adam_eps: float = float() + self.lr_scheduler: str = str() + self.lr_scheduler_fixed_lr: float = float() + self.lr_scheduler_noam_warmup: int = int() + self.lr_scheduler_noam_model_size: int = int() self.reset: bool = bool() - self.learning_rate: float = float() - self.update_every_batch: int = int() - def train(self): model_weights_last_path = self.model_path + "/model_weights_last" model_weights_loss_path = self.model_path + "/model_weights_loss" model_weights_wsd_path = self.model_path + "/model_weights_wsd" + model_weights_bleu_path = self.model_path + "/model_weights_bleu" model_weights_end_of_epoch_path = self.model_path + "/model_weights_end_of_epoch_" training_info_path = self.model_path + "/training_info" - training_losses_path = self.model_path + "/training_losses" + tensorboard_path = self.model_path + "/tensorboard" train_file_path = self.data_path + "/train" dev_file_path = self.data_path + "/dev" config_file_path = self.data_path + "/config.json" print("Loading config and embeddings") - config = ModelConfig() + data_config: DataConfig = DataConfig() + data_config.load_from_file(config_file_path) + config: ModelConfig = ModelConfig(data_config) config.load_from_file(config_file_path) - print("Creating model") - model = Model() - model.config = config - self.recreate_model(model) - - print("Warming up on fake batch") - batch_x, batch_y = create_fake_batch(batch_size=self.batch_size, sample_size=self.warmup_sample_size, input_features=model.config.input_features, input_vocabulary_sizes=model.config.input_vocabulary_sizes, output_features=model.config.output_features, output_vocabulary_sizes=model.config.output_vocabulary_sizes) - model.begin_train_on_batch() - model.train_on_batch(batch_x, batch_y, None) - model.end_train_on_batch() - - self.recreate_model(model) - - print("Loading training and development data") - train_samples = read_all_samples_from_file(train_file_path) - dev_samples = read_all_samples_from_file(dev_file_path) + # change config from CLI parameters + config.input_embeddings_sizes = set_if_not_none(self.input_embeddings_size, config.input_embeddings_sizes) + if self.input_elmo_model is not None: + config.set_input_elmo_path(self.input_elmo_model) + if self.input_bert_model is not None: + config.set_input_bert_model(self.input_bert_model) + if self.input_flair_model is not None: + config.set_input_flair_model(self.input_flair_model) + if self.input_word_dropout_rate is not None: + config.input_word_dropout_rate = self.input_word_dropout_rate + eprint("Warning: input_word_dropout_rate is not implemented") + config.input_apply_linear = set_if_not_none(self.input_apply_linear, config.input_apply_linear) + config.input_linear_size = set_if_not_none(self.input_linear_size, config.input_linear_size) + config.input_dropout_rate = set_if_not_none(self.input_dropout_rate, config.input_dropout_rate) + config.encoder_type = set_if_not_none(self.encoder_type, config.encoder_type) + config.encoder_lstm_hidden_size = set_if_not_none(self.encoder_lstm_hidden_size, config.encoder_lstm_hidden_size) + config.encoder_lstm_layers = set_if_not_none(self.encoder_lstm_layers, config.encoder_lstm_layers) + config.encoder_lstm_dropout = set_if_not_none(self.encoder_lstm_dropout, config.encoder_lstm_dropout) + config.encoder_transformer_hidden_size = set_if_not_none(self.encoder_transformer_hidden_size, config.encoder_transformer_hidden_size) + config.encoder_transformer_layers = set_if_not_none(self.encoder_transformer_layers, config.encoder_transformer_layers) + config.encoder_transformer_heads = set_if_not_none(self.encoder_transformer_heads, config.encoder_transformer_heads) + config.encoder_transformer_dropout = set_if_not_none(self.encoder_transformer_dropout, config.encoder_transformer_dropout) + config.encoder_transformer_positional_encoding = set_if_not_none(self.encoder_transformer_positional_encoding, config.encoder_transformer_positional_encoding) + config.encoder_transformer_scale_embeddings = set_if_not_none(self.encoder_transformer_scale_embeddings, config.encoder_transformer_scale_embeddings) + config.decoder_translation_transformer_hidden_size = set_if_not_none(self.decoder_translation_transformer_hidden_size, config.decoder_translation_transformer_hidden_size) + config.decoder_translation_transformer_dropout = set_if_not_none(self.decoder_translation_transformer_dropout, config.decoder_translation_transformer_dropout) + config.decoder_translation_scale_embeddings = set_if_not_none(self.decoder_translation_scale_embeddings, config.decoder_translation_scale_embeddings) + config.decoder_translation_share_embeddings = set_if_not_none(self.decoder_translation_share_embeddings, config.decoder_translation_share_embeddings) + config.decoder_translation_share_encoder_embeddings = set_if_not_none(self.decoder_translation_share_encoder_embeddings, config.decoder_translation_share_encoder_embeddings) + config.decoder_translation_tokenizer_bert = set_if_not_none(self.decoder_translation_tokenizer_bert, config.decoder_translation_tokenizer_bert) + + print("GPU is available: " + str(torch.cuda.is_available())) + + model: Model = Model(config) + model.set_adam_parameters(adam_beta1=self.adam_beta1, adam_beta2=self.adam_beta2, adam_eps=self.adam_eps) + model.set_lr_scheduler(lr_scheduler=self.lr_scheduler, fixed_lr=self.lr_scheduler_fixed_lr, warmup=self.lr_scheduler_noam_warmup, model_size=self.lr_scheduler_noam_model_size) current_ensemble = 0 current_epoch = 0 current_batch = 0 + current_batch_total = 0 current_sample_index = 0 - best_dev_wsd = None best_dev_loss = None + best_dev_wsd = None + best_dev_bleu = None + random_seed = self.generate_random_seed() if not self.reset and os.path.isfile(training_info_path) and os.path.isfile(model_weights_last_path): print("Resuming from previous training") - current_ensemble, current_epoch, current_batch, current_sample_index, best_dev_wsd, best_dev_loss = load_training_info(training_info_path) + current_ensemble, current_epoch, current_batch, current_batch_total, current_sample_index, best_dev_loss, best_dev_wsd, best_dev_bleu, random_seed = load_training_info(training_info_path) model.load_model_weights(model_weights_last_path) - elif self.shuffle_train_on_init: + else: + print("Creating model") + model.create_model() + create_directory_if_not_exists(self.model_path) + + print("Random seed is " + str(random_seed)) + + print("Config is: ") + pprint.pprint(config.get_serializable_data()) + + print("Number of parameters (total): " + model.get_number_of_parameters(filter_requires_grad=False)) + print("Number of parameters (learned): " + model.get_number_of_parameters(filter_requires_grad=True)) + + print("Warming up on " + str(self.warmup_batch_count) + " batches") + train_samples = read_samples_from_file(train_file_path, data_config.input_clear_text, data_config.output_features, data_config.output_translations, data_config.output_translation_features, data_config.output_translation_clear_text, self.batch_size*self.warmup_batch_count) + model.preprocess_samples(train_samples) + for i in range(self.warmup_batch_count): + batch_x, batch_y, batch_z, batch_tt, actual_batch_size, reached_eof = read_batch_from_samples(train_samples, self.batch_size, -1, 0, data_config.input_features, data_config.output_features, data_config.output_translations, data_config.output_translation_features, data_config.input_clear_text, data_config.output_translation_clear_text) + model.begin_train_on_batch() + model.train_on_batch(batch_x, batch_y, batch_tt) + model.end_train_on_batch() + + print("Loading training and development data") + train_samples = read_samples_from_file(train_file_path, input_clear_text=data_config.input_clear_text, output_features=data_config.output_features, output_translations=data_config.output_translations, output_translation_features=data_config.output_translation_features, output_translation_clear_text=data_config.output_translation_clear_text) + dev_samples = read_samples_from_file(dev_file_path, input_clear_text=data_config.input_clear_text, output_features=data_config.output_features, output_translations=data_config.output_translations, output_translation_features=data_config.output_translation_features, output_translation_clear_text=data_config.output_translation_clear_text) + + print("Preprocessing training and development data") + model.preprocess_samples(train_samples) + model.preprocess_samples(dev_samples) + + if self.shuffle_train_on_init: print("Shuffling training data") + random.seed(random_seed) random.shuffle(train_samples) - create_directory_if_not_exists(self.model_path) + self.print_state(current_ensemble, current_epoch, current_batch, current_batch_total, len(train_samples), current_sample_index, [None for _ in range(data_config.output_features + data_config.output_translations * data_config.output_translation_features)], [None for _ in range(data_config.output_features + data_config.output_translations * data_config.output_translation_features)], [None for _ in range(data_config.output_features)], None) - self.print_state(current_ensemble, current_epoch, current_batch, [None for _ in range(model.config.output_features)], [None for _ in range(model.config.output_features)], None) + if self.reset: + shutil.rmtree(tensorboard_path, ignore_errors=True) for current_ensemble in range(current_ensemble, self.ensemble_size): + if tensorboardX is not None: + tb_writer = tensorboardX.SummaryWriter(tensorboard_path + '/ensemble' + str(current_ensemble)) + else: + tb_writer = None sample_accumulate_between_eval = 0 train_losses = None while self.stop_after_epoch == -1 or current_epoch < self.stop_after_epoch: + model.update_learning_rate(step=current_batch_total) + + print("training sample " + str(current_sample_index) + "/" + str(len(train_samples)), end="\r") + sys.stdout.flush() + reached_eof = False model.begin_train_on_batch() for _ in range(self.update_every_batch): - batch_x, batch_y, batch_z, actual_batch_size, reached_eof = read_batch_from_samples(train_samples, self.batch_size, current_sample_index) - if actual_batch_size == 0: break - batch_losses = model.train_on_batch(batch_x, batch_y, batch_z) + batch_x, batch_y, batch_z, batch_tt, actual_batch_size, reached_eof = read_batch_from_samples(train_samples, self.batch_size, self.token_per_batch, current_sample_index, data_config.input_features, data_config.output_features, data_config.output_translations, data_config.output_translation_features, data_config.input_clear_text, data_config.output_translation_clear_text) + if actual_batch_size == 0: + break + batch_losses = model.train_on_batch(batch_x, batch_y, batch_tt) if train_losses is None: train_losses = [0 for _ in batch_losses] for i in range(len(batch_losses)): @@ -91,106 +200,178 @@ def train(self): current_sample_index += actual_batch_size sample_accumulate_between_eval += actual_batch_size current_batch += 1 - if reached_eof: break + current_batch_total += 1 + if reached_eof: + break model.end_train_on_batch() if reached_eof: - print("Reached eof at batch " + str(current_batch)) if self.save_end_of_epoch: model.save_model_weights(model_weights_end_of_epoch_path + str(current_epoch) + "_" + str(current_ensemble)) current_batch = 0 current_sample_index = 0 current_epoch += 1 + random_seed = self.generate_random_seed() + random.seed(random_seed) random.shuffle(train_samples) - if current_batch % self.test_every_batch == 0: - dev_wsd, dev_losses = self.test_on_dev(self.batch_size, dev_samples, model) + if current_batch % self.eval_frequency == 0: + dev_losses, dev_wsd, dev_bleu = self.test_on_dev(dev_samples, model) for i in range(len(train_losses)): train_losses[i] /= float(sample_accumulate_between_eval) - self.print_state(current_ensemble, current_epoch, current_batch, train_losses, dev_losses, dev_wsd) - save_training_losses(training_losses_path, train_losses[0], dev_losses[0], dev_wsd) + self.print_state(current_ensemble, current_epoch, current_batch, current_batch_total, len(train_samples), current_sample_index, train_losses, dev_losses, dev_wsd, dev_bleu) + self.write_tensorboard(tb_writer, current_epoch, train_samples, current_sample_index, train_losses, dev_losses, dev_wsd, data_config.output_feature_names, dev_bleu, model.optimizer.scheduler.get_learning_rate(current_batch_total)) sample_accumulate_between_eval = 0 train_losses = None if best_dev_loss is None or dev_losses[0] < best_dev_loss: if self.save_best_loss: model.save_model_weights(model_weights_loss_path + str(current_ensemble)) + print("New best dev loss: " + str(dev_losses[0])) best_dev_loss = dev_losses[0] - if best_dev_wsd is None or dev_wsd > best_dev_wsd: + if len(dev_wsd) > 0 and (best_dev_wsd is None or dev_wsd[0] > best_dev_wsd): model.save_model_weights(model_weights_wsd_path + str(current_ensemble)) - best_dev_wsd = dev_wsd + best_dev_wsd = dev_wsd[0] print("New best dev WSD: " + str(best_dev_wsd)) + if (best_dev_bleu is None or dev_bleu > best_dev_bleu) and dev_bleu is not None: + model.save_model_weights(model_weights_bleu_path + str(current_ensemble)) + best_dev_bleu = dev_bleu + print("New best dev BLEU: " + str(best_dev_bleu)) + model.save_model_weights(model_weights_last_path) - save_training_info(training_info_path, current_ensemble, current_epoch, current_batch, current_sample_index, best_dev_wsd, best_dev_loss) + save_training_info(training_info_path, current_ensemble, current_epoch, current_batch, current_batch_total, current_sample_index, best_dev_loss, best_dev_wsd, best_dev_bleu, random_seed) - self.recreate_model(model) + model.create_model() current_epoch = 0 - best_dev_wsd = None + current_batch_total = 0 best_dev_loss = None - - - def recreate_model(self, model): - model.create_model() - model.set_learning_rate(self.learning_rate) - + best_dev_wsd = None + best_dev_bleu = None @staticmethod - def print_state(current_ensemble, current_epoch, current_batch, train_losses, dev_losses, dev_wsd): - print("Ensemble " + str(current_ensemble) + " - Epoch " + str(current_epoch) + " - Batch " + str(current_batch) + " - Train losses = " + str(train_losses) + " - Dev losses = " + str(dev_losses) + " - Dev wsd = " + str(dev_wsd)) - - - def test_on_dev(self, batch_size, dev_samples, model): - loss = self.get_loss_metrics(batch_size, dev_samples, model) - wsd = self.get_wsd_metrics(batch_size, dev_samples, model) - return wsd, loss + def generate_random_seed(): + return int.from_bytes(os.urandom(8), byteorder='big', signed=False) + @staticmethod + def print_state(current_ensemble, current_epoch, current_batch, current_batch_total, samples_count, current_sample, train_losses, dev_losses, dev_wsd, dev_bleu): + print("Ensemble " + str(current_ensemble) + " - Epoch " + str(current_epoch) + " - Batch " + str(current_batch) + + " - Sample " + str(current_sample) + " - Total Batch " + str(current_batch_total) + " - Total Sample " + str(current_epoch * samples_count + current_sample) + + " - Train losses = " + str(train_losses) + + " - Dev losses = " + str(dev_losses) + " - Dev wsd = " + str(dev_wsd) + " - Dev bleu = " + str(dev_bleu)) @staticmethod - def get_loss_metrics(batch_size, dev_samples, model): - losses = None + def write_tensorboard(tb_writer, current_epoch, train_samples, current_sample_index, train_losses, dev_losses, dev_wsd, output_annotation_names, dev_bleu, learning_rate): + if tb_writer is None: + return + tb_index = current_epoch * len(train_samples) + current_sample_index + if len(train_losses) > 0: + tb_writer.add_scalar('train_loss', train_losses[0], tb_index) + if len(dev_losses) > 0: + tb_writer.add_scalar('dev_loss', dev_losses[0], tb_index) + if len(dev_wsd) > 0: + tb_writer.add_scalar('dev_wsd', dev_wsd[0], tb_index) + if dev_bleu is not None: + tb_writer.add_scalar('dev_bleu', dev_bleu, tb_index) + + tb_writer.add_scalar('learning_rate', learning_rate, tb_index) + + for i in range(len(train_losses)): + tb_writer.add_scalar('train_losses/train_loss' + str(i), train_losses[i], tb_index) + for i in range(len(dev_losses)): + tb_writer.add_scalar('dev_losses/dev_loss' + str(i), dev_losses[i], tb_index) + for i in range(1, len(dev_wsd)): + tb_writer.add_scalar('dev_wsds/dev_wsd' + str(i) + "_" + output_annotation_names[i-1].replace("%", "_"), dev_wsd[i], tb_index) + + def test_on_dev(self, dev_samples, model: Model): + loss = self.get_loss_metrics(dev_samples, model) + wsd = self.get_wsd_metrics(dev_samples, model) + bleu = None + if model.config.data_config.output_translations > 0: + bleu = self.get_bleu_metrics(dev_samples, model) + return loss, wsd, bleu + + def get_loss_metrics(self, dev_samples, model: Model): + losses = [0 for _ in range(model.config.data_config.output_features + model.config.data_config.output_translations * model.config.data_config.output_translation_features)] reached_eof = False current_index = 0 while not reached_eof: - batch_x, batch_y, batch_z, actual_batch_size, reached_eof = read_batch_from_samples(dev_samples, batch_size, current_index) - if actual_batch_size == 0: break - batch_losses = model.test_model_on_batch(batch_x, batch_y, batch_z) - if losses is None: - losses = [0 for _ in batch_losses] + batch_x, batch_y, batch_z, batch_tt, actual_batch_size, reached_eof = read_batch_from_samples(dev_samples, self.batch_size, self.token_per_batch, current_index, model.config.data_config.input_features, model.config.data_config.output_features, model.config.data_config.output_translations, model.config.data_config.output_translation_features, model.config.data_config.input_clear_text, model.config.data_config.output_translation_clear_text) + if actual_batch_size == 0: + break + batch_losses = model.test_model_on_batch(batch_x, batch_y, batch_tt) for i in range(len(batch_losses)): losses[i] += (batch_losses[i] * actual_batch_size) current_index += actual_batch_size - for i in range(len(losses)): - losses[i] /= float(current_index) + if current_index != 0: + for i in range(len(losses)): + losses[i] /= float(current_index) return losses - - @staticmethod - def get_wsd_metrics(batch_size, dev_samples, model): - good = 0 - total = 0 + def get_wsd_metrics(self, dev_samples, model: Model): + output_features = model.config.data_config.output_features + if output_features == 0: + return [] + goods = [0 for _ in range(output_features)] + totals = [0 for _ in range(output_features)] + reached_eof = False + current_index = 0 + while not reached_eof: + batch_x, batch_y, batch_z, batch_tt, actual_batch_size, reached_eof = read_batch_from_samples(dev_samples, self.batch_size, self.token_per_batch, current_index, model.config.data_config.input_features, model.config.data_config.output_features, model.config.data_config.output_translations, model.config.data_config.output_translation_features, model.config.data_config.input_clear_text, model.config.data_config.output_translation_clear_text) + if actual_batch_size == 0: + break + output = model.predict_all_features_on_batch(batch_x) + for k in range(len(output)): # k: feat + for i in range(len(output[k])): # i: batch + for j in range(len(output[k][i])): # j: seq + if j < len(batch_z[k][i]): + restricted_possibilities = batch_z[k][i][j] + max_possibility = None + if 0 in restricted_possibilities: + max_possibility = None + elif -1 in restricted_possibilities: + max_possibility = torch_argmax(output[k][i][j]).item() + else: + max_proba = None + for possibility in restricted_possibilities: + proba = output[k][i][j][possibility].item() + if max_proba is None or proba > max_proba: + max_proba = proba + max_possibility = possibility + if max_possibility is not None: + totals[k] += 1 + if max_possibility == batch_y[k][i][j].item(): + goods[k] += 1 + current_index += actual_batch_size + all_wsd = [((float(goods[i]) / float(totals[i])) * float(100)) if totals[i] != 0 else float(0) for i in range(output_features)] + if output_features > 1: + if sum(totals) != 0: + summary_wsd = [(float(sum(goods)) / float(sum(totals))) * float(100)] + else: + summary_wsd = [float(0)] + return summary_wsd + all_wsd + else: + return all_wsd + + def get_bleu_metrics(self, dev_samples, model: Model): reached_eof = False current_index = 0 + all_hypothesis_sentences = [] + all_reference_sentences = [] while not reached_eof: - batch_x, batch_y, batch_z, actual_batch_size, reached_eof = read_batch_from_samples(dev_samples, batch_size, current_index) - if actual_batch_size == 0: break - output = model.predict_model_on_batch(batch_x) - for i in range(len(output)): - for j in range(len(output[i])): - if j < len(batch_z[0][i]): - restricted_possibilities = batch_z[0][i][j] - max_proba = None - max_possibility = None - for possibility in restricted_possibilities: - if possibility != 0: - proba = output[i][j][possibility] - if max_proba is None or proba > max_proba: - max_proba = proba - max_possibility = possibility - if max_possibility is not None: - total += 1 - if max_possibility == batch_y[0][i][j]: - good += 1 + batch_x, batch_y, batch_z, batch_tt, actual_batch_size, reached_eof = read_batch_from_samples(dev_samples, self.batch_size, self.token_per_batch, current_index, model.config.data_config.input_features, model.config.data_config.output_features, model.config.data_config.output_translations, model.config.data_config.output_translation_features, model.config.data_config.input_clear_text, model.config.data_config.output_translation_clear_text) + if actual_batch_size == 0: + break + reference = unpad_turn_to_text_and_remove_bpe_of_batch_t(batch_tt[0][0], model.config.data_config.output_translation_vocabularies[0][0]) + for sentence in reference: + all_reference_sentences.append(sentence) + output = model.predict_translation_on_batch(batch_x) + output = unpad_turn_to_text_and_remove_bpe_of_batch_t(output, model.config.data_config.output_translation_vocabularies[0][0]) + for sentence in output: + all_hypothesis_sentences.append(sentence) current_index += actual_batch_size - return (float(good) / float(total)) * float(100) + if reached_eof is True: + break + bleu = sacrebleu.raw_corpus_bleu(sys_stream=all_hypothesis_sentences, ref_streams=[all_reference_sentences]) + return bleu.score diff --git a/python/onmt/__init__.py b/python/onmt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/onmt/encoders/__init__.py b/python/onmt/encoders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/onmt/encoders/encoder.py b/python/onmt/encoders/encoder.py new file mode 100644 index 0000000..7cc542f --- /dev/null +++ b/python/onmt/encoders/encoder.py @@ -0,0 +1,58 @@ +"""Base class for encoders and generic multi encoders.""" + +import torch.nn as nn + +from onmt.utils.misc import aeq + + +class EncoderBase(nn.Module): + """ + Base encoder class. Specifies the interface used by different encoder types + and required by :class:`onmt.Models.NMTModel`. + + .. mermaid:: + + graph BT + A[Input] + subgraph RNN + C[Pos 1] + D[Pos 2] + E[Pos N] + end + F[Memory_Bank] + G[Final] + A-->C + A-->D + A-->E + C-->F + D-->F + E-->F + E-->G + """ + + @classmethod + def from_opt(cls, opt, embeddings=None): + raise NotImplementedError + + def _check_args(self, src, lengths=None, hidden=None): + _, n_batch, _ = src.size() + if lengths is not None: + n_batch_, = lengths.size() + aeq(n_batch, n_batch_) + + def forward(self, src, lengths=None): + """ + Args: + src (LongTensor): + padded sequences of sparse indices ``(src_len, batch, nfeat)`` + lengths (LongTensor): length of each sequence ``(batch,)`` + + + Returns: + (FloatTensor, FloatTensor): + + * final encoder state, used to initialize decoder + * memory bank for attention, ``(src_len, batch, hidden)`` + """ + + raise NotImplementedError diff --git a/python/onmt/encoders/transformer.py b/python/onmt/encoders/transformer.py new file mode 100644 index 0000000..f9726d1 --- /dev/null +++ b/python/onmt/encoders/transformer.py @@ -0,0 +1,135 @@ +""" +Implementation of "Attention is All You Need" +""" + +import torch.nn as nn + +from onmt.encoders.encoder import EncoderBase +from onmt.modules.multi_headed_attn import MultiHeadedAttention +from onmt.modules.position_ffn import PositionwiseFeedForward + + +class TransformerEncoderLayer(nn.Module): + """ + A single layer of the transformer encoder. + + Args: + d_model (int): the dimension of keys/values/queries in + MultiHeadedAttention, also the input size of + the first-layer of the PositionwiseFeedForward. + heads (int): the number of head for MultiHeadedAttention. + d_ff (int): the second-layer of the PositionwiseFeedForward. + dropout (float): dropout probability(0-1.0). + """ + + def __init__(self, d_model, heads, d_ff, dropout, + max_relative_positions=0): + super(TransformerEncoderLayer, self).__init__() + + self.self_attn = MultiHeadedAttention( + heads, d_model, dropout=dropout, + max_relative_positions=max_relative_positions) + self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.dropout = nn.Dropout(dropout) + + def forward(self, inputs, mask): + """ + Args: + inputs (FloatTensor): ``(batch_size, src_len, model_dim)`` + mask (LongTensor): ``(batch_size, src_len, src_len)`` + + Returns: + (FloatTensor): + + * outputs ``(batch_size, src_len, model_dim)`` + """ + input_norm = self.layer_norm(inputs) + context, _ = self.self_attn(input_norm, input_norm, input_norm, + mask=mask, type="self") + out = self.dropout(context) + inputs + return self.feed_forward(out) + + def update_dropout(self, dropout): + self.self_attn.update_dropout(dropout) + self.feed_forward.update_dropout(dropout) + self.dropout.p = dropout + + +class TransformerEncoder(EncoderBase): + """The Transformer encoder from "Attention is All You Need" + :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` + + .. mermaid:: + + graph BT + A[input] + B[multi-head self-attn] + C[feed forward] + O[output] + A --> B + B --> C + C --> O + + Args: + num_layers (int): number of encoder layers + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + dropout (float): dropout parameters + embeddings (onmt.modules.Embeddings): + embeddings to use, should have positional encodings + + Returns: + (torch.FloatTensor, torch.FloatTensor): + + * embeddings ``(src_len, batch_size, model_dim)`` + * memory_bank ``(src_len, batch_size, model_dim)`` + """ + + def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, + max_relative_positions): + super(TransformerEncoder, self).__init__() + + self.embeddings = embeddings + self.transformer = nn.ModuleList( + [TransformerEncoderLayer( + d_model, heads, d_ff, dropout, + max_relative_positions=max_relative_positions) + for i in range(num_layers)]) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + @classmethod + def from_opt(cls, opt, embeddings): + """Alternate constructor.""" + return cls( + opt.enc_layers, + opt.enc_rnn_size, + opt.heads, + opt.transformer_ff, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + embeddings, + opt.max_relative_positions) + + def forward(self, src, lengths=None): + """See :func:`EncoderBase.forward()`""" + self._check_args(src, lengths) + + emb = self.embeddings(src) + + out = emb.transpose(0, 1).contiguous() + words = src[:, :, 0].transpose(0, 1) + w_batch, w_len = words.size() + padding_idx = self.embeddings.word_padding_idx + mask = words.data.eq(padding_idx).unsqueeze(1) # [B, 1, T] + # Run the forward pass of every layer of the tranformer. + for layer in self.transformer: + out = layer(out, mask) + out = self.layer_norm(out) + + return emb, out.transpose(0, 1).contiguous(), lengths + + def update_dropout(self, dropout): + self.embeddings.update_dropout(dropout) + for layer in self.transformer: + layer.update_dropout(dropout) diff --git a/python/onmt/modules/__init__.py b/python/onmt/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/onmt/modules/multi_headed_attn.py b/python/onmt/modules/multi_headed_attn.py new file mode 100644 index 0000000..e9158ae --- /dev/null +++ b/python/onmt/modules/multi_headed_attn.py @@ -0,0 +1,232 @@ +""" Multi-Head Attention module """ +import math +import torch +import torch.nn as nn + +from onmt.utils.misc import generate_relative_positions_matrix,\ + relative_matmul +# from onmt.utils.misc import aeq + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention module from "Attention is All You Need" + :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. + + Similar to standard `dot` attention but uses + multiple attention distributions simulataneously + to select relevant items. + + .. mermaid:: + + graph BT + A[key] + B[value] + C[query] + O[output] + subgraph Attn + D[Attn 1] + E[Attn 2] + F[Attn N] + end + A --> D + C --> D + A --> E + C --> E + A --> F + C --> F + D --> O + E --> O + F --> O + B --> O + + Also includes several additional tricks. + + Args: + head_count (int): number of parallel heads + model_dim (int): the dimension of keys/values/queries, + must be divisible by head_count + dropout (float): dropout parameter + """ + + def __init__(self, head_count, model_dim, dropout=0.1, + max_relative_positions=0): + assert model_dim % head_count == 0 + self.dim_per_head = model_dim // head_count + self.model_dim = model_dim + + super(MultiHeadedAttention, self).__init__() + self.head_count = head_count + + self.linear_keys = nn.Linear(model_dim, + head_count * self.dim_per_head) + self.linear_values = nn.Linear(model_dim, + head_count * self.dim_per_head) + self.linear_query = nn.Linear(model_dim, + head_count * self.dim_per_head) + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.final_linear = nn.Linear(model_dim, model_dim) + + self.max_relative_positions = max_relative_positions + + if max_relative_positions > 0: + vocab_size = max_relative_positions * 2 + 1 + self.relative_positions_embeddings = nn.Embedding( + vocab_size, self.dim_per_head) + + def forward(self, key, value, query, mask=None, + layer_cache=None, type=None): + """ + Compute the context vector and the attention vectors. + + Args: + key (FloatTensor): set of `key_len` + key vectors ``(batch, key_len, dim)`` + value (FloatTensor): set of `key_len` + value vectors ``(batch, key_len, dim)`` + query (FloatTensor): set of `query_len` + query vectors ``(batch, query_len, dim)`` + mask: binary mask indicating which keys have + non-zero attention ``(batch, query_len, key_len)`` + Returns: + (FloatTensor, FloatTensor): + + * output context vectors ``(batch, query_len, dim)`` + * one of the attention vectors ``(batch, query_len, key_len)`` + """ + + # CHECKS + # batch, k_len, d = key.size() + # batch_, k_len_, d_ = value.size() + # aeq(batch, batch_) + # aeq(k_len, k_len_) + # aeq(d, d_) + # batch_, q_len, d_ = query.size() + # aeq(batch, batch_) + # aeq(d, d_) + # aeq(self.model_dim % 8, 0) + # if mask is not None: + # batch_, q_len_, k_len_ = mask.size() + # aeq(batch_, batch) + # aeq(k_len_, k_len) + # aeq(q_len_ == q_len) + # END CHECKS + + batch_size = key.size(0) + dim_per_head = self.dim_per_head + head_count = self.head_count + key_len = key.size(1) + query_len = query.size(1) + device = key.device + + def shape(x): + """Projection.""" + return x.view(batch_size, -1, head_count, dim_per_head) \ + .transpose(1, 2) + + def unshape(x): + """Compute context.""" + return x.transpose(1, 2).contiguous() \ + .view(batch_size, -1, head_count * dim_per_head) + + # 1) Project key, value, and query. + if layer_cache is not None: + if type == "self": + query, key, value = self.linear_query(query),\ + self.linear_keys(query),\ + self.linear_values(query) + key = shape(key) + value = shape(value) + if layer_cache["self_keys"] is not None: + key = torch.cat( + (layer_cache["self_keys"].to(device), key), + dim=2) + if layer_cache["self_values"] is not None: + value = torch.cat( + (layer_cache["self_values"].to(device), value), + dim=2) + layer_cache["self_keys"] = key + layer_cache["self_values"] = value + elif type == "context": + query = self.linear_query(query) + if layer_cache["memory_keys"] is None: + key, value = self.linear_keys(key),\ + self.linear_values(value) + key = shape(key) + value = shape(value) + else: + key, value = layer_cache["memory_keys"],\ + layer_cache["memory_values"] + layer_cache["memory_keys"] = key + layer_cache["memory_values"] = value + else: + key = self.linear_keys(key) + value = self.linear_values(value) + query = self.linear_query(query) + key = shape(key) + value = shape(value) + + if self.max_relative_positions > 0 and type == "self": + key_len = key.size(2) + # 1 or key_len x key_len + relative_positions_matrix = generate_relative_positions_matrix( + key_len, self.max_relative_positions, + cache=True if layer_cache is not None else False) + # 1 or key_len x key_len x dim_per_head + relations_keys = self.relative_positions_embeddings( + relative_positions_matrix.to(device)) + # 1 or key_len x key_len x dim_per_head + relations_values = self.relative_positions_embeddings( + relative_positions_matrix.to(device)) + + query = shape(query) + + key_len = key.size(2) + query_len = query.size(2) + + # 2) Calculate and scale scores. + query = query / math.sqrt(dim_per_head) + # batch x num_heads x query_len x key_len + query_key = torch.matmul(query, key.transpose(2, 3)) + + if self.max_relative_positions > 0 and type == "self": + scores = query_key + relative_matmul(query, relations_keys, True) + else: + scores = query_key + scores = scores.float() + + if mask is not None: + mask = mask.unsqueeze(1) # [B, 1, 1, T_values] + scores = scores.masked_fill(mask, -1e18) + + # 3) Apply attention dropout and compute context vectors. + attn = self.softmax(scores).to(query.dtype) + drop_attn = self.dropout(attn) + + context_original = torch.matmul(drop_attn, value) + + if self.max_relative_positions > 0 and type == "self": + context = unshape(context_original + + relative_matmul(drop_attn, + relations_values, + False)) + else: + context = unshape(context_original) + + output = self.final_linear(context) + # CHECK + # batch_, q_len_, d_ = output.size() + # aeq(q_len, q_len_) + # aeq(batch, batch_) + # aeq(d, d_) + + # Return one attn + top_attn = attn \ + .view(batch_size, head_count, + query_len, key_len)[:, 0, :, :] \ + .contiguous() + + return output, top_attn + + def update_dropout(self, dropout): + self.dropout.p = dropout diff --git a/python/onmt/modules/position_ffn.py b/python/onmt/modules/position_ffn.py new file mode 100644 index 0000000..fb8df80 --- /dev/null +++ b/python/onmt/modules/position_ffn.py @@ -0,0 +1,41 @@ +"""Position feed-forward network from "Attention is All You Need".""" + +import torch.nn as nn + + +class PositionwiseFeedForward(nn.Module): + """ A two-layer Feed-Forward-Network with residual layer norm. + + Args: + d_model (int): the size of input for the first-layer of the FFN. + d_ff (int): the hidden layer size of the second-layer + of the FNN. + dropout (float): dropout probability in :math:`[0, 1)`. + """ + + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.dropout_1 = nn.Dropout(dropout) + self.relu = nn.ReLU() + self.dropout_2 = nn.Dropout(dropout) + + def forward(self, x): + """Layer definition. + + Args: + x: ``(batch_size, input_len, model_dim)`` + + Returns: + (FloatTensor): Output ``(batch_size, input_len, model_dim)``. + """ + + inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) + output = self.dropout_2(self.w_2(inter)) + return output + x + + def update_dropout(self, dropout): + self.dropout_1.p = dropout + self.dropout_2.p = dropout diff --git a/python/onmt/utils/__init__.py b/python/onmt/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/onmt/utils/misc.py b/python/onmt/utils/misc.py new file mode 100644 index 0000000..845668f --- /dev/null +++ b/python/onmt/utils/misc.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- + +import torch +import random +import inspect +from itertools import islice + + +def split_corpus(path, shard_size): + with open(path, "rb") as f: + if shard_size <= 0: + yield f.readlines() + else: + while True: + shard = list(islice(f, shard_size)) + if not shard: + break + yield shard + + +def aeq(*args): + """ + Assert all arguments have the same value + """ + arguments = (arg for arg in args) + first = next(arguments) + assert all(arg == first for arg in arguments), \ + "Not all arguments have the same value: " + str(args) + + +def sequence_mask(lengths, max_len=None): + """ + Creates a boolean mask from sequence lengths. + """ + batch_size = lengths.numel() + max_len = max_len or lengths.max() + return (torch.arange(0, max_len) + .type_as(lengths) + .repeat(batch_size, 1) + .lt(lengths.unsqueeze(1))) + + +def tile(x, count, dim=0): + """ + Tiles x on dimension dim count times. + """ + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + x = x.view(batch, -1) \ + .transpose(0, 1) \ + .repeat(count, 1) \ + .transpose(0, 1) \ + .contiguous() \ + .view(*out_size) + if dim != 0: + x = x.permute(perm).contiguous() + return x + + +def use_gpu(opt): + """ + Creates a boolean if gpu used + """ + return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ + (hasattr(opt, 'gpu') and opt.gpu > -1) + + +def set_random_seed(seed, is_cuda): + """Sets the random seed.""" + if seed > 0: + torch.manual_seed(seed) + # this one is needed for torchtext random call (shuffled iterator) + # in multi gpu it ensures datasets are read in the same order + random.seed(seed) + # some cudnn methods can be random even after fixing the seed + # unless you tell it to be deterministic + torch.backends.cudnn.deterministic = True + + if is_cuda and seed > 0: + # These ensure same initialization in multi gpu mode + torch.cuda.manual_seed(seed) + + +def generate_relative_positions_matrix(length, max_relative_positions, + cache=False): + """Generate the clipped relative positions matrix + for a given length and maximum relative positions""" + if cache: + distance_mat = torch.arange(-length+1, 1, 1).unsqueeze(0) + else: + range_vec = torch.arange(length) + range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1) + distance_mat = range_mat - range_mat.transpose(0, 1) + distance_mat_clipped = torch.clamp(distance_mat, + min=-max_relative_positions, + max=max_relative_positions) + # Shift values to be >= 0 + final_mat = distance_mat_clipped + max_relative_positions + return final_mat + + +def relative_matmul(x, z, transpose): + """Helper function for relative positions attention.""" + batch_size = x.shape[0] + heads = x.shape[1] + length = x.shape[2] + x_t = x.permute(2, 0, 1, 3) + x_t_r = x_t.reshape(length, heads * batch_size, -1) + if transpose: + z_t = z.transpose(1, 2) + x_tz_matmul = torch.matmul(x_t_r, z_t) + else: + x_tz_matmul = torch.matmul(x_t_r, z) + x_tz_matmul_r = x_tz_matmul.reshape(length, batch_size, heads, -1) + x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3) + return x_tz_matmul_r_t + + +def fn_args(fun): + """Returns the list of function arguments name.""" + return inspect.getfullargspec(fun).args