diff --git a/Pipfile b/Pipfile index 9860653..5d968eb 100644 --- a/Pipfile +++ b/Pipfile @@ -18,6 +18,8 @@ scikit-learn = "*" torch = "*" pytorch-ignite = "*" torchtext = "*" +pytorch-pretrained-bert = "*" +pandas = "*" [requires] python_version = "3.6" diff --git a/Pipfile.lock b/Pipfile.lock index 6ada896..308cfa6 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "9c3ebbd53468ba536a91e2972c0f7affb2004c1787485c72b4bda1979d84d074" + "sha256": "5a112ee07d4119f2a059b1d324542eec71ab5941f427813b3e0fd055c1c2f640" }, "pipfile-spec": 6, "requires": { @@ -52,6 +52,20 @@ ], "version": "==3.1.0" }, + "boto3": { + "hashes": [ + "sha256:1b4a86e1167ba7cbb9dbf2a0a0b86447b35a2b901ae5aace75b8196631680957", + "sha256:f5b12367c530dac45782251b672f1e911da5c74285f89850b0f4f5694b8c388c" + ], + "version": "==1.9.115" + }, + "botocore": { + "hashes": [ + "sha256:7c8ec120bc5bcc4076aebd7dac3a679777ff3a3ce3263c64d7342ea7982b578c", + "sha256:f4607f8800f87fd8eacd450699666f92d7fbc48fbb757903ad56825ce08e072a" + ], + "version": "==1.12.115" + }, "certifi": { "hashes": [ "sha256:59b7658e26ca9c7339e00f8f4636cdfe59d34fa37b9b04f6f9e9926b3cece1a5", @@ -127,6 +141,14 @@ ], "version": "==0.2.9" }, + "docutils": { + "hashes": [ + "sha256:02aec4bd92ab067f6ff27a38a38a41173bf01bed8f89157768c1573f53e474a6", + "sha256:51e64ef2ebfb29cae1faa133b3710143496eca21c530f3f71424d77687764274", + "sha256:7a4bd47eaf6596e1295ecb11361139febe29b084a87bf005bf899f9a42edc3c6" + ], + "version": "==0.14" + }, "entrypoints": { "hashes": [ "sha256:589f874b313739ad35be6e0cd7efde2a4e9b6fea91edcc34e58ecbb8dbe56d19", @@ -198,6 +220,13 @@ ], "version": "==2.10" }, + "jmespath": { + "hashes": [ + "sha256:3720a4b1bd659dd2eecad0666459b9788813e032b83e7ba58578e48254e0a0e6", + "sha256:bde2aef6f44302dfb30320115b17d030798de8c4110e28d5cf6cf91a7a31074c" + ], + "version": "==0.9.4" + }, "jsonschema": { "hashes": [ "sha256:0c0a81564f181de3212efa2d17de1910f8732fa1b71c42266d983cd74304e20d", @@ -453,6 +482,7 @@ "sha256:cc8fc0c7a8d5951dc738f1c1447f71c43734244453616f32b8aa0ef6013a5dfb", "sha256:d7b460bc316064540ce0c41c1438c416a40746fd8a4fb2999668bf18f3c4acf1" ], + "index": "pypi", "version": "==0.24.2" }, "pandocfilters": { @@ -587,6 +617,7 @@ "sha256:7e6584c74aeed623791615e26efd690f29817a27c73085b78e4bad02493df2fb", "sha256:c89805f6f4d64db21ed966fda138f8a5ed7a4fdbc1a8ee329ce1b74e3c74da9e" ], + "markers": "python_version >= '2.7'", "version": "==2.8.0" }, "pytorch-ignite": { @@ -597,6 +628,15 @@ "index": "pypi", "version": "==0.1.2" }, + "pytorch-pretrained-bert": { + "hashes": [ + "sha256:138c9702cc8da0c949a3b266a0c6e436aee4ae1c722b5d3eb1e47fb4b2b0f197", + "sha256:9ec5998f501381d86d6e0b4c4d92c1c2888f3f093e3a13177b3b94494b1bf7d7", + "sha256:f30ae5d19a95b64bd7068170640608cc457488948aa7643855aa261c2c8ab8b7" + ], + "index": "pypi", + "version": "==0.6.1" + }, "pytz": { "hashes": [ "sha256:32b0891edff07e28efe91284ed9c31e123d84bea3fd98e1f72be2508f43ef8d9", @@ -668,6 +708,13 @@ ], "version": "==2.21.0" }, + "s3transfer": { + "hashes": [ + "sha256:7b9ad3213bff7d357f888e0fab5101b56fa1a0548ee77d121c3a3dbfbef4cb2e", + "sha256:f23d5cb7d862b104401d9021fc82e5fa0e0cf57b7660a1331425aab0c691d021" + ], + "version": "==0.2.0" + }, "scikit-learn": { "hashes": [ "sha256:018f470a7e685767d84ce6fac87af59e064e87ec3cea71eaf12646f9538e293d", @@ -875,6 +922,7 @@ "sha256:61bf29cada3fc2fbefad4fdf059ea4bd1b4a86d2b6d15e1c7c0b582b9752fe39", "sha256:de9529817c93f27c8ccbfead6985011db27bd0ddfcdb2d86f3f663385c6a9c22" ], + "markers": "python_version >= '3.4'", "version": "==1.24.1" }, "wcwidth": { diff --git a/bin/test.py b/bin/test.py index 4a3271d..0b56a1b 100644 --- a/bin/test.py +++ b/bin/test.py @@ -10,9 +10,9 @@ from torch.utils.data import DataLoader +from src.vocab import get_vocab from src.matching_network import MatchingNetwork from src.evaluation import (predict, save_predictions, generate_episode_data) -from src.data import read_vocab, read_data_set from src.datasets import EpisodesSampler, EpisodesDataset from src.utils import extract_model_parameters, get_model_name @@ -50,11 +50,11 @@ parser.add_argument("test_set", help="Path to the test CSV file") -def _load_model(model_path): +def _load_model(model_path, vocab): model_file_name = os.path.basename(args.model) distance, embeddings, N, k = extract_model_parameters(model_file_name) model_name = get_model_name(distance, embeddings, N, k) - model = MatchingNetwork(model_name, distance_metric=distance) + model = MatchingNetwork(model_name, vocab, distance_metric=distance) model_state_dict = torch.load(model_path) model.load_state_dict(model_state_dict) @@ -63,11 +63,11 @@ def _load_model(model_path): def main(args): print("Loading model...") - model, _, N, k = _load_model(args.model) + vocab = get_vocab(args.embeddings, args.vocab) + model, embeddings, N, k = _load_model(args.model, vocab) print("Loading dataset...") - vocab = read_vocab(args.vocab) - X_test, y_test = read_data_set(args.test_set, vocab) + X_test, y_test = vocab.to_tensors(args.test_set) test_set = EpisodesDataset(X_test, y_test, k=k) sampler = EpisodesSampler(test_set, N=N, episodes_multiplier=30) test_loader = DataLoader(test_set, sampler=sampler, batch_size=BATCH_SIZE) diff --git a/bin/train.py b/bin/train.py index 17942ee..9eabfa8 100644 --- a/bin/train.py +++ b/bin/train.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader -from src.data import read_vocab, read_data_set +from src.vocab import get_vocab from src.datasets import EpisodesSampler, EpisodesDataset from src.matching_network import MatchingNetwork from src.training import train @@ -44,6 +44,14 @@ type=str, default='cosine', help="Distance metric to be used") +parser.add_argument( + "-e", + "--embeddings", + action="store", + dest="embeddings", + type=str, + default='vanilla', + help="Type of embedding") parser.add_argument( "-p", "--processing-steps", @@ -65,8 +73,8 @@ def _get_loader(data_set, N, episodes_multiplier=1): def main(args): print("Loading dataset...") - vocab = read_vocab(args.vocab) - X_train, y_train = read_data_set(args.training_set, vocab) + vocab = get_vocab(args.embeddings, args.vocab) + X_train, y_train = vocab.to_tensors(args.training_set) # Split training further into train and valid X_train, X_valid, y_train, y_valid = train_test_split_tensors( @@ -77,11 +85,12 @@ def main(args): print("Initialising model...") model_name = get_model_name( distance=args.distance_metric, - embeddings='vanilla', + embeddings=args.embeddings, N=args.N, k=args.k) model = MatchingNetwork( model_name, + vocab, fce=True, processing_steps=args.processing_steps, distance_metric=args.distance_metric) diff --git a/bin/vocab.py b/bin/vocab.py index ce4a309..e625130 100644 --- a/bin/vocab.py +++ b/bin/vocab.py @@ -4,7 +4,7 @@ from argparse import ArgumentParser -from src.data import generate_vocab, store_vocab +from src.vocab import VanillaVocab parser = ArgumentParser() parser.add_argument("input", help="Path to the input CSV data set") @@ -13,10 +13,10 @@ def main(args): print("Generating vocab...") - vocab = generate_vocab(args.input) + vocab = VanillaVocab.generate_vocab(args.input) print("Storing vocab...") - store_vocab(vocab, args.output) + vocab.save(args.output) print(f"Stored vocab of size {len(vocab)} at {args.output}") diff --git a/src/data.py b/src/data.py index f1e88f1..b1d1be5 100644 --- a/src/data.py +++ b/src/data.py @@ -9,6 +9,9 @@ from torchtext.data import Field, TabularDataset from torchtext.vocab import Vocab +print( + "[WARNING] Don't use src.data anymore. Use the Vocab interfaces instead.") + VOCAB_SIZE = 27443 PADDING_TOKEN_INDEX = 1 diff --git a/src/evaluation.py b/src/evaluation.py index c61f551..0329e49 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -6,8 +6,6 @@ import torch import numpy as np -from .data import reverse_tensor - RESULTS_PATH = os.path.join( os.path.dirname(os.path.dirname(__file__)), "results") @@ -174,11 +172,11 @@ def _episode_to_text(support_set, targets, labels, target_labels, vocab): # First, we need to flatten these... N, k, _ = support_set.shape flat_support_set = support_set.view(N * k, -1) - flat_support_set = reverse_tensor(flat_support_set, vocab) + flat_support_set = vocab.to_text(flat_support_set) support_set = flat_support_set.reshape(N, k) - targets = reverse_tensor(targets, vocab) - labels = reverse_tensor(labels, vocab) - target_labels = reverse_tensor(target_labels, vocab) + targets = vocab.to_text(targets) + labels = vocab.to_text(labels) + target_labels = vocab.to_text(target_labels) return support_set, targets, labels, target_labels diff --git a/src/matching_network.py b/src/matching_network.py index 367f07d..95dc881 100644 --- a/src/matching_network.py +++ b/src/matching_network.py @@ -1,8 +1,11 @@ import torch + from torch import nn from torch.nn import functional as F + +from pytorch_pretrained_bert import BertModel + from .similarity import get_similarity_func -from .data import VOCAB_SIZE, PADDING_TOKEN_INDEX class EncodingLayer(nn.Module): @@ -11,23 +14,33 @@ class EncodingLayer(nn.Module): embedding. """ - def __init__(self, vocab_size, encoding_size): + def __init__(self, encoding_size, vocab): """ Initialises the encoding layer. Parameters --- - vocab_size : int - Size of the vocabulary to do one-hot encodings. encoding_size : int Target size of the encoding. + vocab : AbstractVocab + Vocabulary used for the encodings. """ super().__init__() - self.encoding_layer = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=encoding_size, - padding_idx=PADDING_TOKEN_INDEX) + self.vocab_size = len(vocab) + self.padding_token_index = vocab.padding_token_index + self.embeddings = vocab.name + + if self.embeddings == "bert": + bert_encoding_size = 768 + self.bert_layer = BertModel.from_pretrained('bert-base-uncased') + self.encoding_layer = nn.Linear( + in_features=bert_encoding_size, out_features=encoding_size) + else: + self.encoding_layer = nn.Embedding( + num_embeddings=self.vocab_size, + embedding_dim=encoding_size, + padding_idx=self.padding_token_index) def forward(self, sentences): """ @@ -58,8 +71,24 @@ def forward(self, sentences): sen_length = sentences.shape[2] flattened = reshaped.reshape(-1, sen_length) - encoded_flat = self.encoding_layer(flattened) - pooled_flat = encoded_flat.sum(dim=1) + + if self.embeddings == "bert": + # We don't want to fine-tune BERT! + with torch.no_grad(): + encoded_layers, _ = self.bert_layer(flattened) + + # We have a hidden states for each of the 12 layers + # in model bert-base-uncased + + # Remove useless dimension + encoded_flat = torch.squeeze(encoded_layers[11]) + pooled_flat = encoded_flat.sum(dim=1) + + # Reduce dimensionality to 64 + pooled_flat = self.encoding_layer(pooled_flat) + else: + encoded_flat = self.encoding_layer(flattened) + pooled_flat = encoded_flat.sum(dim=1) # Re-shape into original form (4D or 3D tensor) enc_size = pooled_flat.shape[1] @@ -249,8 +278,8 @@ class MatchingNetwork(nn.Module): def __init__(self, name, + vocab, fce=True, - vocab_size=VOCAB_SIZE, processing_steps=5, distance_metric="cosine"): """ @@ -260,10 +289,10 @@ def __init__(self, --- name : str Name of the model. Used for storing checkpoints. + vocab : AbstractVocab + AbstractVocab object. fce : bool Flag to decide if we should use Full Context Embeddings. - vocab_size : int - Size of the vocabulary to do one-hot encodings. processing_steps : int How many processing steps to take when embedding the target query. @@ -275,9 +304,9 @@ def __init__(self, self.name = name self.encoding_size = 64 - self.vocab_size = vocab_size + self.vocab_size = len(vocab) - self.encode = EncodingLayer(self.vocab_size, self.encoding_size) + self.encode = EncodingLayer(self.encoding_size, vocab) self.g = GLayer(self.encoding_size, fce=fce) self.f = FLayer(self.encoding_size, processing_steps=processing_steps) @@ -301,7 +330,8 @@ def _similarity(self, support_embeddings, target_embeddings): """ batch_size, N, k, _ = support_embeddings.shape T = target_embeddings.shape[1] - similarities = torch.zeros(batch_size, T, N, k) + similarities = torch.zeros((batch_size, T, N, k), + device=support_embeddings.device) similarity_func = get_similarity_func(self.distance_metric) # TODO: Would be good to optimise this so that it's vectorised. @@ -369,7 +399,8 @@ def _to_logits(self, attention, labels): # Sum across labels attention = attention.sum(dim=3) batch_size, T, N = attention.shape - logits = torch.zeros((batch_size, T, self.vocab_size)) + logits = torch.zeros((batch_size, T, self.vocab_size), + device=attention.device) # TODO: Would be good to optimise this so that it's vectorised. for batch_idx in range(batch_size): diff --git a/src/vocab.py b/src/vocab.py new file mode 100644 index 0000000..424dd12 --- /dev/null +++ b/src/vocab.py @@ -0,0 +1,429 @@ +import json +import torch +import numpy as np +import pandas as pd + +from collections import defaultdict, Counter + +from torchtext.vocab import Vocab +from torchtext.data import Field, TabularDataset + +from pytorch_pretrained_bert import BertTokenizer + + +class AbstractVocab(object): + """ + Abstract interface for the Vocab classes which allows to map between text + and numbers. + """ + + name = "" + padding_token_index = 0 + + def __len__(self): + raise NotImplementedError() + + def to_tensors(self, file_path): + raise NotImplementedError() + + def to_text(self, X): + raise NotImplementedError() + + +class VanillaVocab(AbstractVocab): + """ + Allows to map between text and numbers using a simple tokenizer. + """ + + name = "vanilla" + padding_token_index = 1 + + def __init__(self, file_path): + """ + Initialise the vocabulary by reading it from a file path. + + Parameters + --- + file_path : str + Path to the vocab file. + """ + super().__init__() + + self.vocab = self._read_vocab(file_path) + + def _read_vocab(self, file_path): + """ + Reads a vocab from its previously stored state. + + Inspired by https://github.com/pytorch/text/issues/253#issue-305929871 + + Parameters + --- + file_path : str + Path to the JSON file with the vocab info. + + Returns + --- + vocab : torchtext.Vocab + Vocabulary created. + """ + vocab_state = {} + with open(file_path) as file: + vocab_state = json.load(file) + + vocab = Vocab(Counter()) + vocab.__dict__.update(vocab_state) + vocab.stoi = defaultdict(lambda: 0, vocab.stoi) + return vocab + + def __len__(self): + """ + Returns the size of the vocabulary. + + Returns + --- + int + Number of tokens in the vocabulary. + """ + return len(self.vocab) + + def save(self, file_path): + """ + Stores a vocab in a JSON file. + + Inspired by https://github.com/pytorch/text/issues/253#issue-305929871 + + Parameters + --- + file_path : str + Path to the vocab state to write. + + """ + vocab_state = dict(self.vocab.__dict__, stoi=dict(self.vocab.stoi)) + with open(file_path, 'w') as file: + json.dump(vocab_state, file) + + def to_tensors(self, file_path): + """ + Reads the data set from one of the pre-processed CSVs composed + of columns `label` and `sentence`. + + Parameters + --- + file_path : str + Path to the CSV file. + + Returns + --- + X : torch.Tensor[num_labels x num_examples x sen_length] + Sentences on the dataset grouped by labels. + y : torch.Tensor[num_labels] + Labels for each group of sentences. + """ + sentence = Field( + batch_first=True, sequential=True, tokenize=self._tokenizer) + sentence.vocab = self.vocab + + label = Field(is_target=True) + label.vocab = self.vocab + + data_set = TabularDataset( + path=file_path, + format='csv', + skip_header=True, + fields=[('label', label), ('sentence', sentence)]) + + sentences_tensor = sentence.process(data_set.sentence) + labels_tensor = label.process(data_set.label).squeeze() + + return _reshape_tensors(sentences_tensor, labels_tensor) + + @classmethod + def _tokenizer(cls, text): + """ + Simple tokenizer which splits the token by the space + character. The CSVs have already been pre-processed with + spaCy, therefore this should be enough. + + Parameters + --- + text : str + Input text to tokenize. + + Returns + --- + iterator + Iterator over token text. + """ + return text.split(' ') + + def to_text(self, X): + """ + Reverses some numericalised tensor into text. + + Parameters + ---- + X : torch.Tensor[num_elements x sen_length] + Sentences on the tensor. + + Returns + ---- + sentences : np.array[num_elements] + Array of strings. + """ + sentences = [] + for sentence_tensor in X: + if len(sentence_tensor.shape) == 0: + # 0-D tensor + sentences.append(self.vocab.itos[sentence_tensor]) + continue + + sentence = [ + self.vocab.itos[token] for token in sentence_tensor + if token != self.padding_token_index + ] + sentences.append(' '.join(sentence)) + + return np.array(sentences) + + @classmethod + def generate_vocab(cls, file_path): + """ + Generate the vocabulary from one of the pre-processed CSVs composed + of columns `label` and `sentence`. + + Parameters + --- + file_path : str + Path to the CSV file. + + Returns + --- + vocab : torchtext.Vocab + Vocabulary generated from the file. + """ + text = Field(sequential=True, tokenize=cls._tokenizer) + + data_set = TabularDataset( + path=file_path, + format='csv', + skip_header=True, + fields=[('label', text), ('sentence', text)]) + + text.build_vocab(data_set.label, data_set.sentence) + return text.vocab + + +class BertVocab(AbstractVocab): + """ + Implementation of mappings between text and tensors using Bert. + """ + + name = "bert" + padding_token_index = 0 + + def __init__(self, *args, **kwargs): + """ + Initialise Bert's tokenizer. + """ + super().__init__() + + self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + + def __len__(self): + """ + Returns the length of Bert's vocabulary. + + Returns + --- + int + Length of the vocabulary. + """ + return len(self.tokenizer.vocab) + + def to_tensors(self, file_path): + """ + Reads the data set from one of the pre-processed CSVs composed + of columns `label` and `sentence`. + + Parameters + --- + file_path : str + Path to the CSV file. + + Returns + --- + X : torch.Tensor[num_labels x num_examples x sen_length] + Sentences on the dataset grouped by labels. + y : torch.Tensor[num_labels] + Labels for each group of sentences. + """ + data_set = pd.read_csv(file_path) + + # Convert into tokens and find max sen length + sentences_tokens, labels_tokens, sen_length = self._to_tokens(data_set) + + # Convert into tensors + num_elems = len(sentences_tokens) + sentences_tensor = torch.zeros((num_elems, sen_length)) + labels_tensor = torch.zeros(num_elems) + + for idx in range(num_elems): + tensor_sentence = self.tokenizer.convert_tokens_to_ids( + sentences_tokens[idx]) + tensor_label = self.tokenizer.convert_tokens_to_ids( + labels_tokens[idx]) + + sentences_tensor[idx, :len(tensor_sentence)] = torch.Tensor( + tensor_sentence) + labels_tensor[idx] = tensor_label[0] + + return _reshape_tensors(sentences_tensor, labels_tensor) + + def _to_tokens(self, data_set): + """ + Tokenize the dataset. + + Parameters + --- + data_set : pd.DataFrame[label, sentence] + Dataset with two columns. + + Returns + --- + sentences_tokens : list + List of tokenized sentences. + labels_tokens : list + List of tokenized labels. + sen_length : int + Maximum sentence length. + """ + sentences_tokens = [] + labels_tokens = [] + sen_length = 0 + for idx, row in data_set.iterrows(): + token_sentence = self._tokenize(row['sentence']) + token_label = self._tokenize(row['label']) + + # TODO: This is a shortcut to avoid dealing with + # situations where Bert's word-piece tokenizer + # splits a label into multiple tokens and thus + # multiple token ids to predict for a single input. + if len(token_label) > 1: + continue + # raise ValueError(f"Label '{row['label']}' was split " + # f"into more than one tokens: " + # f"{token_label}") + + length = len(token_sentence) + if length > sen_length: + sen_length = length + + sentences_tokens.append(token_sentence) + labels_tokens.append(token_label) + + return sentences_tokens, labels_tokens, sen_length + + def _tokenize(self, text): + """ + Tokenize a text using Bert's tokenizer but processing it first to + replace: + + - => [UNK] + - => [MASK] + + Parameters + --- + text : str + Input string. + + Returns + --- + list + List of tokens. + """ + with_unk = text.replace('', '[UNK]') + with_mask = with_unk.replace('', '[MASK]') + + return self.tokenizer.tokenize(with_mask) + + def to_text(self, X): + """ + Reverses some numericalised tensor into text. + + Parameters + ---- + X : torch.Tensor[num_elements x sen_length] + Sentences on the tensor. + + Returns + ---- + sentences : np.array[num_elements] + Array of strings. + """ + sentences = [] + for sentence_tensor in X: + if len(sentence_tensor.shape) == 0: + # 0-D tensor + sentences.append(self.tokenizer.ids_to_tokens[sentence_tensor]) + continue + + sentence = [ + self.tokenizer.ids_to_tokens[token_id] + for token_id in sentence_tensor + if token_id != self.padding_token_index + ] + sentences.append(' '.join(sentence)) + + return np.array(sentences) + + +def _reshape_tensors(sentences_tensor, labels_tensor): + """ + Reshape tensors to the [N x k x sen_lenth] structure. + + Parameters + --- + sentences_tensor : torch.Tensor[num_elems x sen_length] + Flat tensor with all the sentences. + labels_tensor : torch.Tensor[num_elems] + Flat tensor with all the labels. + + Returns + --- + X : torch.Tensor[num_labels x num_examples x sen_length] + Sentences on the dataset grouped by labels. + y : torch.Tensor[num_labels] + Labels for each group of sentences. + """ + # Infer num_labels and num_examples by label + num_labels = labels_tensor.unique().shape[0] + num_examples = labels_tensor.shape[0] // num_labels + y = labels_tensor[::num_examples] + + # More robust to potentially duplicated labels + num_labels = y.shape[0] + + sen_length = sentences_tensor.shape[-1] + X = sentences_tensor.view(num_labels, num_examples, sen_length) + + return X, y + + +VOCABS = {'vanilla': VanillaVocab, 'bert': BertVocab} + + +def get_vocab(embeddings, *args, **kwargs): + """ + Returns an initialised vocab, forwarding the extra args and kwargs. + + Parameters + --- + embeddings : str + Embeddings to use. Can be one of the VOCABS keys. + + Returns + --- + AbstractVocab + """ + return VOCABS[embeddings](*args, **kwargs)