Skip to content

Commit

Permalink
ruff and mypy on publications/mdm (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
ac554 authored Jun 4, 2024
1 parent 19ef437 commit 2065d11
Show file tree
Hide file tree
Showing 33 changed files with 159 additions and 179 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ repos:
rev: v1.5.0
hooks:
- id: mypy
exclude: ^publications/
# TODO: license header hook


Empty file added publications/__init__.py
Empty file.
Empty file added publications/mdm/__init__.py
Empty file.
7 changes: 3 additions & 4 deletions publications/mdm/mdm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .model import (MDMModelHyperParams, MDMModel)
from .init_algorithm import (MDMInitializationAlgorithmParams,
MDMInitializationAlgorithm)
from .algorithm import (MDMAlgorithmParams, MDMAlgorithm)
from .algorithm import MDMAlgorithm, MDMAlgorithmParams
from .init_algorithm import MDMInitializationAlgorithm, MDMInitializationAlgorithmParams
from .model import MDMModel, MDMModelHyperParams
20 changes: 9 additions & 11 deletions publications/mdm/mdm/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
# -*- coding: utf-8 -*-

from dataclasses import dataclass
from typing import Tuple, Optional, TypeVar, Callable, Union
from typing import Callable, Optional, Tuple, TypeVar, Union

import numpy as np
import torch

from pfl.algorithm.base import AlgorithmHyperParams, FederatedAlgorithm
from pfl.common_types import Population
from pfl.data.dataset import AbstractDataset
from pfl.context import CentralContext
from pfl.data.dataset import AbstractDataset, AbstractDatasetType
from pfl.hyperparam import get_param_value
from pfl.metrics import Metrics
from pfl.context import CentralContext
from pfl.stats import MappedVectorStatistics
from pfl.algorithm.base import FederatedAlgorithm, AlgorithmHyperParams
from pfl.data.dataset import AbstractDatasetType

from publications.mdm.mdm.model import MDMModelType, MDMModelHyperParamsType
from publications.mdm.mdm.bridge.factory import FrameworkBridgeFactory as bridges
from publications.mdm.mdm.model import MDMModelHyperParamsType, MDMModelType


@dataclass(frozen=True)
Expand Down Expand Up @@ -44,7 +40,8 @@ class MDMAlgorithmParams(AlgorithmHyperParams):

class MDMAlgorithm(FederatedAlgorithm[MDMAlgorithmParamsType,
MDMModelHyperParamsType, MDMModelType,
MappedVectorStatistics, AbstractDatasetType]):
MappedVectorStatistics,
AbstractDatasetType]):
"""
Federated algorithm class for learning mixture of Polya
(Dirichlet-Multinomial) distribution using MLE algorithm.
Expand Down Expand Up @@ -161,7 +158,8 @@ def simulate_one_user(
e[:, selected_bin] = posterior_probabilities.view(-1)

statistics = MappedVectorStatistics()
statistics['posterior_probabilities'] = posterior_probabilities.to('cpu')
statistics['posterior_probabilities'] = posterior_probabilities.to(
'cpu')
statistics['numerator'] = numerator.to('cpu')
statistics['denominator'] = denominator.to('cpu')
statistics['num_samples_distribution'] = e.to('cpu')
Expand Down
Empty file.
7 changes: 2 additions & 5 deletions publications/mdm/mdm/bridge/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# -*- coding: utf-8 -*-

from typing import Any, Dict, Protocol, TypeVar, Tuple

from typing import Any, Dict, Protocol, Tuple, TypeVar

Tensor = TypeVar('Tensor')


class PolyaMixtureFrameworkBridge(Protocol[Tensor]):
"""
Interface for Polya-Mixture algorithm for a particular Deep Learning
framework.
framework.
"""

@staticmethod
Expand Down
1 change: 0 additions & 1 deletion publications/mdm/mdm/bridge/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
)
from pfl.internal.ops.framework_types import MLFramework
from pfl.internal.ops.selector import get_framework_module

from publications.mdm.mdm.bridge.base import PolyaMixtureFrameworkBridge


Expand Down
16 changes: 9 additions & 7 deletions publications/mdm/mdm/bridge/pytorch/polya_mixture.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-

from typing import Tuple

import torch

from ..base import PolyaMixtureFrameworkBridge
Expand Down Expand Up @@ -40,11 +39,13 @@ def category_probabilities_polya_mixture_initialization(
def expectation_step(phi, alphas, num_samples_distribution,
category_counts) -> torch.Tensor:
if (num_samples_distribution == 0).any():
raise AssertionError('num_samples_distribution contains zero values, which cannot work with expectation step on clients')
raise AssertionError(
'num_samples_distribution contains zero values, which cannot work with expectation step on clients'
)

# E Step - compute posterior probability of each component
# Compute log prior + log likelihood
# TODO log_v might be missing + torch.lgamma(torch.sum(counts)+1) - torch.sum(torch.lgamma(category_counts+1), dim=1, keepdim=False)
# TODO log_v might be missing + torch.lgamma(torch.sum(counts)+1) - torch.sum(torch.lgamma(category_counts+1), dim=1, keepdim=False)
phi = torch.Tensor(phi).to('cpu')
alphas = torch.Tensor(alphas).to('cpu')
category_counts = category_counts.to('cpu')
Expand All @@ -56,7 +57,7 @@ def expectation_step(phi, alphas, num_samples_distribution,
torch.sum(
torch.lgamma(category_counts + alphas) - torch.lgamma(alphas),
dim=1,
keepdim=False)) + torch.log(num_samples_distribution)
keepdim=False)) + torch.log(num_samples_distribution)

# TODO Ignore this as log(0) => NaN
# TODO fix this equation so that it works with num_samples_distribution = 0
Expand All @@ -75,12 +76,13 @@ def expectation_step(phi, alphas, num_samples_distribution,

@staticmethod
def maximization_step(posterior_probabilities, category_counts,
alphas) -> torch.Tensor:
alphas) -> torch.Tensor:
# M Step - compute client update to alphas for fixed point update
# which will be applied by the model in process_aggregated_statistics.
# Note the numerator and denominator are both weighted by w (the
# probability vector giving the client belonging to each component).
posterior_probabilities = torch.Tensor(posterior_probabilities).to('cpu')
posterior_probabilities = torch.Tensor(posterior_probabilities).to(
'cpu')
category_counts = torch.Tensor(category_counts).to('cpu')
alphas = torch.Tensor(alphas).to('cpu')
numerator = posterior_probabilities.reshape(
Expand Down
18 changes: 7 additions & 11 deletions publications/mdm/mdm/init_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
# -*- coding: utf-8 -*-

from dataclasses import dataclass
from typing import Tuple, Optional, TypeVar, Callable, Union
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, Optional, Tuple, TypeVar, Union

import numpy as np
import torch

from pfl.algorithm.base import AlgorithmHyperParams, FederatedAlgorithm
from pfl.common_types import Population
from pfl.data.dataset import AbstractDataset
from pfl.context import CentralContext
from pfl.data.dataset import AbstractDataset, AbstractDatasetType
from pfl.hyperparam import get_param_value
from pfl.internal.ops import get_ops
from pfl.metrics import Metrics
from pfl.context import CentralContext
from pfl.stats import MappedVectorStatistics
from pfl.internal.ops import get_ops
from pfl.algorithm.base import FederatedAlgorithm, AlgorithmHyperParams
from pfl.data.dataset import AbstractDatasetType

from publications.mdm.mdm.model import MDMModelType, MDMModelHyperParamsType
from publications.mdm.mdm.bridge.factory import FrameworkBridgeFactory as bridges
from publications.mdm.mdm.model import MDMModelHyperParamsType, MDMModelType


@dataclass(frozen=True)
Expand Down
18 changes: 8 additions & 10 deletions publications/mdm/mdm/model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
# -*- coding: utf-8 -*-

from typing import TypeVar, Generic, Tuple, List, Union, Optional
from dataclasses import dataclass
import os
import joblib
from dataclasses import dataclass
from typing import Generic, List, Optional, Tuple, TypeVar, Union

import joblib
import numpy as np

from pfl.hyperparam.base import ModelHyperParams
from pfl.model.base import Model
from pfl.metrics import Metrics
from pfl.stats import MappedVectorStatistics
from pfl.exception import CheckpointNotFoundError
from pfl.internal.ops.selector import set_framework_module
from pfl.hyperparam.base import ModelHyperParams
from pfl.internal.ops import pytorch_ops
from pfl.internal.ops.selector import get_default_framework_module as get_ops
from pfl.internal.ops.selector import set_framework_module
from pfl.metrics import Metrics
from pfl.model.base import Model
from pfl.stats import MappedVectorStatistics

Tensor = TypeVar('Tensor')
FrameworkModelType = TypeVar('FrameworkModelType')
Expand Down
Empty file.
28 changes: 15 additions & 13 deletions publications/mdm/mdm_paper/training/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from pfl.aggregate.simulate import SimulatedBackend
from pfl.callback import ModelCheckpointingCallback
from pfl.privacy import (CentrallyAppliedPrivacyMechanism, PLDPrivacyAccountant, GaussianMechanism)

from pfl.privacy import CentrallyAppliedPrivacyMechanism, GaussianMechanism, PLDPrivacyAccountant
from publications.mdm.mdm import (
MDMAlgorithm,
MDMAlgorithmParams,
MDMInitializationAlgorithm,
MDMInitializationAlgorithmParams,
MDMModel,
MDMModelHyperParams,
)
from publications.mdm.mdm_utils.utils.tools import ModelCheckpointingIterationCallback
from publications.mdm.mdm import (MDMModel, MDMModelHyperParams,
MDMAlgorithm, MDMAlgorithmParams,
MDMInitializationAlgorithm,
MDMInitializationAlgorithmParams)


def solve_polya_mixture_mle(
Expand All @@ -33,13 +36,12 @@ def solve_polya_mixture_mle(
if add_DP:
num_iterations = arguments.central_num_iterations_init_algorithm + arguments.central_num_iterations_algorithm

accountant = PLDPrivacyAccountant(
num_compositions=num_iterations,
sampling_probability=0.001,
mechanism='gaussian',
epsilon=2,
delta=1e-7,
noise_scale=1.0)
accountant = PLDPrivacyAccountant(num_compositions=num_iterations,
sampling_probability=0.001,
mechanism='gaussian',
epsilon=2,
delta=1e-7,
noise_scale=1.0)
mechanism = GaussianMechanism.from_privacy_accountant(
accountant=accountant, clipping_bound=0.5)

Expand Down
25 changes: 14 additions & 11 deletions publications/mdm/mdm_paper/training/train.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import os
import argparse
import datetime
import os

import joblib
import numpy as np
import torch
import joblib

from pfl.internal.ops import pytorch_ops
from pfl.internal.ops.selector import get_default_framework_module as get_ops
from pfl.internal.ops.selector import set_framework_module
from pfl.internal.platform.selector import get_platform

from publications.mdm.mdm_utils.datasets import make_cifar10_datasets
from publications.mdm.mdm_utils.utils import (add_dataset_args, add_experiment_args,
add_mle_args, add_init_algorithm_args,
add_algorithm_args,
add_histogram_algorithm_args,
add_user_visualisation_args)

from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle
from publications.mdm.mdm_utils.datasets import make_cifar10_datasets
from publications.mdm.mdm_utils.utils import (
add_algorithm_args,
add_dataset_args,
add_experiment_args,
add_histogram_algorithm_args,
add_init_algorithm_args,
add_mle_args,
add_user_visualisation_args,
)


def get_arguments():
Expand Down Expand Up @@ -104,7 +106,8 @@ def get_arguments():
print('simulated_dirichlet_mixture experiment')
if arguments.precomputed_parameter_filepath is None:
print('learn simulated_dirichlet_mixture parameters')
dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0]
dir_path = get_platform().create_checkpoint_directories(
[arguments.mle_param_dirname])[0]
current_time = datetime.datetime.now()
timestamp = current_time.strftime("%Y-%m-%d_%H-%M")
save_dir = (
Expand Down
21 changes: 12 additions & 9 deletions publications/mdm/mdm_paper/training/train_femnist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import argparse
import os

import joblib
import numpy as np
Expand All @@ -9,14 +9,16 @@
from pfl.internal.ops.selector import get_default_framework_module as get_ops
from pfl.internal.ops.selector import set_framework_module
from pfl.internal.platform.selector import get_platform

from publications.mdm.mdm_utils.datasets import make_femnist_datasets
from publications.mdm.mdm_utils.utils import (add_experiment_args, add_mle_args,
add_init_algorithm_args, add_algorithm_args,
add_histogram_algorithm_args,
add_user_visualisation_args)

from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle
from publications.mdm.mdm_utils.datasets import make_femnist_datasets
from publications.mdm.mdm_utils.utils import (
add_algorithm_args,
add_experiment_args,
add_histogram_algorithm_args,
add_init_algorithm_args,
add_mle_args,
add_user_visualisation_args,
)


def get_arguments():
Expand Down Expand Up @@ -56,7 +58,8 @@ def get_arguments():
print('simulated_dirichlet_mixture experiment')
if arguments.precomputed_parameter_filepath is None:
print('learn simulated_dirichlet_mixture parameters')
dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0]
dir_path = get_platform().create_checkpoint_directories(
[arguments.mle_param_dirname])[0]
save_dir = (
f'femnist_{arguments.dataset_type}_{arguments.num_mixture_components}_mixture'
)
Expand Down
23 changes: 13 additions & 10 deletions publications/mdm/mdm_paper/training/train_femnist_rebuttal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import argparse
import os

import joblib
import numpy as np
Expand All @@ -9,15 +9,17 @@
from pfl.internal.ops.selector import get_default_framework_module as get_ops
from pfl.internal.ops.selector import set_framework_module
from pfl.internal.platform.selector import get_platform

from publications.mdm.mdm_utils.datasets import make_femnist_datasets
from publications.mdm.mdm_utils.utils import (add_experiment_args, add_mle_args,
add_init_algorithm_args, add_algorithm_args,
add_histogram_algorithm_args,
add_user_visualisation_args,
add_dataset_preprocessing_args)

from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle
from publications.mdm.mdm_utils.datasets import make_femnist_datasets
from publications.mdm.mdm_utils.utils import (
add_algorithm_args,
add_dataset_preprocessing_args,
add_experiment_args,
add_histogram_algorithm_args,
add_init_algorithm_args,
add_mle_args,
add_user_visualisation_args,
)


def get_arguments():
Expand Down Expand Up @@ -63,7 +65,8 @@ def get_arguments():
print('simulated_dirichlet_mixture experiment')
if arguments.precomputed_parameter_filepath is None:
print('learn simulated_dirichlet_mixture parameters')
dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0]
dir_path = get_platform().create_checkpoint_directories(
[arguments.mle_param_dirname])[0]
save_dir = (
f'femnist_{arguments.dataset_type}_{arguments.num_mixture_components}_mixture_{arguments.filter_method}_filter_method'
)
Expand Down
2 changes: 1 addition & 1 deletion publications/mdm/mdm_utils/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .mixture_dataset import get_user_counts
from .cifar10_dataset import make_cifar10_datasets
from .femnist_dataset import make_femnist_datasets
from .mixture_dataset import get_user_counts
Loading

0 comments on commit 2065d11

Please sign in to comment.