From dbba92667a6b286607d708b83c10a3ba4bfc225e Mon Sep 17 00:00:00 2001 From: "Jose Luquin jluquin@ibm.com" Date: Wed, 12 Nov 2025 17:36:32 -0500 Subject: [PATCH 1/4] Modifications/Cleanup/Improvements to IRDrop example and implementation Signed-off-by: Jose Luquin jluquin@ibm.com --- examples/28_advanced_irdrop.py | 74 +++++-- src/aihwkit/simulator/parameters/io.py | 48 ++++- .../simulator/tiles/analog_mvm_irdrop_t.py | 184 +++++++++++++----- 3 files changed, 239 insertions(+), 67 deletions(-) diff --git a/examples/28_advanced_irdrop.py b/examples/28_advanced_irdrop.py index 39924786..c3e5e625 100644 --- a/examples/28_advanced_irdrop.py +++ b/examples/28_advanced_irdrop.py @@ -2,7 +2,13 @@ # (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved. # -# Licensed under the MIT license. See LICENSE file in the project root for details. +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. """aihwkit example 28: advanced (time-dependent) IR drop effects @@ -83,7 +89,7 @@ def rpu_config_modifications(rpu_config: Type[RPUConfigBase]) -> Type[RPUConfigB rpu_config.forward.ir_drop_v_read = 0.4 rpu_config.forward.ir_drop = 1.0 rpu_config.forward.nm_thres = 1.0 - rpu_config.forward.inp_res = 2**10 + rpu_config.forward.inp_res = 2**10 - 2 rpu_config.forward.out_bound = -1 # 10 - quite restrictive rpu_config.forward.out_res = -1 rpu_config.forward.out_noise = 0.0 # prevent bit-wise mode from amplifying noise too much @@ -107,9 +113,16 @@ def network_comparison( Returns: None """ plot_model_names_dict = { - "conventional_model_dt_irdrop": r"Conventional \ Mode \ Advanced \ IR \ Drop", - "split_mode_pwm_dt_irdrop": r"Split \ Mode \ Advanced \ IR \ Drop", - "bit_wise_dt_irdrop": r"Bit \ Wise \ Mode \ Advanced \ IR \ Drop", + "default_conventional_model_dt_irdrop": + r"Conventional \ Mode \ Adv.\ IR \ Drop \ Default", + "conventional_model_dt_irdrop_PCMnoise": + r"Conventional \ Mode \ Adv. \ IR \ Drop \ PCM \ Noise", + "conventional_model_dt_irdrop_noADC": + r"Conventional \ Mode \ Advanced \ IR \ Drop \ noADC", + "split_mode_pwm_dt_irdrop": + r"Split \ Mode \ Advanced \ IR \ Drop", + "bit_wise_dt_irdrop": + r"Bit \ Wise \ Mode \ Advanced \ IR \ Drop", } # Move the model and tensors to cuda if it is available. @@ -211,30 +224,61 @@ def network_comparison( model_dict = {} - # conventional time-dependent ir drop - rpu_config_conventional_dt_irdrop = rpu_config_modifications(TorchInferenceRPUConfigIRDropT()) - model_conventional_dt_irdrop = AnalogLinear( - IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config_conventional_dt_irdrop + # default conventional time-dependent ir drop + rpu_config_default_conventional_dt_irdrop = rpu_config_modifications( + TorchInferenceRPUConfigIRDropT() + ) + model_default_conventional_dt_irdrop = AnalogLinear( + IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config_default_conventional_dt_irdrop + ) + model_dict.update({"default_conventional_model_dt_irdrop": + model_default_conventional_dt_irdrop}) + + # conventional time-dependent ir drop with input-dependent PCM read noise (flag testing) + rpu_config_conventional_dt_irdrop_PCMnoise = rpu_config_modifications( + TorchInferenceRPUConfigIRDropT() + ) + rpu_config_conventional_dt_irdrop_PCMnoise.forward.apply_xdep_pcm_read_noise = True + rpu_config_conventional_dt_irdrop_PCMnoise.forward.xdep_pcm_read_noise_scale = 1.0 + model_conventional_dt_irdrop_PCMnoise = AnalogLinear( + IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config_conventional_dt_irdrop_PCMnoise + ) + model_dict.update({"conventional_model_dt_irdrop_PCMnoise": + model_conventional_dt_irdrop_PCMnoise}) + + # conventional time-dependent ir drop without ADC quantization (flag testing) + rpu_config_conventional_dt_irdrop_noADC = rpu_config_modifications( + TorchInferenceRPUConfigIRDropT() + ) + rpu_config_conventional_dt_irdrop_noADC.forward.adc_quantization = False + model_conventional_dt_irdrop_noADC = AnalogLinear( + IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config_conventional_dt_irdrop_noADC ) - model_dict.update({"conventional_model_dt_irdrop": model_conventional_dt_irdrop}) + model_dict.update({"conventional_model_dt_irdrop_noADC": + model_conventional_dt_irdrop_noADC}) # split mode pwm time-dependent ir drop - rpu_config_split_mode_dt_irdrop = rpu_config_modifications(TorchInferenceRPUConfigIRDropT()) + rpu_config_split_mode_dt_irdrop = rpu_config_modifications( + TorchInferenceRPUConfigIRDropT() + ) rpu_config_split_mode_dt_irdrop.forward.mv_type = AnalogMVType.SPLIT_MODE rpu_config_split_mode_dt_irdrop.forward.ir_drop_bit_shift = 3 - rpu_config_split_mode_dt_irdrop.forward.split_mode_pwm = AnalogMVType.SPLIT_MODE model_split_mode_dt_irdrop = AnalogLinear( IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config_split_mode_dt_irdrop ) - model_dict.update({"split_mode_pwm_dt_irdrop": model_split_mode_dt_irdrop}) + model_dict.update({"split_mode_pwm_dt_irdrop": + model_split_mode_dt_irdrop}) # bit wise time-dependent ir drop - rpu_config_bitwise_dt_irdrop = rpu_config_modifications(TorchInferenceRPUConfigIRDropT()) + rpu_config_bitwise_dt_irdrop = rpu_config_modifications( + TorchInferenceRPUConfigIRDropT() + ) rpu_config_bitwise_dt_irdrop.forward.mv_type = AnalogMVType.BIT_WISE model_bitwise_dt_irdrop = AnalogLinear( IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config_bitwise_dt_irdrop ) - model_dict.update({"bit_wise_dt_irdrop": model_bitwise_dt_irdrop}) + model_dict.update({"bit_wise_dt_irdrop": + model_bitwise_dt_irdrop}) # default model rpu_config_default = rpu_config_modifications(InferenceRPUConfig()) diff --git a/src/aihwkit/simulator/parameters/io.py b/src/aihwkit/simulator/parameters/io.py index 846e457c..95ae810b 100644 --- a/src/aihwkit/simulator/parameters/io.py +++ b/src/aihwkit/simulator/parameters/io.py @@ -2,7 +2,13 @@ # (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved. # -# Licensed under the MIT license. See LICENSE file in the project root for details. +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. # pylint: disable=too-many-instance-attributes @@ -32,11 +38,7 @@ class IOParameters(_PrintableMixin): """Short-cut to compute a perfect forward pass. If ``True``, it assumes an ideal forward pass (e.g. no bound, ADC etc...). - Will disregard all other IOParameters settings in this case. - - Note that other noise sources set by - :class:`aihwkit.simulator.parameters.inference.WeightModifierParameter` - and :class:`aihwkit.inference.noise` will still be applied. + Will disregard all other settings in this case. """ mv_type: AnalogMVType = AnalogMVType.ONE_PASS @@ -176,7 +178,7 @@ class IOParameters(_PrintableMixin): analog output (before the ADC), i.e. :math:`\frac{y_i}{1 + n_i*|y_i|}` where :math:`n_i` is drawn at the instantiation time by:: - out_nonlinearity / out_bound * (1 + out_nonlinearity_std * rand) + out_nonlinearity / out_bound * (1 + out_nonlinearity_std * rand) """ out_nonlinearity_std: float = 0.0 @@ -331,3 +333,35 @@ class IOParametersIRDropT(IOParameters): PWM/DAC operation increases throughput / energy efficiency of MVM tile hardware while potentially sacrificing some analog MVM accuracy.""" + + apply_xdep_pcm_read_noise: bool = False + """Sets whether to apply activation (x)-dependent PCM read + noise. This model is implemented within analog_mvm_irdrop_t due + to the inputs needing to be 'prepared' (converted to a ns scale), + which is only accomplished deeper into the mvm implementation (as + opposed to the noise_model parameters within PCMLikeNoiseModel()). + """ + + xdep_pcm_read_noise_scale: float = 1.0 + """Scale for the activation (x)-dependent PCM read noise model. + """ + + adc_quantization: bool = True + """Sets whether the ADC Quantization feature is applied; this + implements the 'floor' operation consistent with the behavior a + Current-Controlled Oscillator (CCO) (during current integration) + in the case where a capacitor is not fully charged to achieve a + pulse/oscillation. This feature is particularly important in + capturing the true behavior of SPLIT PWM mode. This feature is + implemented within analog_mvm_irdrop_t taken that it requires + charge units, rather than unit-less quantities. + """ + adc_frequency: float = 6.24 + """Sets the operating frequency of the ADC used for quantization; + quantity provided is in GHz. + """ + + ir_drop_integration_sum: bool = False + """Sets current integration to use summation rather than the default + trapezoidal integation method. + """ \ No newline at end of file diff --git a/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py b/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py index 3bb8bba6..3a786fa3 100644 --- a/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py +++ b/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py @@ -2,7 +2,13 @@ # (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved. # -# Licensed under the MIT license. See LICENSE file in the project root for details. +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. # pylint: disable=too-many-locals, too-many-arguments @@ -14,8 +20,8 @@ from torch import ( Tensor, empty, - zeros, sum as torch_sum, + trapz as torch_trapz, flip, abs as torch_abs, sign, @@ -24,7 +30,12 @@ allclose, linspace, Size, + mean as torch_mean, + randn_like, + trunc, + log10 as torch_log10, ) +import torch from torch.autograd import no_grad from torch.nn.functional import pad @@ -78,7 +89,7 @@ def _interleave_cols_nd(cls, mvm1: Tensor, mvm2: Tensor) -> Tensor: IR drop in both directions (to north and south ADCs) in a symmetric tile design. """ - shape = Size(mvm1.shape[0:-1]) + Size([mvm1.shape[-1] + mvm2.shape[-1]]) + shape = Size(mvm1.shape[0:-1]) + Size([2 * mvm1.shape[-1]]) mvm = empty(*shape).to(mvm1.device) mvm[..., 0::2] = mvm1 mvm[..., 1::2] = mvm2 @@ -194,7 +205,7 @@ def _prepare_inputs( prepared_input = [prepared_input_lsb, prepared_input_msb] elif io_pars.mv_type == AnalogMVType.BIT_WISE: - int_input = prepared_input / (2.0 * res) + int_input = prepared_input / (2. * res) n_bits = int(log2(1.0 / res + 2)) - 1 prepared_input = [] for _ in range(n_bits): @@ -212,6 +223,75 @@ def _prepare_inputs( raise ConfigError(f"Unknown AnalogMVType {io_pars.mv_type}") return prepared_input + @classmethod + def _apply_xdep_pcm_read_noise( + cls, + weight: Tensor, + input_: Tensor, + io_pars: IOParametersIRDropT, + g_converter: SinglePairConductanceConverter, + res: float + ) -> Tensor: + """ + Applies input-dependent PCM read noise to weights + """ + + read_noise_scale = io_pars.xdep_pcm_read_noise_scale + + ## Default model analytical-fit coefficients + sigma_noise_slope_coefficients = [0.00021926, -0.00187352, + 0.00655714, -0.0146159, 0.02578764] + + sigma_noise_offset_coefficients = [1.11906167e-09, -9.08576764e-09, + 3.07015063e-08, -6.77079241e-08, 1.17763144e-07] + + # First, convert activations/inputs into integer ns units + t_integration = torch_mean(torch_abs(input_) / (2*res)) + + if t_integration > 0.: + x = torch_log10(t_integration) + else: + torch.tensor(0., dtype=t_integration.dtype,device=t_integration.device) + + # Convert weights into conductance units + conductances_lst, params = g_converter.convert_to_conductances(weight) + conductances = torch_abs(conductances_lst[0] - conductances_lst[1]) #[uS] + + sig_noise_slope = ( sigma_noise_slope_coefficients[0]*(x**4) + + sigma_noise_slope_coefficients[1]*(x**3) + + sigma_noise_slope_coefficients[2]*(x**2) + + sigma_noise_slope_coefficients[3]*(x**1) + + sigma_noise_slope_coefficients[4] + ) + sig_noise_offset = ( sigma_noise_offset_coefficients[0]*(x**4) + + sigma_noise_offset_coefficients[1]*(x**3) + + sigma_noise_offset_coefficients[2]*(x**2) + + sigma_noise_offset_coefficients[3]*(x**1) + + sigma_noise_offset_coefficients[4] + ) * 1e6 #[uS] + sig_noise = (sig_noise_slope * conductances) + sig_noise_offset + g_final = conductances + read_noise_scale * sig_noise * randn_like(weight) + + # Turn conductances back into unitless weights with original sign preserved + weight = g_final * (sign(weight)) / params['scale_ratio'] + + return weight + + @classmethod + def _apply_adc_quantization( + cls, + mvm: Tensor, + io_pars: IOParametersIRDropT, + ) -> Tensor: + """ + Applies ADC Quantization + """ + + adc_ticks_even = trunc((mvm / 100) * io_pars.adc_frequency) + mvm = adc_ticks_even * 100 / io_pars.adc_frequency + + return mvm + @classmethod @no_grad() def _thev_equiv( @@ -285,10 +365,10 @@ def _thev_equiv( vth_3d = output[0, :, :, :] rth_3d = output[1, :, :, :] - vth_nd = vth_3d.view(*out_sh) # reshape back to appropriate dimensions + vth_nd = vth_3d.view(*out_sh) # reshape back to appropriate dimensions rth_nd = rth_3d.view(*out_sh) - return vth_3d, rth_3d + return vth_nd, rth_nd gp_4d = gp_2d[None, :, :, None] gm_4d = gm_2d[None, :, :, None] @@ -322,8 +402,8 @@ def sum_segs(g_values: Tensor) -> Tensor: g_4d = sum_segs(g_4d) # regularized to avoid device-by-zero - gth_4d_segs += g_4d + 1e-12 # atomic Thev equiv conductance [uS] - vth_4d_segs += 0.4 * g_4d # atomic Thevenin equivalent resistance [MOhm] + gth_4d_segs += g_4d + 1e-12 # atomic Thev equiv conductance [uS] + vth_4d_segs += 0.4 * g_4d # atomic Thevenin equivalent resistance [MOhm] vth_4d_segs /= gth_4d_segs g_4d = None @@ -339,7 +419,7 @@ def sum_segs(g_values: Tensor) -> Tensor: vth_3d = (vth_3d / r_1 + vth_4d_segs[:, seg, :, :] / r_2) * rth_3d rth_3d += 0.5 * rw_segs - vth_nd = vth_3d.view(*out_sh) # reshape back to appropriate dimensions + vth_nd = vth_3d.view(*out_sh) # reshape back to appropriate dimensions rth_nd = rth_3d.view(*out_sh) return vth_nd, rth_nd # rth_nd in MOhm @@ -379,6 +459,9 @@ def _matmul_irdrop( res = cls._get_res(io_pars.inp_res) bit_res = io_pars.inp_bound / res + if io_pars.apply_xdep_pcm_read_noise: + weight = cls._apply_xdep_pcm_read_noise(weight, input_, io_pars, g_converter, res) + if ir_drop == 0.0: return super(AnalogMVMIRDropT, cls)._matmul(weight, input_, trans) @@ -398,44 +481,56 @@ def _matmul_irdrop( ) input_, new_weight = cls._pad_symmetric(input_, new_weight, phys_input_size=phys_input_size) - if g_converter is None: - g_converter = SinglePairConductanceConverter() # type: ignore + g_lst, params = g_converter.convert_to_conductances(new_weight) - mvm = zeros((input_.shape[0], g_lst[0].shape[1])).to(input_.device) - # (f1, (gp1, gm1)), (f2, (gp2, gm2), ... , (fn, (gpn, gmn)) # low to highest significance - for f_factor, g_lst_pair in zip(params.get("f_lst", [1.0]), zip(g_lst[::2], g_lst[1::2])): - vth_nd, rth_nd = cls._thev_equiv( - input_, - [g[:, 0::2] for g in g_lst_pair], # even cols - time_steps=time_steps, - t_max=t_max, - segments=io_pars.ir_drop_segments, - r_s=ir_drop_rs, - phys_input_size=phys_input_size, - ) - i_out_nd = (vth_nd - io_pars.ir_drop_v_read) / rth_nd # uA - mvm_even_col_down_adc = torch_sum(i_out_nd, dim=-1) # * x batch_size x n_cols/2 + vth_nd, rth_nd = cls._thev_equiv( + input_, + [g[:, 0::2] for g in g_lst], # even cols + time_steps=time_steps, + t_max=t_max, + segments=io_pars.ir_drop_segments, + r_s=ir_drop_rs, + phys_input_size=phys_input_size, + ) + i_out_nd = (vth_nd - io_pars.ir_drop_v_read) / rth_nd # uA - vth_nd, rth_nd = cls._thev_equiv( - flip(input_, (-1,)), # flip input - [flip(g[:, 1::2], (0,)) for g in g_lst_pair], # odd cols - time_steps=time_steps, - t_max=t_max, - segments=io_pars.ir_drop_segments, - r_s=ir_drop_rs, - phys_input_size=phys_input_size, - ) - i_out_nd = (vth_nd - io_pars.ir_drop_v_read) / rth_nd # uA - mvm_odd_col_up_adc = torch_sum(i_out_nd, dim=-1) # * x batch_size x n_cols/2 + if io_pars.ir_drop_integration_sum: + mvm_even_col_down_adc = torch_sum(i_out_nd, dim=-1) + else: + # Insert a '0th' time step trapz integration + i_out_nd = torch.cat( (i_out_nd[...,0].unsqueeze(-1), i_out_nd), -1) + mvm_even_col_down_adc = torch_trapz(i_out_nd, dim=-1) + + if io_pars.adc_quantization: + mvm_even_col_down_adc = cls._apply_adc_quantization(mvm_even_col_down_adc, io_pars) + + vth_nd, rth_nd = cls._thev_equiv( + flip(input_, (-1,)), # flip input + [flip(g[:, 1::2], (0,)) for g in g_lst], # odd cols + time_steps=time_steps, + t_max=t_max, + segments=io_pars.ir_drop_segments, + r_s=ir_drop_rs, + phys_input_size=phys_input_size, + ) + i_out_nd = (vth_nd - io_pars.ir_drop_v_read) / rth_nd # uA - mvm += f_factor * cls._interleave_cols_nd( - mvm_even_col_down_adc, mvm_odd_col_up_adc - ) # symmetric ADCs + if io_pars.ir_drop_integration_sum: + mvm_even_col_down_adc = torch_sum(i_out_nd, dim=-1) + else: + # Insert a '0th' time step trapz integration + i_out_nd = torch.cat( (i_out_nd[...,0].unsqueeze(-1), i_out_nd), -1) + mvm_odd_col_up_adc = torch_trapz(i_out_nd, dim=-1) - mvm /= params["scale_ratio"] # conductance normalization - mvm /= 0.2 # hardware normalization - mvm /= bit_res / 2.0 # normalize + if io_pars.adc_quantization: + mvm_odd_col_up_adc = cls._apply_adc_quantization(mvm_odd_col_up_adc, io_pars) + + mvm = cls._interleave_cols_nd(mvm_even_col_down_adc, mvm_odd_col_up_adc) # symmetric ADCs + + mvm /= params['scale_ratio'] # conductance normalization + mvm /= 0.2 # hardware normalization + mvm /= bit_res / 2.0 # normalize return mvm @classmethod @@ -476,6 +571,7 @@ def _compute_analog_mv( # type: ignore ConfigError: If unknown AnalogMVType """ + ir_drop = io_pars.ir_drop if is_test else 0.0 prepared_input = cls._prepare_inputs( @@ -525,7 +621,6 @@ def _compute_analog_mv( # type: ignore bound_test_passed_lsb, finalized_outputs_lsb = cls._finalize_output( out_values=out_values_lsb, io_pars=io_pars, **fwd_pars ) - time_steps = 2 ** int( log2(bit_res) - io_pars.split_mode_bit_shift - 1 ) # minus 1 for sign bit @@ -545,11 +640,10 @@ def _compute_analog_mv( # type: ignore bound_test_passed_msb, finalized_outputs_msb = cls._finalize_output( out_values=out_values_msb, io_pars=io_pars, **fwd_pars ) - finalized_outputs = ( finalized_outputs_lsb + (2**io_pars.split_mode_bit_shift) * finalized_outputs_msb ) - bound_test_passed = bound_test_passed_lsb * bound_test_passed_msb + bound_test_passed = ( bound_test_passed_lsb * bound_test_passed_msb ) elif io_pars.mv_type == AnalogMVType.BIT_WISE: finalized_outputs, bound_test_passed = 0.0, True From 864b09097a6bd74e2df2ba4e88efb524ea4e9e7b Mon Sep 17 00:00:00 2001 From: "Jose Luquin jluquin@ibm.com" Date: Fri, 5 Dec 2025 13:21:21 -0500 Subject: [PATCH 2/4] Modifications/Cleanup/Improvements to IRDrop example and implementation; signoff included Signed-off-by: Jose Luquin jluquin@ibm.com Signed-off-by: Jose Luquin jluquin@ibm.com --- examples/28_advanced_irdrop.py | 13 +++--- src/aihwkit/simulator/parameters/io.py | 2 +- .../simulator/tiles/analog_mvm_irdrop_t.py | 44 +++++++++---------- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/examples/28_advanced_irdrop.py b/examples/28_advanced_irdrop.py index c3e5e625..7c417079 100644 --- a/examples/28_advanced_irdrop.py +++ b/examples/28_advanced_irdrop.py @@ -232,19 +232,18 @@ def network_comparison( IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config_default_conventional_dt_irdrop ) model_dict.update({"default_conventional_model_dt_irdrop": - model_default_conventional_dt_irdrop}) + model_default_conventional_dt_irdrop}) # conventional time-dependent ir drop with input-dependent PCM read noise (flag testing) rpu_config_conventional_dt_irdrop_PCMnoise = rpu_config_modifications( - TorchInferenceRPUConfigIRDropT() - ) + TorchInferenceRPUConfigIRDropT()) rpu_config_conventional_dt_irdrop_PCMnoise.forward.apply_xdep_pcm_read_noise = True rpu_config_conventional_dt_irdrop_PCMnoise.forward.xdep_pcm_read_noise_scale = 1.0 model_conventional_dt_irdrop_PCMnoise = AnalogLinear( IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config_conventional_dt_irdrop_PCMnoise ) model_dict.update({"conventional_model_dt_irdrop_PCMnoise": - model_conventional_dt_irdrop_PCMnoise}) + model_conventional_dt_irdrop_PCMnoise}) # conventional time-dependent ir drop without ADC quantization (flag testing) rpu_config_conventional_dt_irdrop_noADC = rpu_config_modifications( @@ -255,7 +254,7 @@ def network_comparison( IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config_conventional_dt_irdrop_noADC ) model_dict.update({"conventional_model_dt_irdrop_noADC": - model_conventional_dt_irdrop_noADC}) + model_conventional_dt_irdrop_noADC}) # split mode pwm time-dependent ir drop rpu_config_split_mode_dt_irdrop = rpu_config_modifications( @@ -267,7 +266,7 @@ def network_comparison( IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config_split_mode_dt_irdrop ) model_dict.update({"split_mode_pwm_dt_irdrop": - model_split_mode_dt_irdrop}) + model_split_mode_dt_irdrop}) # bit wise time-dependent ir drop rpu_config_bitwise_dt_irdrop = rpu_config_modifications( @@ -278,7 +277,7 @@ def network_comparison( IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config_bitwise_dt_irdrop ) model_dict.update({"bit_wise_dt_irdrop": - model_bitwise_dt_irdrop}) + model_bitwise_dt_irdrop}) # default model rpu_config_default = rpu_config_modifications(InferenceRPUConfig()) diff --git a/src/aihwkit/simulator/parameters/io.py b/src/aihwkit/simulator/parameters/io.py index 95ae810b..490fde17 100644 --- a/src/aihwkit/simulator/parameters/io.py +++ b/src/aihwkit/simulator/parameters/io.py @@ -364,4 +364,4 @@ class IOParametersIRDropT(IOParameters): ir_drop_integration_sum: bool = False """Sets current integration to use summation rather than the default trapezoidal integation method. - """ \ No newline at end of file + """ diff --git a/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py b/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py index 3a786fa3..66c66e9c 100644 --- a/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py +++ b/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py @@ -238,37 +238,37 @@ def _apply_xdep_pcm_read_noise( read_noise_scale = io_pars.xdep_pcm_read_noise_scale - ## Default model analytical-fit coefficients + # Default model analytical-fit coefficients sigma_noise_slope_coefficients = [0.00021926, -0.00187352, - 0.00655714, -0.0146159, 0.02578764] + 0.00655714, -0.0146159, 0.02578764] sigma_noise_offset_coefficients = [1.11906167e-09, -9.08576764e-09, - 3.07015063e-08, -6.77079241e-08, 1.17763144e-07] + 3.07015063e-08, -6.77079241e-08, 1.17763144e-07] # First, convert activations/inputs into integer ns units - t_integration = torch_mean(torch_abs(input_) / (2*res)) + t_integration = torch_mean(torch_abs(input_) / (2 * res)) if t_integration > 0.: x = torch_log10(t_integration) else: - torch.tensor(0., dtype=t_integration.dtype,device=t_integration.device) + torch.tensor(0., dtype=t_integration.dtype, device=t_integration.device) # Convert weights into conductance units conductances_lst, params = g_converter.convert_to_conductances(weight) - conductances = torch_abs(conductances_lst[0] - conductances_lst[1]) #[uS] - - sig_noise_slope = ( sigma_noise_slope_coefficients[0]*(x**4) + - sigma_noise_slope_coefficients[1]*(x**3) + - sigma_noise_slope_coefficients[2]*(x**2) + - sigma_noise_slope_coefficients[3]*(x**1) + - sigma_noise_slope_coefficients[4] - ) - sig_noise_offset = ( sigma_noise_offset_coefficients[0]*(x**4) + - sigma_noise_offset_coefficients[1]*(x**3) + - sigma_noise_offset_coefficients[2]*(x**2) + - sigma_noise_offset_coefficients[3]*(x**1) + - sigma_noise_offset_coefficients[4] - ) * 1e6 #[uS] + conductances = torch_abs(conductances_lst[0] - conductances_lst[1]) # [uS] + + sig_noise_slope = (sigma_noise_slope_coefficients[0] * (x ** 4) + + sigma_noise_slope_coefficients[1] * (x**3) + + sigma_noise_slope_coefficients[2] * (x**2) + + sigma_noise_slope_coefficients[3] * (x**1) + + sigma_noise_slope_coefficients[4] + ) + sig_noise_offset = (sigma_noise_offset_coefficients[0] * (x**4) + + sigma_noise_offset_coefficients[1] * (x**3) + + sigma_noise_offset_coefficients[2] * (x**2) + + sigma_noise_offset_coefficients[3] * (x**1) + + sigma_noise_offset_coefficients[4] + ) * 1e6 # [uS] sig_noise = (sig_noise_slope * conductances) + sig_noise_offset g_final = conductances + read_noise_scale * sig_noise * randn_like(weight) @@ -499,7 +499,7 @@ def _matmul_irdrop( mvm_even_col_down_adc = torch_sum(i_out_nd, dim=-1) else: # Insert a '0th' time step trapz integration - i_out_nd = torch.cat( (i_out_nd[...,0].unsqueeze(-1), i_out_nd), -1) + i_out_nd = torch.cat((i_out_nd[..., 0].unsqueeze(-1), i_out_nd), -1) mvm_even_col_down_adc = torch_trapz(i_out_nd, dim=-1) if io_pars.adc_quantization: @@ -520,7 +520,7 @@ def _matmul_irdrop( mvm_even_col_down_adc = torch_sum(i_out_nd, dim=-1) else: # Insert a '0th' time step trapz integration - i_out_nd = torch.cat( (i_out_nd[...,0].unsqueeze(-1), i_out_nd), -1) + i_out_nd = torch.cat((i_out_nd[..., 0].unsqueeze(-1), i_out_nd), -1) mvm_odd_col_up_adc = torch_trapz(i_out_nd, dim=-1) if io_pars.adc_quantization: @@ -643,7 +643,7 @@ def _compute_analog_mv( # type: ignore finalized_outputs = ( finalized_outputs_lsb + (2**io_pars.split_mode_bit_shift) * finalized_outputs_msb ) - bound_test_passed = ( bound_test_passed_lsb * bound_test_passed_msb ) + bound_test_passed = (bound_test_passed_lsb * bound_test_passed_msb) elif io_pars.mv_type == AnalogMVType.BIT_WISE: finalized_outputs, bound_test_passed = 0.0, True From 5e766b36a5454d974e8c2d2d6cb2031f1acdaaf9 Mon Sep 17 00:00:00 2001 From: "Jose Luquin jluquin@ibm.com" Date: Fri, 5 Dec 2025 18:15:35 -0500 Subject: [PATCH 3/4] PR modifications and corrections Signed-off-by: Jose Luquin jluquin@ibm.com --- .../simulator/tiles/analog_mvm_irdrop_t.py | 91 ++++++++++--------- 1 file changed, 49 insertions(+), 42 deletions(-) diff --git a/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py b/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py index 66c66e9c..ca8c841a 100644 --- a/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py +++ b/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py @@ -34,6 +34,7 @@ randn_like, trunc, log10 as torch_log10, + zeros ) import torch from torch.autograd import no_grad @@ -287,8 +288,8 @@ def _apply_adc_quantization( Applies ADC Quantization """ - adc_ticks_even = trunc((mvm / 100) * io_pars.adc_frequency) - mvm = adc_ticks_even * 100 / io_pars.adc_frequency + adc_ticks = trunc((mvm / 100) * io_pars.adc_frequency) + mvm = adc_ticks * 100 / io_pars.adc_frequency return mvm @@ -484,49 +485,55 @@ def _matmul_irdrop( g_lst, params = g_converter.convert_to_conductances(new_weight) - vth_nd, rth_nd = cls._thev_equiv( - input_, - [g[:, 0::2] for g in g_lst], # even cols - time_steps=time_steps, - t_max=t_max, - segments=io_pars.ir_drop_segments, - r_s=ir_drop_rs, - phys_input_size=phys_input_size, - ) - i_out_nd = (vth_nd - io_pars.ir_drop_v_read) / rth_nd # uA - - if io_pars.ir_drop_integration_sum: - mvm_even_col_down_adc = torch_sum(i_out_nd, dim=-1) - else: - # Insert a '0th' time step trapz integration - i_out_nd = torch.cat((i_out_nd[..., 0].unsqueeze(-1), i_out_nd), -1) - mvm_even_col_down_adc = torch_trapz(i_out_nd, dim=-1) - - if io_pars.adc_quantization: - mvm_even_col_down_adc = cls._apply_adc_quantization(mvm_even_col_down_adc, io_pars) - - vth_nd, rth_nd = cls._thev_equiv( - flip(input_, (-1,)), # flip input - [flip(g[:, 1::2], (0,)) for g in g_lst], # odd cols - time_steps=time_steps, - t_max=t_max, - segments=io_pars.ir_drop_segments, - r_s=ir_drop_rs, - phys_input_size=phys_input_size, - ) - i_out_nd = (vth_nd - io_pars.ir_drop_v_read) / rth_nd # uA + mvm = zeros((input_.shape[0], g_lst[0].shape[1])).to(input_.device) + # (f1, (gp1, gm1)), (f2, (gp2, gm2), ... , (fn, (gpn, gmn)) # low to highest significance + for f_factor, g_lst_pair in zip(params.get("f_lst", [1.0]), zip(g_lst[::2], g_lst[1::2])): + vth_nd, rth_nd = cls._thev_equiv( + input_, + [g[:, 0::2] for g in g_lst_pair], # even cols + time_steps=time_steps, + t_max=t_max, + segments=io_pars.ir_drop_segments, + r_s=ir_drop_rs, + phys_input_size=phys_input_size, + ) + i_out_nd = (vth_nd - io_pars.ir_drop_v_read) / rth_nd # uA + + if io_pars.ir_drop_integration_sum: + mvm_even_col_down_adc = torch_sum(i_out_nd, dim=-1) + else: + # Insert a '0th' time step trapz integration + i_out_nd = torch.cat((i_out_nd[..., 0].unsqueeze(-1), i_out_nd), -1) + mvm_even_col_down_adc = torch_trapz(i_out_nd, dim=-1) + + if io_pars.adc_quantization: + print('applying 1') + mvm_even_col_down_adc = cls._apply_adc_quantization(mvm_even_col_down_adc, io_pars) + + vth_nd, rth_nd = cls._thev_equiv( + flip(input_, (-1,)), # flip input + [flip(g[:, 1::2], (0,)) for g in g_lst_pair], # odd cols + time_steps=time_steps, + t_max=t_max, + segments=io_pars.ir_drop_segments, + r_s=ir_drop_rs, + phys_input_size=phys_input_size, + ) + i_out_nd = (vth_nd - io_pars.ir_drop_v_read) / rth_nd # uA - if io_pars.ir_drop_integration_sum: - mvm_even_col_down_adc = torch_sum(i_out_nd, dim=-1) - else: - # Insert a '0th' time step trapz integration - i_out_nd = torch.cat((i_out_nd[..., 0].unsqueeze(-1), i_out_nd), -1) - mvm_odd_col_up_adc = torch_trapz(i_out_nd, dim=-1) + if io_pars.ir_drop_integration_sum: + mvm_odd_col_up_adc = torch_sum(i_out_nd, dim=-1) + else: + # Insert a '0th' time step trapz integration + i_out_nd = torch.cat((i_out_nd[..., 0].unsqueeze(-1), i_out_nd), -1) + mvm_odd_col_up_adc = torch_trapz(i_out_nd, dim=-1) - if io_pars.adc_quantization: - mvm_odd_col_up_adc = cls._apply_adc_quantization(mvm_odd_col_up_adc, io_pars) + if io_pars.adc_quantization: + mvm_odd_col_up_adc = cls._apply_adc_quantization(mvm_odd_col_up_adc, io_pars) - mvm = cls._interleave_cols_nd(mvm_even_col_down_adc, mvm_odd_col_up_adc) # symmetric ADCs + mvm += f_factor * cls._interleave_cols_nd( + mvm_even_col_down_adc, mvm_odd_col_up_adc + ) # symmetric ADCs mvm /= params['scale_ratio'] # conductance normalization mvm /= 0.2 # hardware normalization From 5b59105afe3ab3136ac6cb64aa4c32f3f957988c Mon Sep 17 00:00:00 2001 From: "Jose Luquin jluquin@ibm.com" Date: Mon, 22 Dec 2025 17:58:55 -0500 Subject: [PATCH 4/4] PR modifications/update Signed-off-by: Jose Luquin jluquin@ibm.com --- examples/28_advanced_irdrop.py | 10 ++-------- src/aihwkit/simulator/parameters/io.py | 10 ++-------- src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py | 11 ++--------- 3 files changed, 6 insertions(+), 25 deletions(-) diff --git a/examples/28_advanced_irdrop.py b/examples/28_advanced_irdrop.py index 7c417079..5d1d53e0 100644 --- a/examples/28_advanced_irdrop.py +++ b/examples/28_advanced_irdrop.py @@ -1,14 +1,8 @@ # -*- coding: utf-8 -*- -# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved. +# (C) Copyright 2020, 2021, 2022, 2023, 2024, 2025 IBM. All Rights Reserved. # -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. +# Licensed under the MIT license. See LICENSE file in the project root for details. """aihwkit example 28: advanced (time-dependent) IR drop effects diff --git a/src/aihwkit/simulator/parameters/io.py b/src/aihwkit/simulator/parameters/io.py index 490fde17..71e63f08 100644 --- a/src/aihwkit/simulator/parameters/io.py +++ b/src/aihwkit/simulator/parameters/io.py @@ -1,14 +1,8 @@ # -*- coding: utf-8 -*- -# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved. +# (C) Copyright 2020, 2021, 2022, 2023, 2024, 2025 IBM. All Rights Reserved. # -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. +# Licensed under the MIT license. See LICENSE file in the project root for details. # pylint: disable=too-many-instance-attributes diff --git a/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py b/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py index ca8c841a..7e91eb8d 100644 --- a/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py +++ b/src/aihwkit/simulator/tiles/analog_mvm_irdrop_t.py @@ -1,14 +1,8 @@ # -*- coding: utf-8 -*- -# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved. +# (C) Copyright 2020, 2021, 2022, 2023, 2024, 2025 IBM. All Rights Reserved. # -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. +# Licensed under the MIT license. See LICENSE file in the project root for details. # pylint: disable=too-many-locals, too-many-arguments @@ -507,7 +501,6 @@ def _matmul_irdrop( mvm_even_col_down_adc = torch_trapz(i_out_nd, dim=-1) if io_pars.adc_quantization: - print('applying 1') mvm_even_col_down_adc = cls._apply_adc_quantization(mvm_even_col_down_adc, io_pars) vth_nd, rth_nd = cls._thev_equiv(