Skip to content

Commit ce4626f

Browse files
viswaviEren Chenyang Zhao
andauthored
Remove unnecessary columns in the dataset processor (#279)
* Remove unnecessary columns in the dataset processor * Modify dataset.map to achieve this * Add numerical column test * Add test case to test concatenating columns of different types * Improve docstring and test name --------- Co-authored-by: Eren Chenyang Zhao <chenyan3@andrew.cmu.edu>
1 parent 6006d1e commit ce4626f

File tree

2 files changed

+83
-33
lines changed

2 files changed

+83
-33
lines changed

prompt2model/dataset_processor/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def filter_empty_strings(example: dict) -> bool:
9999
modified_dataset_dict[dataset_split] = (
100100
dataset_dict[dataset_split]
101101
.filter(filter_empty_strings)
102-
.map(mapping_function)
102+
.map(mapping_function, remove_columns=["input_col", "output_col"])
103103
)
104104
modified_dataset_dict = datasets.DatasetDict(modified_dataset_dict)
105105
modified_dataset_dicts.append(modified_dataset_dict)

tests/dataset_processor_test.py

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ def test_dataset_processor_t5_style():
138138
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
139139
"<task 0>convert to text2text\nExample:\nbar\nLabel:\n",
140140
],
141-
"input_col": ["foo", "bar"],
142-
"output_col": ["baz", "qux"],
143141
"model_output": ["baz", "qux"],
144142
}
145143
),
@@ -149,8 +147,6 @@ def test_dataset_processor_t5_style():
149147
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
150148
"<task 0>convert to text2text\nExample:\nbar\nLabel:\n",
151149
],
152-
"input_col": ["foo", "bar"],
153-
"output_col": ["baz", "qux"],
154150
"model_output": ["baz", "qux"],
155151
}
156152
),
@@ -164,8 +160,6 @@ def test_dataset_processor_t5_style():
164160
"<task 1>convert to text2text\nExample:\nspam\nLabel:\n",
165161
"<task 1>convert to text2text\nExample:\neggs\nLabel:\n",
166162
],
167-
"input_col": ["spam", "eggs"],
168-
"output_col": ["ham", "sau"],
169163
"model_output": ["ham", "sau"],
170164
}
171165
),
@@ -175,8 +169,6 @@ def test_dataset_processor_t5_style():
175169
"<task 1>convert to text2text\nExample:\nspam\nLabel:\n",
176170
"<task 1>convert to text2text\nExample:\neggs\nLabel:\n",
177171
],
178-
"input_col": ["spam", "eggs"],
179-
"output_col": ["ham", "sau"],
180172
"model_output": ["ham", "sau"],
181173
}
182174
),
@@ -188,6 +180,88 @@ def test_dataset_processor_t5_style():
188180
gc.collect()
189181

190182

183+
def test_dataset_processor_with_numerical_column():
184+
"""Test process_dataset_dict with numerical column values."""
185+
t5_processor = TextualizeProcessor(has_encoder=True)
186+
raw_dataset_dicts = [
187+
datasets.DatasetDict(
188+
{
189+
"train": datasets.Dataset.from_dict(
190+
{
191+
"input_col": ["foo", "bar"],
192+
"output_col": ["baz", "qux"],
193+
}
194+
),
195+
"test": datasets.Dataset.from_dict(
196+
{
197+
"input_col": ["spam", "eggs"],
198+
"output_col": ["ham", "sau"],
199+
}
200+
),
201+
}
202+
),
203+
datasets.DatasetDict(
204+
{
205+
"train": datasets.Dataset.from_dict(
206+
{
207+
"input_col": ["foo", "bar"],
208+
"output_col": [0, 1],
209+
}
210+
),
211+
"test": datasets.Dataset.from_dict(
212+
{
213+
"input_col": ["spam", "eggs"],
214+
"output_col": [1, 2],
215+
}
216+
),
217+
}
218+
),
219+
]
220+
t5_modified_dataset_dicts = t5_processor.process_dataset_dict(
221+
INSTRUCTION, raw_dataset_dicts
222+
)
223+
expected_dataset_dict = datasets.DatasetDict(
224+
{
225+
"train": datasets.Dataset.from_dict(
226+
{
227+
"model_input": [
228+
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
229+
"<task 0>convert to text2text\nExample:\nbar\nLabel:\n",
230+
"<task 1>convert to text2text\nExample:\nfoo\nLabel:\n",
231+
"<task 1>convert to text2text\nExample:\nbar\nLabel:\n",
232+
],
233+
"model_output": ["foo", "bar", "0", "1"],
234+
}
235+
),
236+
"test": datasets.Dataset.from_dict(
237+
{
238+
"model_input": [
239+
"<task 0>convert to text2text\nExample:\nspam\nLabel:\n",
240+
"<task 0>convert to text2text\nExample:\neggs\nLabel:\n",
241+
"<task 1>convert to text2text\nExample:\nspam\nLabel:\n",
242+
"<task 1>convert to text2text\nExample:\neggs\nLabel:\n",
243+
],
244+
"model_output": ["ham", "sau", "1", "2"],
245+
}
246+
),
247+
}
248+
)
249+
training_datasets = []
250+
test_datasets = []
251+
for modified_dataset_dict in t5_modified_dataset_dicts:
252+
training_datasets.append(modified_dataset_dict["train"])
253+
test_datasets.append(modified_dataset_dict["test"])
254+
255+
concatenated_training_dataset = datasets.concatenate_datasets(training_datasets)
256+
concatenated_test_dataset = datasets.concatenate_datasets(test_datasets)
257+
actual_dataset_dict = datasets.DatasetDict(
258+
{"train": concatenated_training_dataset, "test": concatenated_test_dataset}
259+
)
260+
are_dataset_dicts_identical(expected_dataset_dict, actual_dataset_dict)
261+
262+
gc.collect()
263+
264+
191265
def test_dataset_processor_decoder_only_style():
192266
"""Test the `process_dataset_dict` function of a GPT-type `TextualizeProcessor`."""
193267
_, gpt2_tokenizer = create_gpt2_model_and_tokenizer()
@@ -213,8 +287,6 @@ def test_dataset_processor_decoder_only_style():
213287
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\nbaz<|endoftext|>", # noqa: E501
214288
"<task 0>convert to text2text\nExample:\nbar\nLabel:\nqux<|endoftext|>", # noqa: E501
215289
],
216-
"input_col": ["foo", "bar"],
217-
"output_col": ["baz", "qux"],
218290
"model_output": ["baz<|endoftext|>", "qux<|endoftext|>"],
219291
}
220292
),
@@ -224,8 +296,6 @@ def test_dataset_processor_decoder_only_style():
224296
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
225297
"<task 0>convert to text2text\nExample:\nbar\nLabel:\n",
226298
],
227-
"input_col": ["foo", "bar"],
228-
"output_col": ["baz", "qux"],
229299
"model_output": ["baz", "qux"],
230300
}
231301
),
@@ -239,8 +309,6 @@ def test_dataset_processor_decoder_only_style():
239309
"<task 1>convert to text2text\nExample:\nspam\nLabel:\nham<|endoftext|>", # noqa: E501
240310
"<task 1>convert to text2text\nExample:\neggs\nLabel:\nsau<|endoftext|>", # noqa: E501
241311
],
242-
"input_col": ["spam", "eggs"],
243-
"output_col": ["ham", "sau"],
244312
"model_output": ["ham<|endoftext|>", "sau<|endoftext|>"],
245313
}
246314
),
@@ -250,8 +318,6 @@ def test_dataset_processor_decoder_only_style():
250318
"<task 1>convert to text2text\nExample:\nspam\nLabel:\n",
251319
"<task 1>convert to text2text\nExample:\neggs\nLabel:\n",
252320
],
253-
"input_col": ["spam", "eggs"],
254-
"output_col": ["ham", "sau"],
255321
"model_output": ["ham", "sau"],
256322
}
257323
),
@@ -341,8 +407,6 @@ def test_empty_filter_t5_type():
341407
"model_input": [
342408
"<task 0>convert to text2text\nExample:\ntest\nLabel:\n",
343409
],
344-
"input_col": ["test"],
345-
"output_col": ["key"],
346410
"model_output": ["key"],
347411
}
348412
),
@@ -351,12 +415,6 @@ def test_empty_filter_t5_type():
351415
"model_input": [
352416
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
353417
],
354-
"input_col": [
355-
"foo",
356-
],
357-
"output_col": [
358-
"baz",
359-
],
360418
"model_output": [
361419
"baz",
362420
],
@@ -369,8 +427,6 @@ def test_empty_filter_t5_type():
369427
"train": datasets.Dataset.from_dict(
370428
{
371429
"model_input": [],
372-
"input_col": [],
373-
"output_col": [],
374430
"model_output": [],
375431
}
376432
),
@@ -403,8 +459,6 @@ def test_empty_filter_decoder_only_style():
403459
"model_input": [
404460
"<task 0>convert to text2text\nExample:\ntest\nLabel:\nkey<|endoftext|>", # noqa: E501
405461
],
406-
"input_col": ["test"],
407-
"output_col": ["key"],
408462
"model_output": ["key<|endoftext|>"],
409463
}
410464
),
@@ -413,8 +467,6 @@ def test_empty_filter_decoder_only_style():
413467
"model_input": [
414468
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
415469
],
416-
"input_col": ["foo"],
417-
"output_col": ["baz"],
418470
"model_output": ["baz"],
419471
}
420472
),
@@ -425,8 +477,6 @@ def test_empty_filter_decoder_only_style():
425477
"train": datasets.Dataset.from_dict(
426478
{
427479
"model_input": [],
428-
"input_col": [],
429-
"output_col": [],
430480
"model_output": [],
431481
}
432482
),

0 commit comments

Comments
 (0)