-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
118 lines (101 loc) · 3.65 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
from functools import partial
from typing import TypedDict, cast
import pytorch_lightning as pl
import torch
from datasets import Dataset, load_dataset
from torch import Tensor
from torch.utils.data import DataLoader
from transformers import AutoImageProcessor, AutoTokenizer
class MMSDModelInput(TypedDict, total=False):
pixel_values: Tensor
input_ids: Tensor
attention_mask: Tensor
label: Tensor
id: list[str]
def preprocess(example, image_processor, tokenizer):
image_inputs = image_processor(images=example["image"])
text_inputs = tokenizer(
text=example["text"],
truncation=True,
padding="max_length",
)
return {
"pixel_values": image_inputs["pixel_values"],
"input_ids": text_inputs["input_ids"],
"attention_mask": text_inputs["attention_mask"],
"label": example["label"],
"id": example["id"],
}
class MMSDDatasetModule(pl.LightningDataModule):
def __init__(
self,
vision_ckpt_name: str,
text_ckpt_name: str,
dataset_version: str = "mmsd-v2",
train_batch_size: int = 32,
val_batch_size: int = 32,
test_batch_size: int = 32,
num_workers: int = 19,
) -> None:
super().__init__()
self.vision_ckpt_name = vision_ckpt_name
self.text_ckpt_name = text_ckpt_name
self.dataset_version = dataset_version
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.test_batch_size = test_batch_size
self.num_workers = num_workers
def setup(self, stage: str) -> None:
# https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
image_processor = AutoImageProcessor.from_pretrained(self.vision_ckpt_name)
tokenizer = AutoTokenizer.from_pretrained(self.text_ckpt_name)
self.dataset = cast(
Dataset,
load_dataset("<username>/MMSD2.0", name=self.dataset_version),
)
self.dataset.set_transform(
partial(
preprocess,
image_processor=image_processor,
tokenizer=tokenizer,
)
)
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.dataset["train"], # type: ignore
batch_size=self.train_batch_size,
shuffle=True,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
pin_memory=True,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.dataset["validation"], # type: ignore
batch_size=self.val_batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
pin_memory=True,
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
self.dataset["test"], # type: ignore
batch_size=self.test_batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
pin_memory=True,
)
def collate_fn(self, batch) -> MMSDModelInput:
return {
"pixel_values": torch.stack(
[torch.tensor(x["pixel_values"]) for x in batch]
),
"input_ids": torch.stack([torch.tensor(x["input_ids"]) for x in batch]),
"attention_mask": torch.stack(
[torch.tensor(x["attention_mask"]) for x in batch]
),
"label": torch.stack([torch.tensor(x["label"]) for x in batch]),
"id": [x["id"] for x in batch],
}