-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
147 lines (117 loc) · 5.28 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from typing import Any
class BilingualDataset(Dataset):
"""
A PyTorch Dataset class for handling bilingual datasets.
This dataset handles bilingual text pairs, applies tokenization,
and generates inputs and masks for encoder-decoder models.
Args:
ds (Dataset): The dataset containing translation pairs.
tokenizer_src (Tokenizer): The tokenizer for the source language.
tokenizer_tgt (Tokenizer): The tokenizer for the target language.
src_lang (str): The source language code.
tgt_lang (str): The target language code.
seq_len (int): The fixed sequence length for inputs and outputs.
Attributes:
ds (Dataset): The dataset containing translation pairs.
tokenizer_src (Tokenizer): The tokenizer for the source language.
tokenizer_tgt (Tokenizer): The tokenizer for the target language.
src_lang (str): The source language code.
tgt_lang (str): The target language code.
sos_token (Tensor): The tensor representing the start-of-sequence token.
pad_token (Tensor): The tensor representing the padding token.
eos_token (Tensor): The tensor representing the end-of-sequence token.
"""
def __init__(self,
ds,
tokenizer_src,
tokenizer_tgt,
src_lang,
tgt_lang,
seq_len) -> None:
super().__init__()
self.seq_len = seq_len
self.ds = ds
self.tokenizer_src = tokenizer_src
self.tokenizer_tgt = tokenizer_tgt
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.sos_token = torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype = torch.int64)
self.pad_token = torch.tensor([tokenizer_src.token_to_id('[PAD]')], dtype = torch.int64)
self.eos_token = torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype = torch.int64)
def __len__(self):
"""
Returns the total number of samples in the dataset.
Returns:
int: The number of samples in the dataset.
"""
return len(self.ds)
def __getitem__(self, index: Any) -> Any:
"""
Retrieves a single data point from the dataset at the specified index.
Args:
index (int): The index of the data point to retrieve.
Returns:
dict: A dictionary containing encoder inputs, decoder inputs, masks, labels, and original texts.
"""
src_target_pair = self.ds[index]
src_text = src_target_pair['translation'][self.src_lang]
tgt_text = src_target_pair['translation'][self.tgt_lang]
enc_input_tokens = self.tokenizer_src.encode(src_text).ids
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
raise ValueError('Sentence Is Too Long')
# Add SOS and EOS as well as PAD tokens to the source text
encoder_input = torch.cat(
[
self.sos_token,
torch.tensor(enc_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)
],
dim=0
)
# Add SOS and PAD to the decoder input
decoder_input = torch.cat(
[
self.sos_token,
torch.tensor(dec_input_tokens, dtype=torch.int64),
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
],
dim=0
)
# Add EOS and PAD to the label
label = torch.cat(
[
torch.tensor(dec_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
],
dim=0
)
assert encoder_input.size(0) == self.seq_len
assert decoder_input.size(0) == self.seq_len
assert label.size(0) == self.seq_len
return {
"encoder_input" : encoder_input, # (seq_len)
"decoder_input" : decoder_input, # (seq_len)
"encoder_mask" : (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
"decoder_mask" : (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, 1, seq_len) & (1, seq_len, seq_len)
"label" : label, # (seq_len)
"src_text" : src_text,
"tgt_text" : tgt_text
}
def causal_mask(size):
"""
Creates a causal mask for decoder inputs to prevent attention to future tokens.
Args:
size (int): The size of the mask (sequence length).
Returns:
Tensor: A causal mask tensor of shape (1, size, size).
"""
mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
return mask == 0