@@ -26,7 +26,7 @@ def __init__(self, has_encoder: bool, eos_token: str | None = None) -> None:
26
26
27
27
@staticmethod
28
28
@abstractmethod
29
- def post_process_example (
29
+ def _post_process_example (
30
30
example : dict ,
31
31
instruction : str ,
32
32
task_id : int ,
@@ -83,13 +83,13 @@ def filter_empty_strings(example: dict) -> bool:
83
83
"input_col" in example and "output_col" in example
84
84
), "Example dictionary must have 'input_col' and 'output_col' keys."
85
85
# Check if 'input_col' and 'output_col' are both non-empty strings
86
- return bool (example ["input_col" ]) and bool (example ["output_col" ])
86
+ return bool (str ( example ["input_col" ])) and bool (str ( example ["output_col" ]) )
87
87
88
88
for task_id , dataset_dict in enumerate (dataset_dicts ):
89
89
modified_dataset_dict = {}
90
90
for dataset_split in list (dataset_dict .keys ()):
91
91
mapping_function = partial (
92
- self .post_process_example ,
92
+ self ._post_process_example ,
93
93
instruction = instruction ,
94
94
task_id = task_id ,
95
95
has_encoder = self .has_encoder ,
@@ -104,3 +104,110 @@ def filter_empty_strings(example: dict) -> bool:
104
104
modified_dataset_dict = datasets .DatasetDict (modified_dataset_dict )
105
105
modified_dataset_dicts .append (modified_dataset_dict )
106
106
return modified_dataset_dicts
107
+
108
+ @staticmethod
109
+ def _split_dataset_into_dataset_dict (
110
+ dataset ,
111
+ train_proportion : float = 0.8 ,
112
+ val_proportion : float = 0.1 ,
113
+ maximum_example_num : int | None = None ,
114
+ ) -> datasets .DatasetDict :
115
+ """Split a given dataset into `train`, `val`, and `test` splits.
116
+
117
+ This function takes a dataset and splits it based on specified
118
+ proportions for train, val and test. It respects a maximum
119
+ number of examples to be included in each set, if specified.
120
+
121
+ Args:
122
+ dataset: The original dataset to be split.
123
+ train_proportion: Proportion of examples for the `train` set.
124
+ val_proportion: Proportion of examples for the `val` set.
125
+ maximum_example_num: Maximum number of examples
126
+ to include in each set.
127
+
128
+ Returns:
129
+ datasets.DatasetDict: A dictionary containing the `train`,
130
+ `val`, and `test` datasets.
131
+ """
132
+ num_of_examples = len (dataset )
133
+ train_num = int (train_proportion * num_of_examples )
134
+ val_num = int (val_proportion * num_of_examples )
135
+ test_num = num_of_examples - train_num - val_num
136
+
137
+ if maximum_example_num is not None :
138
+ train_num = min (train_num , maximum_example_num )
139
+ val_num = min (val_num , maximum_example_num )
140
+ test_num = min (test_num , maximum_example_num )
141
+
142
+ train_dataset = datasets .Dataset .from_dict (dataset [:train_num ])
143
+ val_dataset = datasets .Dataset .from_dict (
144
+ dataset [train_num : train_num + val_num ]
145
+ )
146
+ test_dataset = datasets .Dataset .from_dict (
147
+ dataset [train_num + val_num : train_num + val_num + test_num ]
148
+ )
149
+
150
+ dataset_dict = datasets .DatasetDict (
151
+ {"train" : train_dataset , "val" : val_dataset , "test" : test_dataset }
152
+ )
153
+ return dataset_dict
154
+
155
+ @staticmethod
156
+ def wrap_single_input (instruction : str , input : str ):
157
+ """Wrap an input string into text2text fashion to be the input of model.
158
+
159
+ Args:
160
+ instruction: The instruction used as a prefix to explain the task.
161
+ input: An input string to be wrapped.
162
+
163
+ Return:
164
+ A wrapped input string.
165
+ """
166
+ return f"<task 0>{ instruction } \n Example:\n { input } \n Label:\n "
167
+
168
+ def process_dataset_lists (
169
+ self ,
170
+ instruction : str ,
171
+ dataset_list : list [datasets .Dataset ],
172
+ train_proportion : float = 0.8 ,
173
+ val_proportion : float = 0.1 ,
174
+ maximum_example_num : int | None = None ,
175
+ ) -> list [datasets .DatasetDict ]:
176
+ """Post-processes both the generated and retrieved datasets.
177
+
178
+ This function takes in datasets generated by `DatasetGenerator`
179
+ and retrieved by `DatasetRetriever`. It modifies these datasets
180
+ based on a given instruction, converting all examples into a
181
+ text-to-text format.
182
+
183
+ Args:
184
+ instruction: The instruction used as a prefix to explain the task.
185
+ dataset_list: A list of datasets. It can be either generated by
186
+ the DatasetGenerator or retrieved by the DatasetRetriever.
187
+ train_proportion: The proportion of examples used for `train`.
188
+ val_proportion: The proportion of examples used for `val`.
189
+ maxium_example_num: The maximum number of examples to
190
+ be used for `train`, `val` and `test`.
191
+
192
+ Returns:
193
+ list[datasets.DatasetDict]: A list of DatasetDicts, all examples
194
+ are converted into text2text fashion.
195
+
196
+ Note:
197
+ The DatasetRetriever returns a DatasetDict with multiple splits.
198
+ Any of these splits can be passed into this function.
199
+ The remaining proportion after allocating to `train` and
200
+ `val` will be used for the `test` set.
201
+ """
202
+ if train_proportion + val_proportion >= 1 :
203
+ raise ValueError (
204
+ f"train_proportion { train_proportion } + val_proportion { val_proportion } must be less than 1." # noqa E501
205
+ )
206
+
207
+ dataset_dicts = [
208
+ self ._split_dataset_into_dataset_dict (
209
+ each , train_proportion , val_proportion , maximum_example_num
210
+ )
211
+ for each in dataset_list
212
+ ]
213
+ return self .process_dataset_dict (instruction , dataset_dicts )
0 commit comments