Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

verbose message to increase size_embeddings_count #11

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions jerex/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def __init__(self, tokenizer: BertTokenizer, task_type: str, types_path: str = N
sampling_processes: int = 4, neg_mention_count: int = 50,
neg_relation_count: int = 50, neg_coref_count: int = 50,
max_span_size: int = 10, neg_mention_overlap_ratio: float = 0.5,
final_valid_evaluate: bool = False):
final_valid_evaluate: bool = False,
size_embeddings_count: int = 30):
super().__init__()

if types_path is not None:
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, tokenizer: BertTokenizer, task_type: str, types_path: str = N
self._neg_coref_count = neg_coref_count
self._max_span_size = max_span_size
self._neg_mention_overlap_ratio = neg_mention_overlap_ratio
self._size_embeddings_count = size_embeddings_count

self._train_path = train_path
self._valid_path = valid_path
Expand All @@ -81,7 +83,8 @@ def setup(self, stage=None):
neg_rel_count=self._neg_relation_count,
max_span_size=self._max_span_size,
neg_mention_overlap_ratio=self._neg_mention_overlap_ratio,
tokenizer=self._tokenizer)
tokenizer=self._tokenizer,
size_embeddings_count=self._size_embeddings_count)

self._train_dataset.switch_task(self._task_type)
self._train_dataset.switch_mode(DocREDDataset.TRAIN_MODE)
Expand All @@ -91,7 +94,8 @@ def setup(self, stage=None):
entity_types=self._entity_types,
relation_types=self._relation_types,
max_span_size=self._max_span_size,
tokenizer=self._tokenizer)
tokenizer=self._tokenizer,
size_embeddings_count=self._size_embeddings_count)

self._valid_dataset.switch_task(self._task_type)
self._valid_dataset.switch_mode(DocREDDataset.INFERENCE_MODE)
Expand All @@ -104,15 +108,16 @@ def setup(self, stage=None):
entity_types=self._entity_types,
relation_types=self._relation_types,
max_span_size=self._max_span_size,
tokenizer=self._tokenizer)
tokenizer=self._tokenizer,
size_embeddings_count=self._size_embeddings_count)
else:
self._test_dataset = self._valid_dataset

self._test_dataset.switch_task(self._task_type)
self._test_dataset.switch_mode(DocREDDataset.INFERENCE_MODE)

def train_dataloader(self):
return DataLoader(self._train_dataset, batch_size=self._train_batch_size, shuffle=True, drop_last=True,
return DataLoader(self._train_dataset, batch_size=self._train_batch_size, shuffle=False, drop_last=True,
num_workers=self._sampling_processes,
collate_fn=collate_fn_padding)

Expand Down
7 changes: 5 additions & 2 deletions jerex/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class DocREDDataset(TorchDataset):
INFERENCE_MODE = 'inference'

def __init__(self, dataset_path, entity_types, relation_types, tokenizer, neg_mention_count=200,
neg_rel_count=200, neg_coref_count=200, max_span_size=10, neg_mention_overlap_ratio=0.5):
neg_rel_count=200, neg_coref_count=200, max_span_size=10, neg_mention_overlap_ratio=0.5,
size_embeddings_count=30):
self._dataset_path = dataset_path
self._entity_types = entity_types
self._relation_types = relation_types
Expand All @@ -29,6 +30,7 @@ def __init__(self, dataset_path, entity_types, relation_types, tokenizer, neg_me
self._max_span_size = max_span_size
self._neg_mention_overlap_ratio = neg_mention_overlap_ratio
self._tokenizer = tokenizer
self._size_embeddings_count = size_embeddings_count

self._mode = DocREDDataset.TRAIN_MODE
self._task = None
Expand Down Expand Up @@ -204,7 +206,8 @@ def __getitem__(self, index: int):
return sampling_joint.create_joint_train_sample(doc, self._neg_mention_count, self._neg_rel_count,
self._neg_coref_count,
self._max_span_size, self._neg_mention_overlap_ratio,
len(self._relation_types))
len(self._relation_types),
self._size_embeddings_count)
elif self._task == TaskType.MENTION_LOCALIZATION:
return sampling_classify.create_mention_classify_train_sample(doc, self._neg_mention_count,
self._max_span_size,
Expand Down
3 changes: 2 additions & 1 deletion jerex/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def train(cfg: TrainConfig):
max_span_size=cfg.sampling.max_span_size,
neg_mention_overlap_ratio=cfg.sampling.neg_mention_overlap_ratio,
final_valid_evaluate=cfg.misc.final_valid_evaluate
and cfg.datasets.test_path is None)
and cfg.datasets.test_path is None,
size_embeddings_count=cfg.model.size_embeddings_count)

data_module.setup('fit')

Expand Down
4 changes: 3 additions & 1 deletion jerex/sampling/sampling_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@


def create_joint_train_sample(doc: Document, neg_mention_count: int, neg_rel_count: int, neg_coref_count: int,
max_span_size: int, neg_mention_overlap_ratio: float, rel_type_count: int):
max_span_size: int, neg_mention_overlap_ratio: float, rel_type_count: int,
size_embeddings_count: int):
encodings = doc.encodings # document sub-word encoding
context_size = len(encodings)

Expand Down Expand Up @@ -89,6 +90,7 @@ def create_joint_train_sample(doc: Document, neg_mention_count: int, neg_rel_cou
assert len(coref_mention_pairs) == len(coref_sample_masks) == len(coref_types) == len(coref_eds)
assert len(entities) == len(entity_types)
assert len(rel_entity_pairs) == len(rel_types)
assert mention_sizes.max().item() < size_embeddings_count, f"You should increase the `size_embeddings_count` config.model section atleast to {mention_sizes.max()}"

return dict(encodings=encodings, context_masks=context_masks, mention_masks=mention_masks,
mention_sizes=mention_sizes, mention_types=mention_types, mention_sample_masks=mention_sample_masks,
Expand Down