Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add dataset #10

Merged
merged 5 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clip/clip/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@
IMAGE_HEIGHT = 112
IMAGE_WIDTH = 112
IMAGE_CHANNEL = 3


# text
MAX_SEQ_LENGTH = 1024
Empty file added clip/clip/data/__init__.py
Empty file.
80 changes: 80 additions & 0 deletions clip/clip/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision import transforms
from pathlib import Path
import os
import json
from clip.constant import MAX_SEQ_LENGTH, IMAGE_HEIGHT, IMAGE_WIDTH
import tiktoken
import torch
from typing import List
import logging

logger = logging.getLogger(__name__)
# logging.basicConfig(level=logging.INFO)


class CLIPDataset(Dataset):
def __init__(
self,
data_dir: str,
max_len: int = MAX_SEQ_LENGTH,
img_height: int = IMAGE_HEIGHT,
img_width: int = IMAGE_WIDTH,
):
self.data_dir = data_dir
self.max_len = max_len
self.img_height = img_height
self.img_width = img_width

self.tokenizer = tiktoken.get_encoding("gpt2")
self.metadata = [self.read_json(js) for js in Path(data_dir).rglob("*.json")]
self.img_data = [
self.read_img(os.path.join(data_dir, f"{data['key']}.jpg"))
for data in self.metadata
if not data["error_message"]
]
self.txt_data = [
self.tokenize(data["caption"])
for data in self.metadata
if not data["error_message"]
]

def __len__(self):
return len(self.img_data)

def __getitem__(self, idx):
return self.txt_data[idx], self.img_data[idx]

def read_json(self, js):
with open(js, "r") as f:
data = json.load(f)
return data

def read_img(self, img_path):
raw_img = read_image(img_path)
if raw_img.shape[0] != 3:
raw_img = raw_img.expand(3,*raw_img.shape[1:])
logger.info(
f"Changing input image from grayscale to RGB."
)
logger.info(
f"Resizing input image from {raw_img.shape[1:]} to {(self.img_height, self.img_width)}"
)
return transforms.Resize((self.img_height, self.img_width), antialias=True)(
raw_img
).to(dtype = torch.float32)

def tokenize(self, text: str) -> List[int]:
tokens = self.tokenizer.encode(text)
if len(tokens) >= self.max_len:
tokens = tokens[: self.max_len - 1]
tokens += [self.tokenizer._special_tokens["<|endoftext|>"]]
else:
# padding after EOS token for batch process
tokens += [self.tokenizer._special_tokens["<|endoftext|>"]]
tokens += [0] * (
self.max_len - len(tokens)
)

return torch.tensor(tokens,dtype=torch.int32)
4 changes: 2 additions & 2 deletions clip/clip/image/vit.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch
from dataclasses import dataclass, field
from ..constant import DEVICE
from clip.constant import DEVICE
import torch.nn as nn
from torch.nn import functional as F
from einops.layers.torch import Rearrange
from ..constant import IMAGE_CHANNEL, IMAGE_HEIGHT, IMAGE_WIDTH
from clip.constant import IMAGE_CHANNEL, IMAGE_HEIGHT, IMAGE_WIDTH


@dataclass
Expand Down
5 changes: 3 additions & 2 deletions clip/clip/languange/gpt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
from dataclasses import dataclass, field
from ..constant import DEVICE
from clip.constant import DEVICE
import torch.nn as nn
from torch.nn import functional as F
from clip.constant import MAX_SEQ_LENGTH


@dataclass
Expand All @@ -11,7 +12,7 @@ class GPTConfig:
vocab_size: int = field(
default=65536, metadata={"help": "define size of vocabulary"}
) # 2**16
seq_len: int = field(default=1024, metadata={"help": "sequence length"})
seq_len: int = field(default=MAX_SEQ_LENGTH, metadata={"help": "sequence length"})
n_layer: int = field(default=12, metadata={"help": "number of layers"})
n_head: int = field(default=8, metadata={"help": "number of heads"})
n_embd: int = field(default=768, metadata={"help": "embedding dimension"})
Expand Down
7 changes: 4 additions & 3 deletions clip/clip/loss.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import torch
import torch.nn as nn
from clip.constant import DEVICE


class CLIPLoss(nn.Module):
def __init__(self, batch_size: int):
def __init__(self, batch_size: int, device = DEVICE):
super().__init__()
self.batch_size = batch_size
self.label = torch.arange(0, self.batch_size)
self.label = torch.arange(0, self.batch_size, dtype=torch.long, device=device)
self.img_loss = nn.CrossEntropyLoss()
self.txt_loss = nn.CrossEntropyLoss()

def forward(self, img_log, txt_log):
def forward(self, txt_log, img_log):
# Loss function
loss_images = self.img_loss(img_log, self.label)
loss_text = self.txt_loss(txt_log, self.label)
Expand Down
Loading
Loading