diff --git a/src/aihwkit/nn/modules/base.py b/src/aihwkit/nn/modules/base.py index 1424061d..3f346b34 100644 --- a/src/aihwkit/nn/modules/base.py +++ b/src/aihwkit/nn/modules/base.py @@ -7,10 +7,12 @@ """Base class for adding functionality to analog layers.""" from typing import Any, List, Optional, Tuple, NamedTuple, Union, Generator, Callable, TYPE_CHECKING from collections import OrderedDict +import warnings from torch import Tensor from torch.nn import Parameter from torch import device as torch_device +from torch import Size as torch_size from aihwkit.exceptions import ModuleError from aihwkit.simulator.tiles.module import TileModule @@ -244,6 +246,30 @@ class type when setting ``load_rpu_config`` to ``load_rpu_config=False``. """ + keys_to_delete = [] + for name, param in list(state_dict.items()): + if name.endswith("analog_ctx") and param.size() == torch_size([]): + keys_to_delete.append(name) + + # For the checkpoint saved by aihwkit before version, 0.9.2, the parameters with + # class `AnalogContext` are saved as empty size tensors since `AnalogContext`, + # which is a derived class of `torch.nn.Parameter`, + # uses an empty Parameter tensor to store the context. + if len(keys_to_delete) > 0: + strict = False + for key in keys_to_delete: + del state_dict[key] + warnings.warn( + "Some parameters in the loaded checkpoint has empty size" + "(param.size() == torch.Size([]))." + "It could happens because of the loaded checkpoint" + "is generated by an older version of aihwkit." + "The parameter is skipped for compatibility reasons." + "The loading mode is set to non-strict." + "It is recommended to re-save the checkpoint with the latest version of aihwkit." + "Related parameters are: {}".format(keys_to_delete) + ) + for analog_tile in self.analog_tiles(): analog_tile.set_load_rpu_config_state(load_rpu_config, strict_rpu_config_check) return super().load_state_dict(state_dict, strict) # type: ignore diff --git a/src/aihwkit/nn/modules/linear.py b/src/aihwkit/nn/modules/linear.py index 9db3bee2..8a5debfe 100644 --- a/src/aihwkit/nn/modules/linear.py +++ b/src/aihwkit/nn/modules/linear.py @@ -82,6 +82,10 @@ def reset_parameters(self) -> None: self.weight, self.bias = self.get_weights() # type: ignore super().reset_parameters() self.set_weights(self.weight, self.bias) # type: ignore + # AnalogLinear doesn't support access weight and bias directly, so delete them + del self.weight, self.bias + # delete them manually is necessary since asigning `bias` (a bool) is forbidden + # by torch if self.bias is already a tensor self.weight, self.bias = None, bias def forward(self, x_input: Tensor) -> Tensor: diff --git a/src/aihwkit/optim/context.py b/src/aihwkit/optim/context.py index d9228d03..b5c5c33d 100644 --- a/src/aihwkit/optim/context.py +++ b/src/aihwkit/optim/context.py @@ -10,7 +10,7 @@ from typing import Optional, Type, Union, Any, TYPE_CHECKING -from torch import ones, dtype, Tensor, no_grad +from torch import dtype, Tensor, no_grad from torch.nn import Parameter from torch import device as torch_device @@ -19,7 +19,30 @@ class AnalogContext(Parameter): - """Context for analog optimizer.""" + """Context for analog optimizer. + + Note: `data` attribution, inherited from `torch.nn.Parameter`, is a tensor of training parameter + If `analog_bias` (which is provided by `analog_tile`) is False, + `data` has the same meaning as `torch.nn.Parameter` + If `analog_bias` (which is provided by `analog_tile`) is True, + The last column of `data` is the `bias` term + + Even though it allows us to access the weights directly, always keep in mind that it is used + only for studying propuses. To simulate the real reading, call the `read_weights` method + instead, i.e. given `analog_ctx: AnalogContext`, + estimated_weights, estimated_bias = analog_ctx.analog_tile.read_weights() + + Similarly, even though this feature allows us to update the weights directly, + always keep in mind that the real RPU devices change their weights only + by "pulse update" method. + + Therefore, use the following update methods instead of + writing `data` directly in the analog optimizer: + --- + analog_ctx.analog_tile.update(...) + analog_ctx.analog_tile.update_indexed(...) + --- + """ def __new__( cls: Type["AnalogContext"], @@ -30,9 +53,15 @@ def __new__( if parameter is None: return Parameter.__new__( cls, - data=ones((), device=analog_tile.device, dtype=analog_tile.get_dtype()), + data=analog_tile.tile.get_weights(), requires_grad=True, ) + # analog_tile.tile can comes from different classes: + # aihwkit.silulator.rpu_base.devices.AnalogTile (C++) + # TorchInferenceTile (Python) + # It stores the "weight" matrix; + # If analog_tile.analog_bias is True, it also stores the "bias" matrix + parameter.__class__ = cls return parameter @@ -92,8 +121,8 @@ def cuda(self, device: Optional[Union[torch_device, str, int]] = None) -> "Analo Returns: This context in the specified device. """ - self.data = self.data.cuda(device) # type: Tensor if not self.analog_tile.is_cuda: + self.data = self.analog_tile.tile.get_weights() # type: Tensor self.analog_tile = self.analog_tile.cuda(device) self.reset(self.analog_tile) return self diff --git a/src/aihwkit/simulator/tiles/array.py b/src/aihwkit/simulator/tiles/array.py index 43fdbc17..98a0f0f7 100644 --- a/src/aihwkit/simulator/tiles/array.py +++ b/src/aihwkit/simulator/tiles/array.py @@ -115,7 +115,7 @@ def set_weights(self, weight: Tensor, bias: Optional[Tensor] = None, **kwargs: A in_start = in_end if self.bias is not None and bias is not None: - self.bias.data = bias.detach().to(self.bias.device) + self.bias.data.copy_(bias) @no_grad() def get_weights(self, **kwargs: Any) -> Tuple[Tensor, Optional[Tensor]]: @@ -140,7 +140,7 @@ def get_weights(self, **kwargs: Any) -> Tuple[Tensor, Optional[Tensor]]: weight = cat(weight_lst, 1) if self.bias is not None: - return weight, self.bias.clone().cpu() + return weight, self.bias return weight, None def forward(self, x_input: Tensor, tensor_view: Optional[Tuple] = None) -> Tensor: diff --git a/src/aihwkit/simulator/tiles/base.py b/src/aihwkit/simulator/tiles/base.py index d9383d51..ba026acd 100644 --- a/src/aihwkit/simulator/tiles/base.py +++ b/src/aihwkit/simulator/tiles/base.py @@ -264,6 +264,7 @@ def get_meta_parameters(self) -> Any: raise NotImplementedError +# pylint: disable=too-many-public-methods class SimulatorTileWrapper: """Wrapper base class for defining the necessary tile functionality. @@ -281,6 +282,18 @@ class SimulatorTileWrapper: should be used. handle_output_bound: whether the bound clamp gradient should be inserted ignore_analog_state: whether to ignore the analog state when __getstate__ is called + + Attributes: + tile: A simulator tile object that handles the computations + for the given input/output sizes. + It is created by `self._create_simulator_tile` method, + which is provided by the derived class. + E.g., `aihwkit.simulator.tiles.analog.AnalogTile` and + `aihwkit.simulator.tiles.inference_torch.TorchInferenceTile` + implement this method. + The weight data is stored in the tile object. + analog_ctx: `AnalogContext`, which wraps the weight in tile + into a `torch.nn.Parameter` object. """ def __init__( @@ -295,8 +308,6 @@ def __init__( handle_output_bound: bool = False, ignore_analog_state: bool = False, ): - self.is_cuda = False - self.device = torch_device("cpu") self.out_size = out_size self.in_size = in_size self.rpu_config = deepcopy(rpu_config) @@ -324,6 +335,16 @@ def __init__( self.analog_ctx = AnalogContext(self) self.analog_ctx.use_torch_update = torch_update + @property + def device(self) -> torch_device: + """Return the device of the tile.""" + return self.analog_ctx.device + + @property + def is_cuda(self) -> bool: + """Return the is_cuda state of the tile.""" + return self.analog_ctx.is_cuda + def get_runtime(self) -> RuntimeParameter: """Returns the runtime parameter.""" if not hasattr(self.rpu_config, "runtime"): @@ -456,8 +477,6 @@ def __getstate__(self) -> Dict: # don't save device. Will be determined by loading object current_dict.pop("stream", None) - current_dict.pop("is_cuda", None) - current_dict.pop("device", None) # this is should not be saved. current_dict.pop("image_sizes", None) @@ -527,9 +546,6 @@ def __setstate__(self, state: Dict) -> None: self.rpu_config = rpu_config self.__dict__.update(current_dict) - self.device = torch_device("cpu") - self.is_cuda = False - # recreate attributes not saved # always first create on CPU x_size = self.in_size + 1 if self.analog_bias else self.in_size @@ -537,6 +553,7 @@ def __setstate__(self, state: Dict) -> None: # Recreate the tile. self.tile = self._recreate_simulator_tile(x_size, d_size, self.rpu_config) + self.analog_ctx.data = self.tile.get_weights() names = self.tile.get_hidden_parameter_names() if len(hidden_parameters_names) > 0 and names != hidden_parameters_names: @@ -568,8 +585,15 @@ def __setstate__(self, state: Dict) -> None: if analog_ctx is not None: # Keep the object ID and device to_device = analog_ctx.device - if self.device != to_device: - self.analog_ctx = self.analog_ctx.to(to_device) + if self.analog_ctx.device != to_device: + # aihwkit implements analog tiles in both CPU and CUDA versions, + # e.g. FloatingPointTile(RPUSimple(4, 3)) + # v.s. FloatingPointTile(RPUCudaSimple(4, 3)) + # Here we need to manually convert the tile to the corresponding version + self.to(to_device) + # Note: `self.to(to_device)` will call `self.analog_ctx.data.to(to_device)` + # so no need to recall + # self.analog_ctx = self.analog_ctx.to(to_device) self.analog_ctx.set_data(analog_ctx.data) @no_grad() @@ -632,6 +656,16 @@ def _separate_weights(self, combined_weights: Tensor) -> Tuple[Tensor, Optional[ return combined_weights, None + # pylint: disable=invalid-name + def to(self, device: torch_device) -> "SimulatorTileWrapper": + """Move the tile to a device. + """ + if device.type == "cuda": + self.cuda(device) + else: + self.cpu() + return self + @no_grad() def cpu(self) -> "SimulatorTileWrapper": """Return a copy of this tile in CPU memory. @@ -642,8 +676,6 @@ def cpu(self) -> "SimulatorTileWrapper": if not self.is_cuda: return self - self.is_cuda = False - self.device = torch_device("cpu") self.analog_ctx.data = self.analog_ctx.data.cpu() self.analog_ctx.reset(self) @@ -665,8 +697,6 @@ def cuda( CudaError: if the library has not been compiled with CUDA. """ device = torch_device("cuda", cuda_device(device).idx) - self.is_cuda = True - self.device = device self.analog_ctx.data = self.analog_ctx.data.cuda(device) self.analog_ctx.reset(self) return self diff --git a/src/aihwkit/simulator/tiles/custom.py b/src/aihwkit/simulator/tiles/custom.py index de1d4425..99e2ed68 100644 --- a/src/aihwkit/simulator/tiles/custom.py +++ b/src/aihwkit/simulator/tiles/custom.py @@ -147,8 +147,7 @@ def set_weights(self, weight: Tensor) -> None: Args: weight: ``[out_size, in_size]`` weight matrix. """ - device = self._analog_weight.device - self._analog_weight = weight.clone().to(device) + self._analog_weight.copy_(weight) def get_weights(self) -> Tensor: """Get the tile weights. diff --git a/src/aihwkit/simulator/tiles/functions.py b/src/aihwkit/simulator/tiles/functions.py index e7ead29f..84938a83 100644 --- a/src/aihwkit/simulator/tiles/functions.py +++ b/src/aihwkit/simulator/tiles/functions.py @@ -32,6 +32,8 @@ def forward( Note: Indexed versions can used when analog_ctx.use_indexed is set to True. """ + # `ctx` is the parameter required by PyTorch to store the context + # no need to pass it through ```AnalogFunction.apply(...)````. # Store in context for using during `backward()`. ctx.analog_ctx = analog_ctx ctx.analog_tile = analog_tile diff --git a/src/aihwkit/simulator/tiles/periphery.py b/src/aihwkit/simulator/tiles/periphery.py index a9f2d950..e87b41c5 100644 --- a/src/aihwkit/simulator/tiles/periphery.py +++ b/src/aihwkit/simulator/tiles/periphery.py @@ -150,7 +150,7 @@ def set_weights( if not isinstance(bias, Tensor): bias = from_numpy(array(bias)) - self.bias.data[:] = bias[:].clone().detach().to(self.get_dtype()).to(self.bias.device) + self.bias.data.copy_(bias) bias = None combined_weights = self._combine_weights(weight, bias) diff --git a/src/aihwkit/simulator/tiles/rpucuda.py b/src/aihwkit/simulator/tiles/rpucuda.py index e549d438..1ba4fafc 100644 --- a/src/aihwkit/simulator/tiles/rpucuda.py +++ b/src/aihwkit/simulator/tiles/rpucuda.py @@ -153,8 +153,6 @@ def cuda( if self.tile.__class__ in MAP_TILE_CLASS_TO_CUDA: with cuda_device(device): self.tile = MAP_TILE_CLASS_TO_CUDA[self.tile.__class__](self.tile) - self.is_cuda = True - self.device = device self.analog_ctx.data = self.analog_ctx.data.cuda(device) self.analog_ctx.reset(self) # type: ignore diff --git a/src/aihwkit/simulator/tiles/torch_tile.py b/src/aihwkit/simulator/tiles/torch_tile.py index 6c8dbd66..08c8754c 100644 --- a/src/aihwkit/simulator/tiles/torch_tile.py +++ b/src/aihwkit/simulator/tiles/torch_tile.py @@ -77,7 +77,7 @@ def set_weights(self, weight: Tensor) -> None: Args: weight: ``[out_size, in_size]`` weight matrix. """ - self.weight.data = weight.clone().to(self.weight.device) + self.weight.data.copy_(weight) def get_weights(self) -> Tensor: """Get the tile weights. @@ -87,7 +87,7 @@ def get_weights(self) -> Tensor: matrix; and the second item is either the ``[out_size]`` bias vector or ``None`` if the tile is set not to use bias. """ - return self.weight.data.detach().cpu() + return self.weight.data def get_x_size(self) -> int: """Returns input size of tile""" diff --git a/tests/test_layers_linear.py b/tests/test_layers_linear.py index 05d3e60d..cfca2f12 100644 --- a/tests/test_layers_linear.py +++ b/tests/test_layers_linear.py @@ -119,8 +119,12 @@ def test_seed(self): weight1, bias1 = layer1.get_weights() weight2, bias2 = layer2.get_weights() + if self.use_cuda: + weight1, weight2 = weight1.cpu(), weight2.cpu() assert_array_almost_equal(weight1, weight2) if bias1 is not None: + if self.use_cuda: + bias1, bias2 = bias1.cpu(), bias2.cpu() assert_array_almost_equal(bias1, bias2) def test_several_analog_layers(self): diff --git a/tests/test_utils.py b/tests/test_utils.py index 8932c478..0d84bfae 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -100,14 +100,18 @@ def train_model(model, loss_func, x_b, y_b): @staticmethod def get_layer_and_tile_weights(model): """Return the weights and biases of the model and the tile and whether - it automatically syncs""" + it automatically syncs + + Note: All the weights and biases are detached and converted to numpy format.""" if isinstance(model, AnalogLinearMapped): weight, bias = model.get_weights() + weight, bias = weight.detach().cpu().numpy(), bias.detach().cpu().numpy() return weight, bias, weight, bias, True if isinstance(model, AnalogConv2dMapped): weight, bias = model.get_weights() + weight, bias = weight.detach().cpu().numpy(), bias.detach().cpu().numpy() return weight, bias, weight, bias, True if model.weight is not None: @@ -115,6 +119,7 @@ def get_layer_and_tile_weights(model): else: # we do not sync anymore weight, bias = model.get_weights() + weight, bias = weight.detach().cpu().numpy(), bias.detach().cpu().numpy() return weight, bias, weight, bias, True if model.bias is not None: @@ -436,6 +441,7 @@ def test_save_load_model_cross_device(self): self.assertIsInstance(new_analog_tile.analog_ctx.analog_tile, analog_tile.__class__) self.assertTrue(new_analog_tile.is_cuda != analog_tile.is_cuda) + self.assertTrue(new_analog_tile.device.type == map_location) if analog_tile.shared_weights is not None: self.assertTrue(new_analog_tile.shared_weights.device.type == map_location) @@ -923,7 +929,11 @@ def test_load_state_dict_conversion(self): state1 = new_state_dict[key] state2 = state_dict[key] - assert_array_almost_equal(state1["analog_tile_weights"], state2["analog_tile_weights"]) + weights1 = state1["analog_tile_weights"] + weights2 = state2["analog_tile_weights"] + if self.use_cuda: + weights1, weights2 = weights1.cpu(), weights2.cpu() + assert_array_almost_equal(weights1, weights2) # assert_array_almost_equal(state1['analog_alpha_scale'], # state2['analog_alpha_scale'])