Skip to content
26 changes: 26 additions & 0 deletions src/aihwkit/nn/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
"""Base class for adding functionality to analog layers."""
from typing import Any, List, Optional, Tuple, NamedTuple, Union, Generator, Callable, TYPE_CHECKING
from collections import OrderedDict
import warnings

from torch import Tensor
from torch.nn import Parameter
from torch import device as torch_device
from torch import Size as torch_size

from aihwkit.exceptions import ModuleError
from aihwkit.simulator.tiles.module import TileModule
Expand Down Expand Up @@ -244,6 +246,30 @@ class type when setting ``load_rpu_config`` to
``load_rpu_config=False``.

"""
keys_to_delete = []
for name, param in list(state_dict.items()):
if name.endswith("analog_ctx") and param.size() == torch_size([]):
keys_to_delete.append(name)

# For the checkpoint saved by aihwkit before version, 0.9.2, the parameters with
# class `AnalogContext` are saved as empty size tensors since `AnalogContext`,
# which is a derived class of `torch.nn.Parameter`,
# uses an empty Parameter tensor to store the context.
if len(keys_to_delete) > 0:
strict = False
for key in keys_to_delete:
del state_dict[key]
warnings.warn(
"Some parameters in the loaded checkpoint has empty size"
"(param.size() == torch.Size([]))."
"It could happens because of the loaded checkpoint"
"is generated by an older version of aihwkit."
"The parameter is skipped for compatibility reasons."
"The loading mode is set to non-strict."
"It is recommended to re-save the checkpoint with the latest version of aihwkit."
"Related parameters are: {}".format(keys_to_delete)
)

for analog_tile in self.analog_tiles():
analog_tile.set_load_rpu_config_state(load_rpu_config, strict_rpu_config_check)
return super().load_state_dict(state_dict, strict) # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions src/aihwkit/nn/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def reset_parameters(self) -> None:
self.weight, self.bias = self.get_weights() # type: ignore
super().reset_parameters()
self.set_weights(self.weight, self.bias) # type: ignore
# AnalogLinear doesn't support access weight and bias directly, so delete them
del self.weight, self.bias
# delete them manually is necessary since asigning `bias` (a bool) is forbidden
# by torch if self.bias is already a tensor
self.weight, self.bias = None, bias

def forward(self, x_input: Tensor) -> Tensor:
Expand Down
37 changes: 33 additions & 4 deletions src/aihwkit/optim/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from typing import Optional, Type, Union, Any, TYPE_CHECKING

from torch import ones, dtype, Tensor, no_grad
from torch import dtype, Tensor, no_grad
from torch.nn import Parameter
from torch import device as torch_device

Expand All @@ -19,7 +19,30 @@


class AnalogContext(Parameter):
"""Context for analog optimizer."""
"""Context for analog optimizer.

Note: `data` attribution, inherited from `torch.nn.Parameter`, is a tensor of training parameter
If `analog_bias` (which is provided by `analog_tile`) is False,
`data` has the same meaning as `torch.nn.Parameter`
If `analog_bias` (which is provided by `analog_tile`) is True,
The last column of `data` is the `bias` term

Even though it allows us to access the weights directly, always keep in mind that it is used
only for studying propuses. To simulate the real reading, call the `read_weights` method
instead, i.e. given `analog_ctx: AnalogContext`,
estimated_weights, estimated_bias = analog_ctx.analog_tile.read_weights()

Similarly, even though this feature allows us to update the weights directly,
always keep in mind that the real RPU devices change their weights only
by "pulse update" method.

Therefore, use the following update methods instead of
writing `data` directly in the analog optimizer:
---
analog_ctx.analog_tile.update(...)
analog_ctx.analog_tile.update_indexed(...)
---
"""

def __new__(
cls: Type["AnalogContext"],
Expand All @@ -30,9 +53,15 @@ def __new__(
if parameter is None:
return Parameter.__new__(
cls,
data=ones((), device=analog_tile.device, dtype=analog_tile.get_dtype()),
data=analog_tile.tile.get_weights(),
requires_grad=True,
)
# analog_tile.tile can comes from different classes:
# aihwkit.silulator.rpu_base.devices.AnalogTile (C++)
# TorchInferenceTile (Python)
# It stores the "weight" matrix;
# If analog_tile.analog_bias is True, it also stores the "bias" matrix

parameter.__class__ = cls
return parameter

Expand Down Expand Up @@ -92,8 +121,8 @@ def cuda(self, device: Optional[Union[torch_device, str, int]] = None) -> "Analo
Returns:
This context in the specified device.
"""
self.data = self.data.cuda(device) # type: Tensor
if not self.analog_tile.is_cuda:
self.data = self.analog_tile.tile.get_weights() # type: Tensor
self.analog_tile = self.analog_tile.cuda(device)
self.reset(self.analog_tile)
return self
Expand Down
4 changes: 2 additions & 2 deletions src/aihwkit/simulator/tiles/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def set_weights(self, weight: Tensor, bias: Optional[Tensor] = None, **kwargs: A
in_start = in_end

if self.bias is not None and bias is not None:
self.bias.data = bias.detach().to(self.bias.device)
self.bias.data.copy_(bias)

@no_grad()
def get_weights(self, **kwargs: Any) -> Tuple[Tensor, Optional[Tensor]]:
Expand All @@ -140,7 +140,7 @@ def get_weights(self, **kwargs: Any) -> Tuple[Tensor, Optional[Tensor]]:
weight = cat(weight_lst, 1)

if self.bias is not None:
return weight, self.bias.clone().cpu()
return weight, self.bias
return weight, None

def forward(self, x_input: Tensor, tensor_view: Optional[Tuple] = None) -> Tensor:
Expand Down
56 changes: 43 additions & 13 deletions src/aihwkit/simulator/tiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def get_meta_parameters(self) -> Any:
raise NotImplementedError


# pylint: disable=too-many-public-methods
class SimulatorTileWrapper:
"""Wrapper base class for defining the necessary tile
functionality.
Expand All @@ -281,6 +282,18 @@ class SimulatorTileWrapper:
should be used.
handle_output_bound: whether the bound clamp gradient should be inserted
ignore_analog_state: whether to ignore the analog state when __getstate__ is called

Attributes:
tile: A simulator tile object that handles the computations
for the given input/output sizes.
It is created by `self._create_simulator_tile` method,
which is provided by the derived class.
E.g., `aihwkit.simulator.tiles.analog.AnalogTile` and
`aihwkit.simulator.tiles.inference_torch.TorchInferenceTile`
implement this method.
The weight data is stored in the tile object.
analog_ctx: `AnalogContext`, which wraps the weight in tile
into a `torch.nn.Parameter` object.
"""

def __init__(
Expand All @@ -295,8 +308,6 @@ def __init__(
handle_output_bound: bool = False,
ignore_analog_state: bool = False,
):
self.is_cuda = False
self.device = torch_device("cpu")
self.out_size = out_size
self.in_size = in_size
self.rpu_config = deepcopy(rpu_config)
Expand Down Expand Up @@ -324,6 +335,16 @@ def __init__(
self.analog_ctx = AnalogContext(self)
self.analog_ctx.use_torch_update = torch_update

@property
def device(self) -> torch_device:
"""Return the device of the tile."""
return self.analog_ctx.device

@property
def is_cuda(self) -> bool:
"""Return the is_cuda state of the tile."""
return self.analog_ctx.is_cuda

def get_runtime(self) -> RuntimeParameter:
"""Returns the runtime parameter."""
if not hasattr(self.rpu_config, "runtime"):
Expand Down Expand Up @@ -456,8 +477,6 @@ def __getstate__(self) -> Dict:

# don't save device. Will be determined by loading object
current_dict.pop("stream", None)
current_dict.pop("is_cuda", None)
current_dict.pop("device", None)

# this is should not be saved.
current_dict.pop("image_sizes", None)
Expand Down Expand Up @@ -527,16 +546,14 @@ def __setstate__(self, state: Dict) -> None:
self.rpu_config = rpu_config
self.__dict__.update(current_dict)

self.device = torch_device("cpu")
self.is_cuda = False

# recreate attributes not saved
# always first create on CPU
x_size = self.in_size + 1 if self.analog_bias else self.in_size
d_size = self.out_size

# Recreate the tile.
self.tile = self._recreate_simulator_tile(x_size, d_size, self.rpu_config)
self.analog_ctx.data = self.tile.get_weights()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the logic here? Note that this will only the a copy of the current weights. So if you update the weights (using RPUCuda) the analog_ctx.data will not be synchronized correctly with the actual weight. Of course the size of the weight will not change, but it will be more confusing of one maintains two different version of the weight which are not synced, or not?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, there could be an out-of-sync concern here. Therefore, I also change the definition of self.tile.get_weights(). So far, the tile will return an original weight instead of a detached tensor here. Since the data here and the actual weight tenser are the same object in essence, there is no sync issue here.

def get_weights(self) -> Tensor:
        """Get the tile weights.
            matrix; and the second item is either the ``[out_size]`` bias vector
            or ``None`` if the tile is set not to use bias.
        """
        return self.weight.data


names = self.tile.get_hidden_parameter_names()
if len(hidden_parameters_names) > 0 and names != hidden_parameters_names:
Expand Down Expand Up @@ -568,8 +585,15 @@ def __setstate__(self, state: Dict) -> None:
if analog_ctx is not None:
# Keep the object ID and device
to_device = analog_ctx.device
if self.device != to_device:
self.analog_ctx = self.analog_ctx.to(to_device)
if self.analog_ctx.device != to_device:
# aihwkit implements analog tiles in both CPU and CUDA versions,
# e.g. FloatingPointTile(RPUSimple<float>(4, 3))
# v.s. FloatingPointTile(RPUCudaSimple<float>(4, 3))
# Here we need to manually convert the tile to the corresponding version
self.to(to_device)
# Note: `self.to(to_device)` will call `self.analog_ctx.data.to(to_device)`
# so no need to recall
# self.analog_ctx = self.analog_ctx.to(to_device)
self.analog_ctx.set_data(analog_ctx.data)

@no_grad()
Expand Down Expand Up @@ -632,6 +656,16 @@ def _separate_weights(self, combined_weights: Tensor) -> Tuple[Tensor, Optional[

return combined_weights, None

# pylint: disable=invalid-name
def to(self, device: torch_device) -> "SimulatorTileWrapper":
"""Move the tile to a device.
"""
if device.type == "cuda":
self.cuda(device)
else:
self.cpu()
return self

@no_grad()
def cpu(self) -> "SimulatorTileWrapper":
"""Return a copy of this tile in CPU memory.
Expand All @@ -642,8 +676,6 @@ def cpu(self) -> "SimulatorTileWrapper":
if not self.is_cuda:
return self

self.is_cuda = False
self.device = torch_device("cpu")
self.analog_ctx.data = self.analog_ctx.data.cpu()
self.analog_ctx.reset(self)

Expand All @@ -665,8 +697,6 @@ def cuda(
CudaError: if the library has not been compiled with CUDA.
"""
device = torch_device("cuda", cuda_device(device).idx)
self.is_cuda = True
self.device = device
self.analog_ctx.data = self.analog_ctx.data.cuda(device)
self.analog_ctx.reset(self)
return self
Expand Down
3 changes: 1 addition & 2 deletions src/aihwkit/simulator/tiles/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ def set_weights(self, weight: Tensor) -> None:
Args:
weight: ``[out_size, in_size]`` weight matrix.
"""
device = self._analog_weight.device
self._analog_weight = weight.clone().to(device)
self._analog_weight.copy_(weight)

def get_weights(self) -> Tensor:
"""Get the tile weights.
Expand Down
2 changes: 2 additions & 0 deletions src/aihwkit/simulator/tiles/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def forward(
Note: Indexed versions can used when analog_ctx.use_indexed is
set to True.
"""
# `ctx` is the parameter required by PyTorch to store the context
# no need to pass it through ```AnalogFunction.apply(...)````.
# Store in context for using during `backward()`.
ctx.analog_ctx = analog_ctx
ctx.analog_tile = analog_tile
Expand Down
2 changes: 1 addition & 1 deletion src/aihwkit/simulator/tiles/periphery.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def set_weights(
if not isinstance(bias, Tensor):
bias = from_numpy(array(bias))

self.bias.data[:] = bias[:].clone().detach().to(self.get_dtype()).to(self.bias.device)
self.bias.data.copy_(bias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this allow for setting the bias with the data type defined in the tile? While it is correct that torch defines the data type of a layer solely by the data type of the weight tensor, I found it more convenient to handle all the specialized code to have a dtype property on tile level (as this is essentially the "analog tensor" ). Do you suggest that the d_type should be removed from the tile, but now determined by the ctx.data tensor dtype?

bias = None

combined_weights = self._combine_weights(weight, bias)
Expand Down
2 changes: 0 additions & 2 deletions src/aihwkit/simulator/tiles/rpucuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ def cuda(
if self.tile.__class__ in MAP_TILE_CLASS_TO_CUDA:
with cuda_device(device):
self.tile = MAP_TILE_CLASS_TO_CUDA[self.tile.__class__](self.tile)
self.is_cuda = True
self.device = device
self.analog_ctx.data = self.analog_ctx.data.cuda(device)
self.analog_ctx.reset(self) # type: ignore

Expand Down
4 changes: 2 additions & 2 deletions src/aihwkit/simulator/tiles/torch_tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def set_weights(self, weight: Tensor) -> None:
Args:
weight: ``[out_size, in_size]`` weight matrix.
"""
self.weight.data = weight.clone().to(self.weight.device)
self.weight.data.copy_(weight)

def get_weights(self) -> Tensor:
"""Get the tile weights.
Expand All @@ -87,7 +87,7 @@ def get_weights(self) -> Tensor:
matrix; and the second item is either the ``[out_size]`` bias vector
or ``None`` if the tile is set not to use bias.
"""
return self.weight.data.detach().cpu()
return self.weight.data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the current convention is that get_weights always returns CPU weights. If you want to change this, the RPUCuda get_weights call need to change as well, as they producing CPU weights by default. Moreover, get_weights will always product a copy (without the backward trace) by design, to avoid implicit things that cannot be done with analog weights. Of course, hardware aware training is a special case, but for that we have a separate tile.


def get_x_size(self) -> int:
"""Returns input size of tile"""
Expand Down
4 changes: 4 additions & 0 deletions tests/test_layers_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,12 @@ def test_seed(self):
weight1, bias1 = layer1.get_weights()
weight2, bias2 = layer2.get_weights()

if self.use_cuda:
weight1, weight2 = weight1.cpu(), weight2.cpu()
assert_array_almost_equal(weight1, weight2)
if bias1 is not None:
if self.use_cuda:
bias1, bias2 = bias1.cpu(), bias2.cpu()
assert_array_almost_equal(bias1, bias2)

def test_several_analog_layers(self):
Expand Down
14 changes: 12 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,26 @@ def train_model(model, loss_func, x_b, y_b):
@staticmethod
def get_layer_and_tile_weights(model):
"""Return the weights and biases of the model and the tile and whether
it automatically syncs"""
it automatically syncs

Note: All the weights and biases are detached and converted to numpy format."""

if isinstance(model, AnalogLinearMapped):
weight, bias = model.get_weights()
weight, bias = weight.detach().cpu().numpy(), bias.detach().cpu().numpy()
return weight, bias, weight, bias, True

if isinstance(model, AnalogConv2dMapped):
weight, bias = model.get_weights()
weight, bias = weight.detach().cpu().numpy(), bias.detach().cpu().numpy()
return weight, bias, weight, bias, True

if model.weight is not None:
weight = model.weight.data.detach().cpu().numpy()
else:
# we do not sync anymore
weight, bias = model.get_weights()
weight, bias = weight.detach().cpu().numpy(), bias.detach().cpu().numpy()
return weight, bias, weight, bias, True

if model.bias is not None:
Expand Down Expand Up @@ -436,6 +441,7 @@ def test_save_load_model_cross_device(self):
self.assertIsInstance(new_analog_tile.analog_ctx.analog_tile, analog_tile.__class__)

self.assertTrue(new_analog_tile.is_cuda != analog_tile.is_cuda)
self.assertTrue(new_analog_tile.device.type == map_location)

if analog_tile.shared_weights is not None:
self.assertTrue(new_analog_tile.shared_weights.device.type == map_location)
Expand Down Expand Up @@ -923,7 +929,11 @@ def test_load_state_dict_conversion(self):

state1 = new_state_dict[key]
state2 = state_dict[key]
assert_array_almost_equal(state1["analog_tile_weights"], state2["analog_tile_weights"])
weights1 = state1["analog_tile_weights"]
weights2 = state2["analog_tile_weights"]
if self.use_cuda:
weights1, weights2 = weights1.cpu(), weights2.cpu()
assert_array_almost_equal(weights1, weights2)
# assert_array_almost_equal(state1['analog_alpha_scale'],
# state2['analog_alpha_scale'])

Expand Down