Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions examples/17_resnet34_imagenet_conversion_to_analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,4 @@
# convolutions)
model = convert_to_analog(model, rpu_config)

# Note: One can also use ``convert_to_analog_mapped`` instead to
# convert e.g. ``Conv2d`` to ``AnalogConv2dMapped`` (using a special way to
# unfold over multiple tiles in a more memory efficient way
# for some analog tiles on GPU)

print(model)
4 changes: 2 additions & 2 deletions examples/19_analog_summary_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import nn

# Imports from aihwkit.
from aihwkit.nn.conversion import convert_to_analog_mapped
from aihwkit.nn.conversion import convert_to_analog
from aihwkit.simulator.configs import SingleRPUConfig, ConstantStepDevice
from aihwkit.utils.analog_info import analog_summary

Expand All @@ -36,6 +36,6 @@
nn.LogSoftmax(dim=1),
)

analog_model = convert_to_analog_mapped(model, rpu_config=rpu_config)
analog_model = convert_to_analog(model, rpu_config=rpu_config)

analog_summary(analog_model, (1, 1, 28, 28))
8 changes: 4 additions & 4 deletions notebooks/tutorial/analog_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@
"outputs": [],
"source": [
"from torchvision.models import resnet18\n",
"from aihwkit.nn.conversion import convert_to_analog_mapped\n",
"from aihwkit.nn.conversion import convert_to_analog\n",
"\n",
"analog_model = convert_to_analog_mapped(resnet18(), rpu_config=rpu_config)\n",
"analog_model = convert_to_analog(resnet18(), rpu_config=rpu_config)\n",
"\n",
"print(analog_model)"
]
Expand Down Expand Up @@ -575,15 +575,15 @@
"from torchmetrics.functional import accuracy\n",
"\n",
"from aihwkit.optim import AnalogSGD\n",
"from aihwkit.nn.conversion import convert_to_analog_mapped\n",
"from aihwkit.nn.conversion import convert_to_analog\n",
"\n",
"\n",
"class LitAnalogModel(pl.LightningModule):\n",
" def __init__(self, model, rpu_config, lr=0.05):\n",
" super().__init__()\n",
"\n",
" # We simply convert the given model to analog on-the-fly\n",
" self.analog_model = convert_to_analog_mapped(model, rpu_config)\n",
" self.analog_model = convert_to_analog(model, rpu_config)\n",
" self.lr = lr\n",
"\n",
" def forward(self, x):\n",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/tutorial/extending_functionality.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
"from torchmetrics.functional import accuracy\n",
"\n",
"from aihwkit.optim import AnalogSGD\n",
"from aihwkit.nn.conversion import convert_to_analog_mapped\n",
"from aihwkit.nn.conversion import convert_to_analog\n",
"\n",
"PATH_DATASET = os.path.join('data', 'DATASET')\n",
"os.makedirs(PATH_DATASET, exist_ok=True)\n",
Expand Down Expand Up @@ -163,7 +163,7 @@
" super().__init__()\n",
"\n",
" # We simply convert the given model to analog on-the-fly\n",
" self.analog_model = convert_to_analog_mapped(model, rpu_config)\n",
" self.analog_model = convert_to_analog(model, rpu_config)\n",
" self.lr = lr\n",
"\n",
" def forward(self, x):\n",
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ add_select = D204,D215,D401,D404
match-dir = ^(?!helpers|definitions).*

[mypy]
python_version = 3.8
python_version = 3.10
namespace_packages = True
ignore_missing_imports = True
warn_redundant_casts = True
Expand Down
71 changes: 0 additions & 71 deletions src/aihwkit/nn/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,77 +214,6 @@ def convert_to_analog(
return module


def convert_to_analog_mapped(
module: Module,
rpu_config: RPUConfigGeneric,
tile_module_class: Optional[TileModule] = None,
specific_rpu_config_fun: Optional[Callable] = None,
module_name: str = "",
ensure_analog_root: bool = True,
exclude_modules: Optional[List[str]] = None,
inplace: bool = False,
verbose: bool = False,
) -> Module:
"""Convert a given digital model to its analog counterpart with tile
mapping support.

Note:
The torch device (cuda/cpu) is inferred from the original
models parameters, however, if multiple torch
devices are used in a given module, the corresponding analog
module is not moved to any device.

Args:
module: The torch module to convert. All layers that are
defined in the ``conversion_map``.
rpu_config: RPU config to apply to all converted tiles.
tile_module_class: Custom tile module class
specific_rpu_config_fun: Function that modifies the generic
RPUConfig for specific modules. See
:func:`~specific_rpu_config_id` as an example how to
specify it.

module_name: Explicitly given name of the base (root) module,
given to ``specific_rpu_config_fun``.

ensure_analog_root: Whether to ensure that the root module is
of layer type `AnalogLayerBase` so that custom analog are
methods such as `drift_analog_weigths` are available. If
set, it will wrap the model if `AnalogWrapper` if necessary.

Note:

Since the module structure changes when wrapped, the
checkpoint names will also change if this is
enabled (for legacy load this might need to be disabled).

exclude_modules: List of modules names that are in the
conversion map but should be excluded from the conversion

inplace: Whether to for in place conversion (without deepcopy)

verbose: Increase verbosity. Will print converted layers.


Returns:
Module where all the digital layers are replaced with analog
mapped layers.

"""
return convert_to_analog(
module,
rpu_config,
tile_module_class,
_DEFAULT_MAPPED_CONVERSION_MAP,
specific_rpu_config_fun,
module_name,
ensure_analog_root,
exclude_modules,
inplace,
verbose,
)


def convert_to_digital(
module: Module,
conversion_set: Optional[Set] = None,
Expand Down
8 changes: 6 additions & 2 deletions src/aihwkit/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,14 @@ def __init__(

rpu_config = SingleRPUConfig()

if tile_module_class is None:
tile_module_class = rpu_config.get_default_tile_module_class()
self.in_features = self.get_tile_size(in_channels, groups, kernel_size)
self.out_features = out_channels

if tile_module_class is None:
tile_module_class = rpu_config.get_default_tile_module_class(
out_size=self.out_features, in_size=self.in_features
)

self.analog_module = tile_module_class(
self.out_features, self.in_features, rpu_config, bias
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def create_analog_network(rpu_config):


def get_rpu(
rpu: Union[TorchInferenceRPUConfig, InferenceRPUConfig, QuantizedTorchInferenceRPUConfig],
rpu: Union[TorchInferenceRPUConfig, InferenceRPUConfig, QuantizedTorchInferenceRPUConfig]
):
"""Create test rpu config."""
rpu.forward.out_noise = 0.01
Expand Down
6 changes: 3 additions & 3 deletions tests/test_quantized_tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_output_quantization(n_bits, symmetric, range_estimator):
"""Test that output quantization works, returning the appropriate number of states"""

def set_perfect_rpuconfig(
rpu_config: Union[TorchInferenceRPUConfig, QuantizedTorchInferenceRPUConfig],
rpu_config: Union[TorchInferenceRPUConfig, QuantizedTorchInferenceRPUConfig]
):
rpu_config.forward.is_perfect = True
if isinstance(rpu_config, QuantizedTorchInferenceRPUConfig):
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_array_module_output_quantization(
"""Test that when an array is used, output quantization is properly applied"""

def set_perfect_rpuconfig(
rpu_config: Union[TorchInferenceRPUConfig, QuantizedTorchInferenceRPUConfig],
rpu_config: Union[TorchInferenceRPUConfig, QuantizedTorchInferenceRPUConfig]
):
rpu_config.forward.is_perfect = True
if isinstance(rpu_config, QuantizedTorchInferenceRPUConfig):
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_quantized_periphery(n_bits, symmetric, arr_rows, arr_columns):
"""Test that quantized periphery is properly applied"""

def set_perfect_rpuconfig_with_periphery(
rpu_config: Union[TorchInferenceRPUConfig, QuantizedTorchInferenceRPUConfig],
rpu_config: Union[TorchInferenceRPUConfig, QuantizedTorchInferenceRPUConfig]
):
rpu_config.forward.is_perfect = True
rpu_config.mapping.weight_scaling_omega = 1.0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_torch_tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def test_noise_and_bound_management(
"""

def set_bm_nm(
rpu: Union[TorchInferenceRPUConfig, InferenceRPUConfig],
rpu: Union[TorchInferenceRPUConfig, InferenceRPUConfig]
) -> Union[TorchInferenceRPUConfig, InferenceRPUConfig]:
"""Set the rpu config."""
rpu.forward.out_noise = 0.0
Expand Down