From a9a49adf49fb499fc8df64115fb9ac211ef6c280 Mon Sep 17 00:00:00 2001 From: e3oroush Date: Thu, 9 Dec 2021 12:13:22 +0100 Subject: [PATCH] verbose message to increase size_embeddings_count --- jerex/data_module.py | 15 ++++++++++----- jerex/datasets.py | 7 +++++-- jerex/model.py | 3 ++- jerex/sampling/sampling_joint.py | 4 +++- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/jerex/data_module.py b/jerex/data_module.py index 95fef0a..ee7ad62 100644 --- a/jerex/data_module.py +++ b/jerex/data_module.py @@ -19,7 +19,8 @@ def __init__(self, tokenizer: BertTokenizer, task_type: str, types_path: str = N sampling_processes: int = 4, neg_mention_count: int = 50, neg_relation_count: int = 50, neg_coref_count: int = 50, max_span_size: int = 10, neg_mention_overlap_ratio: float = 0.5, - final_valid_evaluate: bool = False): + final_valid_evaluate: bool = False, + size_embeddings_count: int = 30): super().__init__() if types_path is not None: @@ -56,6 +57,7 @@ def __init__(self, tokenizer: BertTokenizer, task_type: str, types_path: str = N self._neg_coref_count = neg_coref_count self._max_span_size = max_span_size self._neg_mention_overlap_ratio = neg_mention_overlap_ratio + self._size_embeddings_count = size_embeddings_count self._train_path = train_path self._valid_path = valid_path @@ -81,7 +83,8 @@ def setup(self, stage=None): neg_rel_count=self._neg_relation_count, max_span_size=self._max_span_size, neg_mention_overlap_ratio=self._neg_mention_overlap_ratio, - tokenizer=self._tokenizer) + tokenizer=self._tokenizer, + size_embeddings_count=self._size_embeddings_count) self._train_dataset.switch_task(self._task_type) self._train_dataset.switch_mode(DocREDDataset.TRAIN_MODE) @@ -91,7 +94,8 @@ def setup(self, stage=None): entity_types=self._entity_types, relation_types=self._relation_types, max_span_size=self._max_span_size, - tokenizer=self._tokenizer) + tokenizer=self._tokenizer, + size_embeddings_count=self._size_embeddings_count) self._valid_dataset.switch_task(self._task_type) self._valid_dataset.switch_mode(DocREDDataset.INFERENCE_MODE) @@ -104,7 +108,8 @@ def setup(self, stage=None): entity_types=self._entity_types, relation_types=self._relation_types, max_span_size=self._max_span_size, - tokenizer=self._tokenizer) + tokenizer=self._tokenizer, + size_embeddings_count=self._size_embeddings_count) else: self._test_dataset = self._valid_dataset @@ -112,7 +117,7 @@ def setup(self, stage=None): self._test_dataset.switch_mode(DocREDDataset.INFERENCE_MODE) def train_dataloader(self): - return DataLoader(self._train_dataset, batch_size=self._train_batch_size, shuffle=True, drop_last=True, + return DataLoader(self._train_dataset, batch_size=self._train_batch_size, shuffle=False, drop_last=True, num_workers=self._sampling_processes, collate_fn=collate_fn_padding) diff --git a/jerex/datasets.py b/jerex/datasets.py index 91cc7a1..ce8ae41 100644 --- a/jerex/datasets.py +++ b/jerex/datasets.py @@ -19,7 +19,8 @@ class DocREDDataset(TorchDataset): INFERENCE_MODE = 'inference' def __init__(self, dataset_path, entity_types, relation_types, tokenizer, neg_mention_count=200, - neg_rel_count=200, neg_coref_count=200, max_span_size=10, neg_mention_overlap_ratio=0.5): + neg_rel_count=200, neg_coref_count=200, max_span_size=10, neg_mention_overlap_ratio=0.5, + size_embeddings_count=30): self._dataset_path = dataset_path self._entity_types = entity_types self._relation_types = relation_types @@ -29,6 +30,7 @@ def __init__(self, dataset_path, entity_types, relation_types, tokenizer, neg_me self._max_span_size = max_span_size self._neg_mention_overlap_ratio = neg_mention_overlap_ratio self._tokenizer = tokenizer + self._size_embeddings_count = size_embeddings_count self._mode = DocREDDataset.TRAIN_MODE self._task = None @@ -204,7 +206,8 @@ def __getitem__(self, index: int): return sampling_joint.create_joint_train_sample(doc, self._neg_mention_count, self._neg_rel_count, self._neg_coref_count, self._max_span_size, self._neg_mention_overlap_ratio, - len(self._relation_types)) + len(self._relation_types), + self._size_embeddings_count) elif self._task == TaskType.MENTION_LOCALIZATION: return sampling_classify.create_mention_classify_train_sample(doc, self._neg_mention_count, self._max_span_size, diff --git a/jerex/model.py b/jerex/model.py index ecec710..76c9d3d 100644 --- a/jerex/model.py +++ b/jerex/model.py @@ -286,7 +286,8 @@ def train(cfg: TrainConfig): max_span_size=cfg.sampling.max_span_size, neg_mention_overlap_ratio=cfg.sampling.neg_mention_overlap_ratio, final_valid_evaluate=cfg.misc.final_valid_evaluate - and cfg.datasets.test_path is None) + and cfg.datasets.test_path is None, + size_embeddings_count=cfg.model.size_embeddings_count) data_module.setup('fit') diff --git a/jerex/sampling/sampling_joint.py b/jerex/sampling/sampling_joint.py index 11ac3ee..791af51 100644 --- a/jerex/sampling/sampling_joint.py +++ b/jerex/sampling/sampling_joint.py @@ -7,7 +7,8 @@ def create_joint_train_sample(doc: Document, neg_mention_count: int, neg_rel_count: int, neg_coref_count: int, - max_span_size: int, neg_mention_overlap_ratio: float, rel_type_count: int): + max_span_size: int, neg_mention_overlap_ratio: float, rel_type_count: int, + size_embeddings_count: int): encodings = doc.encodings # document sub-word encoding context_size = len(encodings) @@ -89,6 +90,7 @@ def create_joint_train_sample(doc: Document, neg_mention_count: int, neg_rel_cou assert len(coref_mention_pairs) == len(coref_sample_masks) == len(coref_types) == len(coref_eds) assert len(entities) == len(entity_types) assert len(rel_entity_pairs) == len(rel_types) + assert mention_sizes.max().item() < size_embeddings_count, f"You should increase the `size_embeddings_count` config.model section atleast to {mention_sizes.max()}" return dict(encodings=encodings, context_masks=context_masks, mention_masks=mention_masks, mention_sizes=mention_sizes, mention_types=mention_types, mention_sample_masks=mention_sample_masks,