Skip to content

Commit

Permalink
Merge pull request #86 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v0.0.79
  • Loading branch information
KevinMusgrave authored Aug 16, 2022
2 parents a0fbdd7 + efc25db commit 6fbae70
Show file tree
Hide file tree
Showing 40 changed files with 835 additions and 162 deletions.
4 changes: 0 additions & 4 deletions format_code.bat

This file was deleted.

4 changes: 2 additions & 2 deletions format_code.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
black . --exclude examples
isort . --profile black --skip-glob examples
black src tests
isort src tests --profile black
nbqa black examples
nbqa isort examples --profile black
2 changes: 0 additions & 2 deletions run_linter.bat

This file was deleted.

2 changes: 1 addition & 1 deletion run_linter.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
flake8 . --count --show-source --statistics
flake8 src tests --count --show-source --statistics
nbqa flake8 examples --count --show-source --statistics
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
long_description = fh.read()


extras_require_detection = ["albumentations"]
extras_require_detection = ["albumentations >= 1.2.1"]
extras_require_ignite = ["pytorch-ignite == 0.4.9"]
extras_require_lightning = ["pytorch-lightning"]
extras_require_record_keeper = ["record-keeper >= 0.9.31"]
extras_require_record_keeper = ["record-keeper >= 0.9.32"]
extras_require_timm = ["timm"]
extras_require_docs = [
"mkdocs-material",
Expand Down Expand Up @@ -44,8 +44,8 @@
"numpy",
"torch",
"torchvision",
"torchmetrics",
"pytorch-metric-learning >= 1.3.1.dev0",
"torchmetrics >= 0.9.3",
"pytorch-metric-learning >= 1.5.2",
],
extras_require={
"detection": extras_require_detection,
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_adapt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.78"
__version__ = "0.0.79"
2 changes: 1 addition & 1 deletion src/pytorch_adapt/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .adda import ADDA
from .aligner import RTN, Aligner
from .base_adapter import BaseAdapter
from .classifier import Classifier, Finetuner
from .classifier import Classifier, Finetuner, MultiLabelClassifier
from .dann import CDANNE, DANN, DANNE, GVB, GVBE
from .gan import CDAN, CDANE, GAN, GANE, VADA, DomainConfusion
from .mcd import MCD
Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_adapt/adapters/adabn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ class AdaBN(BaseAdapter):
|models|```["G", "C"]```|
"""

hook_cls = AdaBNHook

def __init__(self, *args, inference_fn=None, **kwargs):
"""
Arguments:
Expand All @@ -27,6 +25,10 @@ def __init__(self, *args, inference_fn=None, **kwargs):
def init_hook(self, hook_kwargs):
self.hook = self.hook_cls(**hook_kwargs)

@property
def hook_cls(self):
return AdaBNHook

def get_key_enforcer(self) -> KeyEnforcer:
return KeyEnforcer(models=["G", "C"], optimizers=[])

Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_adapt/adapters/adda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ class ADDA(BaseAdapter):
The target model ("T") is created during initialization by deep-copying the G model.
"""

hook_cls = ADDAHook

def __init__(self, *args, inference_fn=None, **kwargs):
"""
Arguments:
Expand Down Expand Up @@ -55,3 +53,7 @@ def init_hook(self, hook_kwargs):
def init_containers_and_check_keys(self, containers):
containers["models"]["T"] = copy.deepcopy(containers["models"]["G"])
super().init_containers_and_check_keys(containers)

@property
def hook_cls(self):
return ADDAHook
12 changes: 8 additions & 4 deletions src/pytorch_adapt/adapters/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ class Aligner(BaseGCAdapter):
|optimizers|```["G", "C"]```|
"""

hook_cls = AlignerPlusCHook

def init_hook(self, hook_kwargs):
opts = with_opt(list(self.optimizers.keys()))
self.hook = self.hook_cls(opts, **hook_kwargs)

@property
def hook_cls(self):
return AlignerPlusCHook


class RTN(Aligner):
"""
Expand All @@ -34,8 +36,6 @@ class RTN(Aligner):
|misc|```["feature_combiner"]```|
"""

hook_cls = RTNHook

def __init__(self, *args, inference_fn=None, **kwargs):
"""
Arguments:
Expand All @@ -44,6 +44,10 @@ def __init__(self, *args, inference_fn=None, **kwargs):
inference_fn = c_f.default(inference_fn, rtn_fn)
super().__init__(*args, inference_fn=inference_fn, **kwargs)

@property
def hook_cls(self):
return RTNHook

def get_key_enforcer(self) -> KeyEnforcer:
ke = super().get_key_enforcer()
ke.requirements["models"].append("residual_model")
Expand Down
5 changes: 5 additions & 0 deletions src/pytorch_adapt/adapters/base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ def init_hook(self):
"""
pass

@property
@abstractmethod
def hook_cls(self):
pass

def init_containers_and_check_keys(self, containers):
"""
Called in ```__init__``` before
Expand Down
18 changes: 14 additions & 4 deletions src/pytorch_adapt/adapters/classifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ..containers import KeyEnforcer, MultipleContainers, Optimizers
from ..hooks import ClassifierHook, FinetunerHook
from ..hooks import ClassifierHook, FinetunerHook, MultiLabelClassifierHook
from .base_adapter import BaseGCAdapter
from .utils import default_optimizer_tuple, with_opt

Expand All @@ -14,12 +14,14 @@ class Classifier(BaseGCAdapter):
|optimizers|```["G", "C"]```|
"""

hook_cls = ClassifierHook

def init_hook(self, hook_kwargs):
opts = with_opt(list(self.optimizers.keys()))
self.hook = self.hook_cls(opts=opts, **hook_kwargs)

@property
def hook_cls(self):
return ClassifierHook


class Finetuner(Classifier):
"""
Expand All @@ -31,7 +33,9 @@ class Finetuner(Classifier):
|optimizers|```["C"]```|
"""

hook_cls = FinetunerHook
@property
def hook_cls(self):
return FinetunerHook

def get_default_containers(self) -> MultipleContainers:
optimizers = Optimizers(default_optimizer_tuple(), keys=["C"])
Expand All @@ -41,3 +45,9 @@ def get_key_enforcer(self) -> KeyEnforcer:
ke = super().get_key_enforcer()
ke.requirements["optimizers"].remove("G")
return ke


class MultiLabelClassifier(Classifier):
@property
def hook_cls(self):
return MultiLabelClassifierHook
22 changes: 16 additions & 6 deletions src/pytorch_adapt/adapters/dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,25 @@ class DANN(BaseGCDAdapter):
|optimizers|```["G", "C", "D"]```|
"""

hook_cls = DANNHook

def init_hook(self, hook_kwargs):
opts = with_opt(list(self.optimizers.keys()))
self.hook = self.hook_cls(opts=opts, **hook_kwargs)

@property
def hook_cls(self):
return DANNHook


class DANNE(DANN):
hook_cls = DANNEHook
@property
def hook_cls(self):
return DANNEHook


class CDANNE(DANN, CDAN):
hook_cls = CDANNEHook
@property
def hook_cls(self):
return CDANNEHook


class GVB(DANN):
Expand All @@ -44,7 +50,9 @@ class GVB(DANN):
with each bridge being a re-initialized copy of each model.
"""

hook_cls = GVBHook
@property
def hook_cls(self):
return GVBHook

def init_containers_and_check_keys(self, containers):
models = containers["models"]
Expand All @@ -55,4 +63,6 @@ def init_containers_and_check_keys(self, containers):


class GVBE(GVB):
hook_cls = GVBEHook
@property
def hook_cls(self):
return GVBEHook
26 changes: 19 additions & 7 deletions src/pytorch_adapt/adapters/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@ class GAN(BaseGCDAdapter):
|optimizers|```["G", "C", "D"]```|
"""

hook_cls = GANHook

def init_hook(self, hook_kwargs):
g_opts = with_opt(["G", "C"])
d_opts = with_opt(["D"])
self.hook = self.hook_cls(d_opts=d_opts, g_opts=g_opts, **hook_kwargs)

@property
def hook_cls(self):
return GANHook


class GANE(GAN):
hook_cls = GANEHook
@property
def hook_cls(self):
return GANEHook


class CDAN(GAN):
Expand All @@ -46,7 +50,9 @@ class CDAN(GAN):
|misc|```["feature_combiner"]```|
"""

hook_cls = CDANHook
@property
def hook_cls(self):
return CDANHook

def get_key_enforcer(self) -> KeyEnforcer:
ke = super().get_key_enforcer()
Expand All @@ -55,7 +61,9 @@ def get_key_enforcer(self) -> KeyEnforcer:


class CDANE(CDAN):
hook_cls = CDANEHook
@property
def hook_cls(self):
return CDANEHook


class DomainConfusion(GAN):
Expand All @@ -68,7 +76,9 @@ class DomainConfusion(GAN):
|optimizers|```["G", "C", "D"]```|
"""

hook_cls = DomainConfusionHook
@property
def hook_cls(self):
return DomainConfusionHook


class VADA(GAN):
Expand All @@ -86,7 +96,9 @@ class VADA(GAN):
automatically during initialization.
"""

hook_cls = VADAHook
@property
def hook_cls(self):
return VADAHook

def init_containers_and_check_keys(self, containers):
models = containers["models"]
Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_adapt/adapters/mcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class MCD(BaseGCAdapter):
classifiers is 2, so C should output ```[logits1, logits2]```.
"""

hook_cls = MCDHook

def __init__(self, *args, inference_fn=None, **kwargs):
inference_fn = c_f.default(inference_fn, mcd_fn)
super().__init__(*args, inference_fn=inference_fn, **kwargs)
Expand All @@ -29,3 +27,7 @@ def init_hook(self, hook_kwargs):
self.hook = self.hook_cls(
g_opts=with_opt(["G"]), c_opts=with_opt(["C"]), **hook_kwargs
)

@property
def hook_cls(self):
return MCDHook
6 changes: 4 additions & 2 deletions src/pytorch_adapt/adapters/symnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ class SymNets(BaseGCAdapter):
The C model must output a list of logits: ```[logits1, logits2]```.
"""

hook_cls = SymNetsHook

def __init__(self, *args, inference_fn=None, **kwargs):
inference_fn = c_f.default(inference_fn, symnets_fn)
super().__init__(*args, inference_fn=inference_fn, **kwargs)
Expand All @@ -27,3 +25,7 @@ def init_hook(self, hook_kwargs):
self.hook = self.hook_cls(
g_opts=with_opt(["G"]), c_opts=with_opt(["C"]), **hook_kwargs
)

@property
def hook_cls(self):
return SymNetsHook
1 change: 1 addition & 0 deletions src/pytorch_adapt/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,6 @@ def download_dataset(self, root):
download_url(self.url, root, filename=self.filename, md5=self.md5)
filepath = os.path.join(root, self.filename)
decompressor = tarfile.open if tarfile.is_tarfile(filepath) else zipfile.ZipFile
c_f.LOGGER.info("Extracting")
with decompressor(filepath, "r") as f:
f.extractall(path=root, members=c_f.extract_progress(f))
11 changes: 11 additions & 0 deletions src/pytorch_adapt/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,14 @@ def maybe_download(fn, kwargs):
fn(**kwargs)
else:
raise


def num_classes(dataset_name):
return {
"mnist": 10,
"domainnet": 345,
"domainnet126": 126,
"office31": 31,
"officehome": 65,
"voc_multilabel": 20,
}[dataset_name]
1 change: 1 addition & 0 deletions src/pytorch_adapt/frameworks/ignite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .checkpoint_utils import CheckpointFnCreator
from .ignite import Ignite
from .ignite_multilabel_classification import IgniteMultiLabelClassification
from .ignite_preds_as_features import IgnitePredsAsFeatures
from .ignite_val_hook_wrapper import IgniteValHookWrapper
6 changes: 5 additions & 1 deletion src/pytorch_adapt/frameworks/ignite/ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
with_pbars=True,
device=None,
auto_dist=True,
val_output_dict_fn=None,
):
"""
Arguments:
Expand Down Expand Up @@ -77,6 +78,9 @@ def __init__(
self.log_freq = log_freq
self.with_pbars = with_pbars
self.device = c_f.default(device, idist.device, {})
self.val_output_dict_fn = c_f.default(
val_output_dict_fn, f_utils.create_output_dict
)
self.trainer_init()
self.collector_init()
self.dist_init_done = False
Expand Down Expand Up @@ -325,7 +329,7 @@ def evaluate_best_model(
def get_collector_step(self, inference):
def collector_step(engine, batch):
batch = c_f.batch_to_device(batch, self.device)
return f_utils.collector_step(inference, batch, f_utils.create_output_dict)
return f_utils.collector_step(inference, batch, self.val_output_dict_fn)

return collector_step

Expand Down
Loading

0 comments on commit 6fbae70

Please sign in to comment.