-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_AncientSites.py
35 lines (28 loc) · 1000 Bytes
/
test_AncientSites.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from __future__ import print_function, division
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from skimage import io
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
class test_SitesDataset(Dataset):
def __init__(self, root_dir='/local_scratch/COSI149B/Project1/', transform=None):
self.root_dir = os.path.join(root_dir, 'test')
self.img_dir = [os.path.join(self.root_dir, i) for i in sorted(os.listdir(self.root_dir))]
self.transform = transform
def __len__(self):
return len(self.img_dir)
def __getitem__(self, idx):
img_dir = self.img_dir[idx]
# print(img_dir)
image = Image.open(img_dir)
if self.transform is not None:
image = self.transform(image)
return {'dir': img_dir, 'image': image}