Skip to content

Commit

Permalink
Merge pull request #78 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v0.0.77
  • Loading branch information
KevinMusgrave committed Jul 23, 2022
2 parents 9d54fb0 + a73fc6d commit 8f0b28c
Show file tree
Hide file tree
Showing 15 changed files with 3,504 additions and 98 deletions.
2 changes: 1 addition & 1 deletion build_script.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
./format_code.sh
RUN_DATASET_TESTS=true python -m unittest discover && \
RUN_DATASET_TESTS=true RUN_DOMAINNET126_DATASET_TESTS=true python -m unittest discover && \
rm -rfv build/ && \
rm -rfv dist/ && \
rm -rfv src/pytorch_adapt.egg-info/ && \
Expand Down
3,320 changes: 3,264 additions & 56 deletions examples/other/DatasetViz.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
long_description = fh.read()


extras_require_ignite = ["pytorch-ignite >= 0.4.9"]
extras_require_ignite = ["pytorch-ignite == 0.4.9"]
extras_require_lightning = ["pytorch-lightning"]
extras_require_record_keeper = ["record-keeper >= 0.9.31"]
extras_require_timm = ["timm"]
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_adapt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.76"
__version__ = "0.0.77"
2 changes: 1 addition & 1 deletion src/pytorch_adapt/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .concat_dataset import ConcatDataset
from .dataloader_creator import DataloaderCreator
from .domainnet import DomainNet, DomainNet126, DomainNet126Full
from .getters import get_mnist_mnistm, get_office31, get_officehome
from .getters import get_domainnet126, get_mnist_mnistm, get_office31, get_officehome
from .mnistm import MNISTM
from .office31 import Office31, Office31Full
from .officehome import OfficeHome, OfficeHomeFull
Expand Down
29 changes: 17 additions & 12 deletions src/pytorch_adapt/datasets/domainnet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from collections import OrderedDict

from .base_dataset import BaseDataset
from .utils import check_img_paths, check_length
from .base_dataset import BaseDataset, BaseDownloadableDataset
from .utils import check_img_paths, check_length, check_train


class DomainNet(BaseDataset):
Expand Down Expand Up @@ -84,38 +84,43 @@ def __init__(self, root: str, domain: str, transform, **kwargs):
self.transform = transform


class DomainNet126(BaseDataset):
class DomainNet126(BaseDownloadableDataset):
"""
A custom train/test split of DomainNet126Full.
"""

def __init__(self, root: str, domain: str, train: bool, transform, **kwargs):
url = "https://cornell.box.com/shared/static/5uu0v3rs9heusbiht2nn1gbn4yfspas6"
filename = "domainnet126.tar.gz"
md5 = "50f29fa0152d715c036c813ad67502d6"

def __init__(self, root: str, domain: str, train: bool, transform=None, **kwargs):
"""
Arguments:
root: The dataset must be located at ```<root>/domainnet```
domain: One of the 4 domains
train: Whether or not to use the training set.
transform: The image transform applied to each sample.
"""
super().__init__(domain=domain, **kwargs)
if not isinstance(train, bool):
raise TypeError("train should be True or False")
name = "train" if train else "test"
labels_file = os.path.join(root, "domainnet", f"{domain}126_{name}.txt")
self.train = check_train(train)
super().__init__(root=root, domain=domain, **kwargs)
self.transform = transform

def set_paths_and_labels(self, root):
name = "train" if self.train else "test"
labels_file = os.path.join(root, "domainnet", f"{self.domain}126_{name}.txt")
img_dir = os.path.join(root, "domainnet")

with open(labels_file) as f:
content = [line.rstrip().split(" ") for line in f]
self.img_paths = [os.path.join(img_dir, x[0]) for x in content]
check_img_paths(img_dir, self.img_paths, domain)
check_img_paths(img_dir, self.img_paths, self.domain)
check_length(
self,
{
"clipart": {"train": 14962, "test": 3741}[name],
"painting": {"train": 25201, "test": 6301}[name],
"real": {"train": 56286, "test": 14072}[name],
"sketch": {"train": 19665, "test": 4917}[name],
}[domain],
}[self.domain],
)
self.labels = [int(x[1]) for x in content]
self.transform = transform
43 changes: 22 additions & 21 deletions src/pytorch_adapt/datasets/getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
from ..utils.transforms import GrayscaleToRGB
from .combined_source_and_target import CombinedSourceAndTargetDataset
from .concat_dataset import ConcatDataset
from .domainnet import DomainNet126
from .mnistm import MNISTM
from .office31 import Office31
from .officehome import OfficeHome
from .source_dataset import SourceDataset
from .target_dataset import TargetDataset


def get_multiple(dataset_getter, domains, *args):
return ConcatDataset([dataset_getter(d, *args) for d in domains])
def get_multiple(dataset_getter, domains, **kwargs):
return ConcatDataset([dataset_getter(domain=d, **kwargs) for d in domains])


def get_datasets(
Expand All @@ -31,11 +32,11 @@ def getter(domains, train, is_training):
return get_multiple(
dataset_getter,
domains,
train,
is_training,
folder,
download,
transform_getter,
train=train,
is_training=is_training,
root=folder,
download=download,
transform_getter=transform_getter,
)

if not src_domains and not target_domains:
Expand Down Expand Up @@ -93,15 +94,15 @@ def get_mnist_transform(domain, *_):
)


def _get_mnist_mnistm(domain, train, is_training, folder, download, transform_getter):
def _get_mnist_mnistm(is_training, transform_getter, **kwargs):
transform_getter = c_f.default(transform_getter, get_mnist_transform)
transform = transform_getter(domain, train, is_training)
domain = kwargs["domain"]
kwargs["transform"] = transform_getter(domain, kwargs["train"], is_training)
kwargs.pop("domain")
if domain == "mnist":
return datasets.MNIST(
folder, train=train, transform=transform, download=download
)
return datasets.MNIST(**kwargs)
elif domain == "mnistm":
return MNISTM(folder, train, transform, download=download)
return MNISTM(**kwargs)


def get_mnist_mnistm(*args, **kwargs):
Expand All @@ -126,16 +127,12 @@ def get_resnet_transform(domain, train, is_training):


def standard_dataset(cls):
def fn(domain, train, is_training, folder, download, transform_getter):
def fn(is_training, transform_getter, **kwargs):
transform_getter = c_f.default(transform_getter, get_resnet_transform)
transform = transform_getter(domain, train, is_training)
return cls(
root=folder,
domain=domain,
train=train,
transform=transform,
download=download,
kwargs["transform"] = transform_getter(
kwargs["domain"], kwargs["train"], is_training
)
return cls(**kwargs)

return fn

Expand All @@ -146,3 +143,7 @@ def get_office31(*args, **kwargs):

def get_officehome(*args, **kwargs):
return get_datasets(standard_dataset(OfficeHome), *args, **kwargs)


def get_domainnet126(*args, **kwargs):
return get_datasets(standard_dataset(DomainNet126), *args, **kwargs)
23 changes: 23 additions & 0 deletions src/pytorch_adapt/models/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,26 @@ def officehomeC(
return download_weights(
model, url, pretrained, progress=progress, file_name=file_name, **kwargs
)


def domainnet126G(*args, **kwargs):
"""
Returns:
A ResNet50 model trained on ImageNet, if ```pretrained == True```.
"""
return resnet50(*args, **kwargs)


def domainnet126C(
domain=None,
num_classes=126,
in_size=2048,
h=256,
pretrained=False,
progress=True,
**kwargs,
):
if pretrained:
raise ValueError("pretrained=True not yet supported")

return Classifier(num_classes=num_classes, in_size=in_size, h=h)
1 change: 1 addition & 0 deletions src/pytorch_adapt/validators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .knn_validator import KNNValidator
from .mmd_validator import MMDValidator
from .multiple_validators import MultipleValidators
from .nearest_source_validator import NearestSourceL2Validator, NearestSourceValidator
from .per_class_validator import PerClassValidator
from .score_history import ScoreHistories, ScoreHistory
from .snd_validator import SNDValidator
Expand Down
78 changes: 78 additions & 0 deletions src/pytorch_adapt/validators/nearest_source_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
from pytorch_metric_learning.distances import (
BatchedDistance,
CosineSimilarity,
LpDistance,
)
from pytorch_metric_learning.utils.inference import CustomKNN

from .base_validator import BaseValidator


def acc(preds, labels):
if max(labels) != preds.shape[1] - 1:
raise ValueError(
f"Max label {max(labels)} should be equal to preds.shape[1] {preds.shape[1]}"
)
preds = torch.argmax(preds, dim=1)
return (preds == labels).float()


class NearestSourceValidator(BaseValidator):
def __init__(self, layer="preds", threshold=0, weighted=False, **kwargs):
super().__init__(**kwargs)
self.layer = layer
self.threshold = threshold
self.weighted = weighted
self.knn_fn = CustomKNN(CosineSimilarity())

def compute_score(self, src_val, target_train):
nearest_src_acc, sims = self.get_nearest_src_acc(src_val, target_train)

if self.weighted:
sims = (sims - self.threshold) / (max(sims) - self.threshold)
sims[sims <= 0] = 0
nearest_src_acc *= sims
else:
nearest_src_acc[sims <= self.threshold] = 0

return torch.mean(nearest_src_acc).item()

def get_nearest_src_acc(self, src_val, target_train):
src_acc = acc(src_val["preds"], src_val["labels"])

sims, idx = self.knn_fn(
target_train[self.layer],
k=1,
reference=src_val[self.layer],
embeddings_come_from_same_source=False,
)
sims, idx = sims.squeeze(1), idx.squeeze(1)
nearest_src_acc = src_acc[idx]
return nearest_src_acc, sims


class NearestSourceL2Validator(NearestSourceValidator):
def __init__(self, layer="preds", **kwargs):
super().__init__(layer=layer, threshold=float("inf"), weighted=True, **kwargs)
dist_fn = LpDistance(normalize_embeddings=False)
self.knn_fn = CustomKNN(dist_fn)
self.all_dist_fn = BatchedDistance(dist_fn, batch_size=1024)

def compute_score(self, src_val, target_train):
max_dist = [0]

def iter_fn(mat, *_):
max_dist[0] = max(max_dist[0], torch.max(mat))

all_feats = torch.cat([src_val[self.layer], target_train[self.layer]], dim=0)
self.all_dist_fn.iter_fn = iter_fn
self.all_dist_fn(all_feats)

nearest_src_acc, dists = self.get_nearest_src_acc(src_val, target_train)
dists /= max_dist[0]
nearest_src_acc *= 1 - dists
return torch.mean(nearest_src_acc).item()


#
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DATASET_FOLDER = "zzz_pytorch_adapt_dataset_test_folder"
RUN_DATASET_TESTS = os.environ.get("RUN_DATASET_TESTS", False)
RUN_DOMAINNET_DATASET_TESTS = os.environ.get("RUN_DOMAINNET_DATASET_TESTS", False)
RUN_DOMAINNET126_DATASET_TESTS = os.environ.get("RUN_DOMAINNET126_DATASET_TESTS", False)

TEST_DTYPES = [getattr(torch, x) for x in dtypes_from_environ]
TEST_DEVICE = torch.device(device_from_environ)
Expand Down
11 changes: 8 additions & 3 deletions tests/datasets/test_domainnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@

from pytorch_adapt.datasets import DomainNet, DomainNet126, DomainNet126Full

from .. import DATASET_FOLDER, RUN_DOMAINNET_DATASET_TESTS
from .. import (
DATASET_FOLDER,
RUN_DOMAINNET126_DATASET_TESTS,
RUN_DOMAINNET_DATASET_TESTS,
)
from .utils import (
check_full,
check_train_test_disjoint,
check_train_test_matches_full,
loop_through_dataset,
skip_reason_domainnet,
skip_reason_domainnet126,
)


Expand Down Expand Up @@ -39,7 +44,7 @@ def test_domainnet(self):
self.assertTrue(len(dataset) == length)
loop_through_dataset(dataset)

@unittest.skipIf(not RUN_DOMAINNET_DATASET_TESTS, skip_reason_domainnet)
@unittest.skipIf(not RUN_DOMAINNET126_DATASET_TESTS, skip_reason_domainnet126)
def test_domainnet126(self):
check_train_test_matches_full(
self,
Expand All @@ -50,7 +55,7 @@ def test_domainnet126(self):
DATASET_FOLDER,
)

@unittest.skipIf(not RUN_DOMAINNET_DATASET_TESTS, skip_reason_domainnet)
@unittest.skipIf(not RUN_DOMAINNET126_DATASET_TESTS, skip_reason_domainnet126)
def test_domainnet126_full(self):
check_full(
self,
Expand Down
28 changes: 26 additions & 2 deletions tests/datasets/test_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pytorch_adapt.datasets import (
MNISTM,
DomainNet126,
Office31,
OfficeHome,
SourceDataset,
Expand All @@ -12,9 +13,10 @@
get_office31,
get_officehome,
)
from pytorch_adapt.datasets.getters import get_domainnet126

from .. import DATASET_FOLDER, RUN_DATASET_TESTS
from .utils import skip_reason
from .. import DATASET_FOLDER, RUN_DATASET_TESTS, RUN_DOMAINNET126_DATASET_TESTS
from .utils import skip_reason, skip_reason_domainnet126


class TestGetters(unittest.TestCase):
Expand Down Expand Up @@ -138,3 +140,25 @@ def test_office31(self):
"target_val": 159,
},
)

@unittest.skipIf(not RUN_DOMAINNET126_DATASET_TESTS, skip_reason_domainnet126)
def test_domainnet126(self):
datasets = get_domainnet126(
["real"], ["sketch"], folder=DATASET_FOLDER, download=True
)
self.helper(
datasets,
DomainNet126,
DomainNet126,
{
"src_train": 56286,
"src_val": 14072,
"target_train": 19665,
"target_val": 4917,
},
)

def test_incorrect_train_arg(self):
for dataset in [MNISTM, Office31, OfficeHome, DomainNet126]:
with self.assertRaises(TypeError):
dataset(root=DATASET_FOLDER, domain=None, train="something")
1 change: 1 addition & 0 deletions tests/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

skip_reason = "RUN_DATASET_TESTS is False"
skip_reason_domainnet = "RUN_DOMAINNET_DATASET_TESTS is False"
skip_reason_domainnet126 = "RUN_DOMAINNET126_DATASET_TESTS is False"


def simple_transform():
Expand Down
Loading

0 comments on commit 8f0b28c

Please sign in to comment.