diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da4f768..e91a0b0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,6 +45,7 @@ repos: rev: v1.5.0 hooks: - id: mypy + exclude: ^publications/ # TODO: license header hook diff --git a/publications/__init__.py b/publications/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/__init__.py b/publications/mdm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/mdm/__init__.py b/publications/mdm/mdm/__init__.py index 7ae9963..89c1ea9 100644 --- a/publications/mdm/mdm/__init__.py +++ b/publications/mdm/mdm/__init__.py @@ -1,4 +1,3 @@ -from .model import (MDMModelHyperParams, MDMModel) -from .init_algorithm import (MDMInitializationAlgorithmParams, - MDMInitializationAlgorithm) -from .algorithm import (MDMAlgorithmParams, MDMAlgorithm) +from .algorithm import MDMAlgorithm, MDMAlgorithmParams +from .init_algorithm import MDMInitializationAlgorithm, MDMInitializationAlgorithmParams +from .model import MDMModel, MDMModelHyperParams diff --git a/publications/mdm/mdm/algorithm.py b/publications/mdm/mdm/algorithm.py index 0f3be49..0f360c6 100644 --- a/publications/mdm/mdm/algorithm.py +++ b/publications/mdm/mdm/algorithm.py @@ -1,22 +1,18 @@ -# -*- coding: utf-8 -*- - from dataclasses import dataclass -from typing import Tuple, Optional, TypeVar, Callable, Union +from typing import Callable, Optional, Tuple, TypeVar, Union import numpy as np import torch +from pfl.algorithm.base import AlgorithmHyperParams, FederatedAlgorithm from pfl.common_types import Population -from pfl.data.dataset import AbstractDataset +from pfl.context import CentralContext +from pfl.data.dataset import AbstractDataset, AbstractDatasetType from pfl.hyperparam import get_param_value from pfl.metrics import Metrics -from pfl.context import CentralContext from pfl.stats import MappedVectorStatistics -from pfl.algorithm.base import FederatedAlgorithm, AlgorithmHyperParams -from pfl.data.dataset import AbstractDatasetType - -from publications.mdm.mdm.model import MDMModelType, MDMModelHyperParamsType from publications.mdm.mdm.bridge.factory import FrameworkBridgeFactory as bridges +from publications.mdm.mdm.model import MDMModelHyperParamsType, MDMModelType @dataclass(frozen=True) @@ -44,7 +40,8 @@ class MDMAlgorithmParams(AlgorithmHyperParams): class MDMAlgorithm(FederatedAlgorithm[MDMAlgorithmParamsType, MDMModelHyperParamsType, MDMModelType, - MappedVectorStatistics, AbstractDatasetType]): + MappedVectorStatistics, + AbstractDatasetType]): """ Federated algorithm class for learning mixture of Polya (Dirichlet-Multinomial) distribution using MLE algorithm. @@ -161,7 +158,8 @@ def simulate_one_user( e[:, selected_bin] = posterior_probabilities.view(-1) statistics = MappedVectorStatistics() - statistics['posterior_probabilities'] = posterior_probabilities.to('cpu') + statistics['posterior_probabilities'] = posterior_probabilities.to( + 'cpu') statistics['numerator'] = numerator.to('cpu') statistics['denominator'] = denominator.to('cpu') statistics['num_samples_distribution'] = e.to('cpu') diff --git a/publications/mdm/mdm/bridge/__init__.py b/publications/mdm/mdm/bridge/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/mdm/bridge/base.py b/publications/mdm/mdm/bridge/base.py index f63e763..a9900ef 100644 --- a/publications/mdm/mdm/bridge/base.py +++ b/publications/mdm/mdm/bridge/base.py @@ -1,7 +1,4 @@ -# -*- coding: utf-8 -*- - -from typing import Any, Dict, Protocol, TypeVar, Tuple - +from typing import Any, Dict, Protocol, Tuple, TypeVar Tensor = TypeVar('Tensor') @@ -9,7 +6,7 @@ class PolyaMixtureFrameworkBridge(Protocol[Tensor]): """ Interface for Polya-Mixture algorithm for a particular Deep Learning - framework. + framework. """ @staticmethod diff --git a/publications/mdm/mdm/bridge/factory.py b/publications/mdm/mdm/bridge/factory.py index c395e26..4c3e449 100644 --- a/publications/mdm/mdm/bridge/factory.py +++ b/publications/mdm/mdm/bridge/factory.py @@ -8,7 +8,6 @@ ) from pfl.internal.ops.framework_types import MLFramework from pfl.internal.ops.selector import get_framework_module - from publications.mdm.mdm.bridge.base import PolyaMixtureFrameworkBridge diff --git a/publications/mdm/mdm/bridge/pytorch/polya_mixture.py b/publications/mdm/mdm/bridge/pytorch/polya_mixture.py index a43ed9c..463e42a 100644 --- a/publications/mdm/mdm/bridge/pytorch/polya_mixture.py +++ b/publications/mdm/mdm/bridge/pytorch/polya_mixture.py @@ -1,6 +1,5 @@ -# -*- coding: utf-8 -*- - from typing import Tuple + import torch from ..base import PolyaMixtureFrameworkBridge @@ -40,11 +39,13 @@ def category_probabilities_polya_mixture_initialization( def expectation_step(phi, alphas, num_samples_distribution, category_counts) -> torch.Tensor: if (num_samples_distribution == 0).any(): - raise AssertionError('num_samples_distribution contains zero values, which cannot work with expectation step on clients') + raise AssertionError( + 'num_samples_distribution contains zero values, which cannot work with expectation step on clients' + ) # E Step - compute posterior probability of each component # Compute log prior + log likelihood - # TODO log_v might be missing + torch.lgamma(torch.sum(counts)+1) - torch.sum(torch.lgamma(category_counts+1), dim=1, keepdim=False) + # TODO log_v might be missing + torch.lgamma(torch.sum(counts)+1) - torch.sum(torch.lgamma(category_counts+1), dim=1, keepdim=False) phi = torch.Tensor(phi).to('cpu') alphas = torch.Tensor(alphas).to('cpu') category_counts = category_counts.to('cpu') @@ -56,7 +57,7 @@ def expectation_step(phi, alphas, num_samples_distribution, torch.sum( torch.lgamma(category_counts + alphas) - torch.lgamma(alphas), dim=1, - keepdim=False)) + torch.log(num_samples_distribution) + keepdim=False)) + torch.log(num_samples_distribution) # TODO Ignore this as log(0) => NaN # TODO fix this equation so that it works with num_samples_distribution = 0 @@ -75,12 +76,13 @@ def expectation_step(phi, alphas, num_samples_distribution, @staticmethod def maximization_step(posterior_probabilities, category_counts, - alphas) -> torch.Tensor: + alphas) -> torch.Tensor: # M Step - compute client update to alphas for fixed point update # which will be applied by the model in process_aggregated_statistics. # Note the numerator and denominator are both weighted by w (the # probability vector giving the client belonging to each component). - posterior_probabilities = torch.Tensor(posterior_probabilities).to('cpu') + posterior_probabilities = torch.Tensor(posterior_probabilities).to( + 'cpu') category_counts = torch.Tensor(category_counts).to('cpu') alphas = torch.Tensor(alphas).to('cpu') numerator = posterior_probabilities.reshape( diff --git a/publications/mdm/mdm/init_algorithm.py b/publications/mdm/mdm/init_algorithm.py index 9a24c1b..9acaaee 100644 --- a/publications/mdm/mdm/init_algorithm.py +++ b/publications/mdm/mdm/init_algorithm.py @@ -1,24 +1,20 @@ -# -*- coding: utf-8 -*- - -from dataclasses import dataclass -from typing import Tuple, Optional, TypeVar, Callable, Union from collections import defaultdict +from dataclasses import dataclass +from typing import Callable, Optional, Tuple, TypeVar, Union import numpy as np import torch +from pfl.algorithm.base import AlgorithmHyperParams, FederatedAlgorithm from pfl.common_types import Population -from pfl.data.dataset import AbstractDataset +from pfl.context import CentralContext +from pfl.data.dataset import AbstractDataset, AbstractDatasetType from pfl.hyperparam import get_param_value +from pfl.internal.ops import get_ops from pfl.metrics import Metrics -from pfl.context import CentralContext from pfl.stats import MappedVectorStatistics -from pfl.internal.ops import get_ops -from pfl.algorithm.base import FederatedAlgorithm, AlgorithmHyperParams -from pfl.data.dataset import AbstractDatasetType - -from publications.mdm.mdm.model import MDMModelType, MDMModelHyperParamsType from publications.mdm.mdm.bridge.factory import FrameworkBridgeFactory as bridges +from publications.mdm.mdm.model import MDMModelHyperParamsType, MDMModelType @dataclass(frozen=True) diff --git a/publications/mdm/mdm/model.py b/publications/mdm/mdm/model.py index 3372fa3..65813f0 100644 --- a/publications/mdm/mdm/model.py +++ b/publications/mdm/mdm/model.py @@ -1,20 +1,18 @@ -# -*- coding: utf-8 -*- - -from typing import TypeVar, Generic, Tuple, List, Union, Optional -from dataclasses import dataclass import os -import joblib +from dataclasses import dataclass +from typing import Generic, List, Optional, Tuple, TypeVar, Union +import joblib import numpy as np -from pfl.hyperparam.base import ModelHyperParams -from pfl.model.base import Model -from pfl.metrics import Metrics -from pfl.stats import MappedVectorStatistics from pfl.exception import CheckpointNotFoundError -from pfl.internal.ops.selector import set_framework_module +from pfl.hyperparam.base import ModelHyperParams from pfl.internal.ops import pytorch_ops from pfl.internal.ops.selector import get_default_framework_module as get_ops +from pfl.internal.ops.selector import set_framework_module +from pfl.metrics import Metrics +from pfl.model.base import Model +from pfl.stats import MappedVectorStatistics Tensor = TypeVar('Tensor') FrameworkModelType = TypeVar('FrameworkModelType') diff --git a/publications/mdm/mdm_paper/notebooks/__init__.py b/publications/mdm/mdm_paper/notebooks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/mdm_paper/training/mle.py b/publications/mdm/mdm_paper/training/mle.py index 3376740..88d9394 100644 --- a/publications/mdm/mdm_paper/training/mle.py +++ b/publications/mdm/mdm_paper/training/mle.py @@ -2,13 +2,16 @@ from pfl.aggregate.simulate import SimulatedBackend from pfl.callback import ModelCheckpointingCallback -from pfl.privacy import (CentrallyAppliedPrivacyMechanism, PLDPrivacyAccountant, GaussianMechanism) - +from pfl.privacy import CentrallyAppliedPrivacyMechanism, GaussianMechanism, PLDPrivacyAccountant +from publications.mdm.mdm import ( + MDMAlgorithm, + MDMAlgorithmParams, + MDMInitializationAlgorithm, + MDMInitializationAlgorithmParams, + MDMModel, + MDMModelHyperParams, +) from publications.mdm.mdm_utils.utils.tools import ModelCheckpointingIterationCallback -from publications.mdm.mdm import (MDMModel, MDMModelHyperParams, - MDMAlgorithm, MDMAlgorithmParams, - MDMInitializationAlgorithm, - MDMInitializationAlgorithmParams) def solve_polya_mixture_mle( @@ -33,13 +36,12 @@ def solve_polya_mixture_mle( if add_DP: num_iterations = arguments.central_num_iterations_init_algorithm + arguments.central_num_iterations_algorithm - accountant = PLDPrivacyAccountant( - num_compositions=num_iterations, - sampling_probability=0.001, - mechanism='gaussian', - epsilon=2, - delta=1e-7, - noise_scale=1.0) + accountant = PLDPrivacyAccountant(num_compositions=num_iterations, + sampling_probability=0.001, + mechanism='gaussian', + epsilon=2, + delta=1e-7, + noise_scale=1.0) mechanism = GaussianMechanism.from_privacy_accountant( accountant=accountant, clipping_bound=0.5) diff --git a/publications/mdm/mdm_paper/training/train.py b/publications/mdm/mdm_paper/training/train.py index 9431ccd..811cefe 100644 --- a/publications/mdm/mdm_paper/training/train.py +++ b/publications/mdm/mdm_paper/training/train.py @@ -1,24 +1,26 @@ -import os import argparse import datetime +import os +import joblib import numpy as np import torch -import joblib from pfl.internal.ops import pytorch_ops from pfl.internal.ops.selector import get_default_framework_module as get_ops from pfl.internal.ops.selector import set_framework_module from pfl.internal.platform.selector import get_platform - -from publications.mdm.mdm_utils.datasets import make_cifar10_datasets -from publications.mdm.mdm_utils.utils import (add_dataset_args, add_experiment_args, - add_mle_args, add_init_algorithm_args, - add_algorithm_args, - add_histogram_algorithm_args, - add_user_visualisation_args) - from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle +from publications.mdm.mdm_utils.datasets import make_cifar10_datasets +from publications.mdm.mdm_utils.utils import ( + add_algorithm_args, + add_dataset_args, + add_experiment_args, + add_histogram_algorithm_args, + add_init_algorithm_args, + add_mle_args, + add_user_visualisation_args, +) def get_arguments(): @@ -104,7 +106,8 @@ def get_arguments(): print('simulated_dirichlet_mixture experiment') if arguments.precomputed_parameter_filepath is None: print('learn simulated_dirichlet_mixture parameters') - dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0] + dir_path = get_platform().create_checkpoint_directories( + [arguments.mle_param_dirname])[0] current_time = datetime.datetime.now() timestamp = current_time.strftime("%Y-%m-%d_%H-%M") save_dir = ( diff --git a/publications/mdm/mdm_paper/training/train_femnist.py b/publications/mdm/mdm_paper/training/train_femnist.py index dd6c90e..9b118b3 100644 --- a/publications/mdm/mdm_paper/training/train_femnist.py +++ b/publications/mdm/mdm_paper/training/train_femnist.py @@ -1,5 +1,5 @@ -import os import argparse +import os import joblib import numpy as np @@ -9,14 +9,16 @@ from pfl.internal.ops.selector import get_default_framework_module as get_ops from pfl.internal.ops.selector import set_framework_module from pfl.internal.platform.selector import get_platform - -from publications.mdm.mdm_utils.datasets import make_femnist_datasets -from publications.mdm.mdm_utils.utils import (add_experiment_args, add_mle_args, - add_init_algorithm_args, add_algorithm_args, - add_histogram_algorithm_args, - add_user_visualisation_args) - from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle +from publications.mdm.mdm_utils.datasets import make_femnist_datasets +from publications.mdm.mdm_utils.utils import ( + add_algorithm_args, + add_experiment_args, + add_histogram_algorithm_args, + add_init_algorithm_args, + add_mle_args, + add_user_visualisation_args, +) def get_arguments(): @@ -56,7 +58,8 @@ def get_arguments(): print('simulated_dirichlet_mixture experiment') if arguments.precomputed_parameter_filepath is None: print('learn simulated_dirichlet_mixture parameters') - dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0] + dir_path = get_platform().create_checkpoint_directories( + [arguments.mle_param_dirname])[0] save_dir = ( f'femnist_{arguments.dataset_type}_{arguments.num_mixture_components}_mixture' ) diff --git a/publications/mdm/mdm_paper/training/train_femnist_rebuttal.py b/publications/mdm/mdm_paper/training/train_femnist_rebuttal.py index 096c7ec..9848838 100644 --- a/publications/mdm/mdm_paper/training/train_femnist_rebuttal.py +++ b/publications/mdm/mdm_paper/training/train_femnist_rebuttal.py @@ -1,5 +1,5 @@ -import os import argparse +import os import joblib import numpy as np @@ -9,15 +9,17 @@ from pfl.internal.ops.selector import get_default_framework_module as get_ops from pfl.internal.ops.selector import set_framework_module from pfl.internal.platform.selector import get_platform - -from publications.mdm.mdm_utils.datasets import make_femnist_datasets -from publications.mdm.mdm_utils.utils import (add_experiment_args, add_mle_args, - add_init_algorithm_args, add_algorithm_args, - add_histogram_algorithm_args, - add_user_visualisation_args, - add_dataset_preprocessing_args) - from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle +from publications.mdm.mdm_utils.datasets import make_femnist_datasets +from publications.mdm.mdm_utils.utils import ( + add_algorithm_args, + add_dataset_preprocessing_args, + add_experiment_args, + add_histogram_algorithm_args, + add_init_algorithm_args, + add_mle_args, + add_user_visualisation_args, +) def get_arguments(): @@ -63,7 +65,8 @@ def get_arguments(): print('simulated_dirichlet_mixture experiment') if arguments.precomputed_parameter_filepath is None: print('learn simulated_dirichlet_mixture parameters') - dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0] + dir_path = get_platform().create_checkpoint_directories( + [arguments.mle_param_dirname])[0] save_dir = ( f'femnist_{arguments.dataset_type}_{arguments.num_mixture_components}_mixture_{arguments.filter_method}_filter_method' ) diff --git a/publications/mdm/mdm_utils/datasets/__init__.py b/publications/mdm/mdm_utils/datasets/__init__.py index e189449..533151f 100644 --- a/publications/mdm/mdm_utils/datasets/__init__.py +++ b/publications/mdm/mdm_utils/datasets/__init__.py @@ -1,3 +1,3 @@ -from .mixture_dataset import get_user_counts from .cifar10_dataset import make_cifar10_datasets from .femnist_dataset import make_femnist_datasets +from .mixture_dataset import get_user_counts diff --git a/publications/mdm/mdm_utils/datasets/cifar10_dataset.py b/publications/mdm/mdm_utils/datasets/cifar10_dataset.py index 3d06410..bd27c0e 100644 --- a/publications/mdm/mdm_utils/datasets/cifar10_dataset.py +++ b/publications/mdm/mdm_utils/datasets/cifar10_dataset.py @@ -1,19 +1,14 @@ -# -*- coding: utf-8 -*- - import os import pickle from typing import Callable, List, Optional, Tuple import numpy as np -from pfl.data import (ArtificialFederatedDataset, FederatedDataset, - FederatedDatasetBase) -from pfl.data.sampling import get_user_sampler, get_data_sampler +from pfl.data import ArtificialFederatedDataset, FederatedDataset, FederatedDatasetBase from pfl.data.dataset import Dataset +from pfl.data.sampling import get_data_sampler, get_user_sampler -from .mixture_dataset import (ArtificialFederatedDatasetMixture, - partition_by_dirichlet_mixture_class_distribution - ) +from .mixture_dataset import ArtificialFederatedDatasetMixture, partition_by_dirichlet_mixture_class_distribution from .sampler import DirichletDataSampler @@ -94,7 +89,7 @@ def make_federated_dataset( images = numpy_to_tensor(images) labels = numpy_to_tensor(labels) - data = dict() + data = {} for user_id in range(len(user_idxs)): data[user_id] = [ images[user_idxs[user_id]], labels[user_idxs[user_id]] diff --git a/publications/mdm/mdm_utils/datasets/femnist_dataset.py b/publications/mdm/mdm_utils/datasets/femnist_dataset.py index c921e4f..21880a8 100644 --- a/publications/mdm/mdm_utils/datasets/femnist_dataset.py +++ b/publications/mdm/mdm_utils/datasets/femnist_dataset.py @@ -1,27 +1,23 @@ -# -*- coding: utf-8 -*- - import os -from typing import Callable, Dict, Tuple, List, Optional +from typing import Callable, Dict, List, Optional, Tuple import h5py import numpy as np import torch from pfl.data import ArtificialFederatedDataset, FederatedDataset -from pfl.data.sampling import get_user_sampler, get_data_sampler from pfl.data.dataset import Dataset +from pfl.data.sampling import get_data_sampler, get_user_sampler -from .mixture_dataset import (ArtificialFederatedDatasetMixture, - partition_by_dirichlet_mixture_class_distribution - ) +from .mixture_dataset import ArtificialFederatedDatasetMixture, partition_by_dirichlet_mixture_class_distribution from .sampler import DirichletDataSampler def _sample_users(user_id_to_data: Dict[str, List[np.ndarray]], filter_method: Optional[str] = None, - sample_fraction: float = None, - start_idx: int = None, - end_idx: int = None, + sample_fraction: Optional[float] = None, + start_idx: Optional[int] = None, + end_idx: Optional[int] = None, include_sampled: bool = True): user_ids = list(user_id_to_data.keys()) @@ -156,7 +152,7 @@ def make_special_federated_dataset( #for k,v in indices_per_class.items(): # print('indices_per_class', k, len(v)) - new_user_id_to_data = dict() + new_user_id_to_data = {} start_id_per_class = {i: 0 for i in unique_labels} #print('start_id_per_class', start_id_per_class) for user_id, data in user_id_to_data.items(): @@ -218,10 +214,10 @@ def make_central_dataset( """ Create central dataset from a FEMNIST data file. """ - images = np.concatenate([data[0].cpu() for data in user_id_to_data.values()], - axis=0) - labels = np.concatenate([data[1].cpu() for data in user_id_to_data.values()], - axis=0) + images = np.concatenate( + [data[0].cpu() for data in user_id_to_data.values()], axis=0) + labels = np.concatenate( + [data[1].cpu() for data in user_id_to_data.values()], axis=0) return Dataset(raw_data=[images, labels]) @@ -235,9 +231,9 @@ def make_femnist_datasets( alphas=None, user_dataset_len_samplers=None, filter_method: Optional[str] = None, - sample_fraction: float = None, - start_idx: int = None, - end_idx: int = None, + sample_fraction: Optional[float] = None, + start_idx: Optional[int] = None, + end_idx: Optional[int] = None, include_sampled: bool = True ) -> Tuple[FederatedDataset, FederatedDataset, Dataset]: """ diff --git a/publications/mdm/mdm_utils/datasets/mixture_dataset.py b/publications/mdm/mdm_utils/datasets/mixture_dataset.py index 495d4b1..2de7789 100644 --- a/publications/mdm/mdm_utils/datasets/mixture_dataset.py +++ b/publications/mdm/mdm_utils/datasets/mixture_dataset.py @@ -1,12 +1,12 @@ from collections import defaultdict from typing import Callable, Iterable, List, Tuple -import numpy as np import joblib +import numpy as np from pfl.data import ArtificialFederatedDataset, FederatedDatasetBase from pfl.data.dataset import AbstractDataset -from pfl.internal.ops.selector import (get_default_framework_module as get_ops) +from pfl.internal.ops.selector import get_default_framework_module as get_ops class ArtificialFederatedDatasetMixture(FederatedDatasetBase): @@ -139,7 +139,7 @@ def get_user_counts(training_federated_dataset, num_classes, over a number of central iterations in train.py. """ print('get_user_counts') - all_counts = dict() + all_counts = {} for r in range(num_central_iterations): all_counts[r + 1] = [] l = list(training_federated_dataset.get_cohort(cohort_size)) diff --git a/publications/mdm/mdm_utils/datasets/sampler.py b/publications/mdm/mdm_utils/datasets/sampler.py index 070783b..c419d53 100644 --- a/publications/mdm/mdm_utils/datasets/sampler.py +++ b/publications/mdm/mdm_utils/datasets/sampler.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import itertools import numpy as np diff --git a/publications/mdm/mdm_utils/models/argument_parsing.py b/publications/mdm/mdm_utils/models/argument_parsing.py index 3d70d70..3f109a8 100644 --- a/publications/mdm/mdm_utils/models/argument_parsing.py +++ b/publications/mdm/mdm_utils/models/argument_parsing.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import argparse from typing import Optional, Tuple diff --git a/publications/mdm/mdm_utils/models/pytorch/__init__.py b/publications/mdm/mdm_utils/models/pytorch/__init__.py index 664abd6..ee4accf 100644 --- a/publications/mdm/mdm_utils/models/pytorch/__init__.py +++ b/publications/mdm/mdm_utils/models/pytorch/__init__.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -from .cnn import simple_cnn, multi_label_cnn +from .cnn import multi_label_cnn, simple_cnn from .dnn import dnn, simple_dnn from .lstm import lm_lstm from .transformer import lm_transformer diff --git a/publications/mdm/mdm_utils/models/pytorch/cnn.py b/publications/mdm/mdm_utils/models/pytorch/cnn.py index ef85a40..f8d00ba 100644 --- a/publications/mdm/mdm_utils/models/pytorch/cnn.py +++ b/publications/mdm/mdm_utils/models/pytorch/cnn.py @@ -1,17 +1,16 @@ -# -*- coding: utf-8 -*- - import types -from typing import Tuple, List +from typing import List, Tuple import numpy as np import torch # type: ignore import torch.nn as nn import torch.nn.functional as F + from pfl.metrics import Weighted -from .layer import Transpose2D -from .metrics import image_classification_metrics, image_classification_loss from ..numpy.metrics import AveragedPrecision, MacroWeighted +from .layer import Transpose2D +from .metrics import image_classification_loss, image_classification_metrics def multi_label_cnn( @@ -41,9 +40,8 @@ def multi_label_cnn( import torchvision.models # type: ignore import torchvision.transforms as transforms # type: ignore - from .module_modification import (validate_no_batchnorm, - freeze_batchnorm_modules, - convert_batchnorm_modules) + + from .module_modification import convert_batchnorm_modules, freeze_batchnorm_modules, validate_no_batchnorm torchvision_models = torchvision.models.__dict__ @@ -218,7 +216,7 @@ def simple_cnn(input_shape: Tuple[int, ...], num_outputs: int) -> nn.Module: # Apply Glorot (Xavier) uniform initialization to match TF2 model. for m in model.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + if isinstance(m, (nn.Conv2d, nn.Linear)): torch.nn.init.xavier_uniform_(m.weight) model.loss = types.MethodType(image_classification_loss, model) diff --git a/publications/mdm/mdm_utils/models/pytorch/dnn.py b/publications/mdm/mdm_utils/models/pytorch/dnn.py index a07e152..b6ac8bf 100644 --- a/publications/mdm/mdm_utils/models/pytorch/dnn.py +++ b/publications/mdm/mdm_utils/models/pytorch/dnn.py @@ -1,13 +1,11 @@ -# -*- coding: utf-8 -*- - -from typing import Tuple import functools import types +from typing import Tuple -import torch.nn as nn import numpy as np +import torch.nn as nn -from .metrics import image_classification_metrics, image_classification_loss +from .metrics import image_classification_loss, image_classification_metrics def dnn(input_shape: Tuple[int, ...], hidden_dims: Tuple[int, ...], diff --git a/publications/mdm/mdm_utils/models/pytorch/layer.py b/publications/mdm/mdm_utils/models/pytorch/layer.py index b21d35d..9c62736 100644 --- a/publications/mdm/mdm_utils/models/pytorch/layer.py +++ b/publications/mdm/mdm_utils/models/pytorch/layer.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- - from abc import ABC import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.batchnorm import _NormBase + from ..numpy.layer import positional_encoding @@ -22,10 +21,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # and use pretrained statistics in training as well self.training = False - if self.momentum is None: - exponential_average_factor = 0.0 - else: - exponential_average_factor = self.momentum + exponential_average_factor = 0.0 if self.momentum is None else self.momentum bn_training = (self.running_mean is None) and (self.running_var is None) @@ -48,24 +44,22 @@ class FrozenBatchNorm1D(_FrozenBatchNorm): def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: - raise ValueError('expected 2D or 3D input (got {}D input)'.format( - input.dim())) + raise ValueError( + f'expected 2D or 3D input (got {input.dim()}D input)') class FrozenBatchNorm2D(_FrozenBatchNorm): def _check_input_dim(self, input): if input.dim() != 4: - raise ValueError('expected 4D input (got {}D input)'.format( - input.dim())) + raise ValueError(f'expected 4D input (got {input.dim()}D input)') class FrozenBatchNorm3D(_FrozenBatchNorm): def _check_input_dim(self, input): if input.dim() != 5: - raise ValueError('expected 5D input (got {}D input)'.format( - input.dim())) + raise ValueError(f'expected 5D input (got {input.dim()}D input)') class Transpose2D(nn.Module): diff --git a/publications/mdm/mdm_utils/models/pytorch/metrics.py b/publications/mdm/mdm_utils/models/pytorch/metrics.py index 2f9f096..92c1523 100644 --- a/publications/mdm/mdm_utils/models/pytorch/metrics.py +++ b/publications/mdm/mdm_utils/models/pytorch/metrics.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - from typing import Dict import torch diff --git a/publications/mdm/mdm_utils/models/pytorch/module_modification.py b/publications/mdm/mdm_utils/models/pytorch/module_modification.py index 39d9047..345fe64 100644 --- a/publications/mdm/mdm_utils/models/pytorch/module_modification.py +++ b/publications/mdm/mdm_utils/models/pytorch/module_modification.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - from typing import Callable, Type from torch import nn diff --git a/publications/mdm/mdm_utils/models/pytorch_model.py b/publications/mdm/mdm_utils/models/pytorch_model.py index cadc368..5ab603e 100644 --- a/publications/mdm/mdm_utils/models/pytorch_model.py +++ b/publications/mdm/mdm_utils/models/pytorch_model.py @@ -1,5 +1,5 @@ import types -from typing import Tuple, Dict +from typing import Dict, Tuple import torch from torch import nn @@ -42,7 +42,7 @@ def simple_cnn(input_shape: Tuple[int, ...], num_outputs: int) -> nn.Module: # Apply Glorot (Xavier) uniform initialization to match TF2 model. for m in model.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + if isinstance(m, (nn.Conv2d, nn.Linear)): torch.nn.init.xavier_uniform_(m.weight) model.loss = types.MethodType(image_classification_loss, model) diff --git a/publications/mdm/mdm_utils/utils/__init__.py b/publications/mdm/mdm_utils/utils/__init__.py index b3fccc1..37bf735 100644 --- a/publications/mdm/mdm_utils/utils/__init__.py +++ b/publications/mdm/mdm_utils/utils/__init__.py @@ -1,5 +1,10 @@ -from .argument_parsing import (add_dataset_args, add_experiment_args, - add_init_algorithm_args, add_algorithm_args, - add_mle_args, add_histogram_algorithm_args, - add_user_visualisation_args, - add_dataset_preprocessing_args) +from .argument_parsing import ( + add_algorithm_args, + add_dataset_args, + add_dataset_preprocessing_args, + add_experiment_args, + add_histogram_algorithm_args, + add_init_algorithm_args, + add_mle_args, + add_user_visualisation_args, +) diff --git a/publications/mdm/mdm_utils/utils/argument_parsing.py b/publications/mdm/mdm_utils/utils/argument_parsing.py index 2650e66..e936c8b 100644 --- a/publications/mdm/mdm_utils/utils/argument_parsing.py +++ b/publications/mdm/mdm_utils/utils/argument_parsing.py @@ -7,12 +7,12 @@ def __init__(self, option_strings, dest, **kwargs): argparse.Action.__init__(self, option_strings, dest, **kwargs) def __call__(self, parser, namespace, values, option_string=None): - false_values = set(['false', 'no']) - true_values = set(['true', 'yes']) + false_values = {'false', 'no'} + true_values = {'true', 'yes'} values = values.lower() - if not values in (false_values | true_values): + if values not in (false_values | true_values): raise argparse.ArgumentError( self, 'Value must be either "true" or "false"') value = (values in true_values) @@ -24,7 +24,9 @@ def add_experiment_args(parser): parser.add_argument('--seed', type=int, default=0) parser.add_argument('--data_dir', type=str) parser.add_argument('--dirname', type=str) - parser.add_argument('--mle_param_dirname', type=str, default='publications/mdm/mle_params') + parser.add_argument('--mle_param_dirname', + type=str, + default='publications/mdm/mle_params') parser.add_argument( '--precomputed_parameter_filepath', type=str, @@ -61,17 +63,19 @@ def add_dataset_preprocessing_args(parser): def float_list(arg): try: float_values = [float(val) for val in arg.split()] - return float_values except ValueError: raise argparse.ArgumentTypeError("Invalid float values in the list") + else: + return float_values def int_list(arg): try: int_values = [int(val) for val in arg.split()] - return int_values except ValueError: raise argparse.ArgumentTypeError("Invalid int values in the list") + else: + return int_values def add_dataset_args(parser): diff --git a/publications/mdm/mdm_utils/utils/tools.py b/publications/mdm/mdm_utils/utils/tools.py index 6280ee1..44deab7 100644 --- a/publications/mdm/mdm_utils/utils/tools.py +++ b/publications/mdm/mdm_utils/utils/tools.py @@ -3,7 +3,6 @@ from pfl.callback import TrainingProcessCallback from pfl.internal.ops.selector import get_default_framework_module as get_ops - from pfl.metrics import Metrics from pfl.model.base import StatefulModel diff --git a/publications/mdm/mdm_utils/utils/visualize_results.py b/publications/mdm/mdm_utils/utils/visualize_results.py index 281f39d..67ddc80 100644 --- a/publications/mdm/mdm_utils/utils/visualize_results.py +++ b/publications/mdm/mdm_utils/utils/visualize_results.py @@ -13,7 +13,7 @@ def plot_cifar10_results(): df = pd.read_csv(filename) experiments = np.unique(df['experiment'].values).tolist() - dfs = dict() + dfs = {} for experiment in experiments: dfs[experiment] = (df.loc[df['experiment'] == experiment]) @@ -21,13 +21,13 @@ def plot_cifar10_results(): 'cohort_size', 'local_batch_size', 'local_learning_rate', 'local_num_epochs' ] - unique_vals = dict() + unique_vals = {} for column_name in column_names: unique_vals[column_name] = np.unique(dfs['live'][column_name]).tolist() - accs = dict() + accs = {} for name, df in dfs.items(): - accs[name] = dict() + accs[name] = {} for tup in product(*unique_vals.values()): filter_dic = dict(zip(column_names, tup)) a = df.loc[(df[list(filter_dic)] == pd.Series(filter_dic)).all( @@ -38,7 +38,7 @@ def plot_cifar10_results(): permutation = np.argsort(-x) mask = np.array(list(accs['live'].values()))[permutation] >= 0.6 - dic = dict() + dic = {} c = dict(zip(accs.keys(), ['blue', 'red', 'green'])) plt.rcParams.update({'font.size': 13})