Skip to content

Conversation

@jubueche
Copy link
Collaborator

@jubueche jubueche commented Nov 6, 2025

Related issues

#627

Description

Tile array wasn't used when instantiating analog conv layer with rpu where the max. mapping size was smaller than the layer size.
Fix: Pass 2D layer shape when instantiating analog tile.

Details

Small test to check correctness:

import torch
import torch.nn as nn
from aihwkit.simulator.configs import InferenceRPUConfig
from aihwkit.nn.conversion import convert_to_analog, convert_to_analog_mapped


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)

    def forward(self, x):
        x = self.conv1(x)
        return x


def test_mapping_correctness():
    """Test that convert_to_analog and convert_to_analog_mapped produce the same results."""
    model = Net()
    rpu_config = InferenceRPUConfig()
    rpu_config.forward.is_perfect = True

    rpu_config.mapping.max_input_size = 10
    rpu_config.mapping.max_output_size = 10
    
    # Convert using both methods
    model_1 = convert_to_analog(model, rpu_config)
    model_2 = convert_to_analog_mapped(model, rpu_config)
    
    # Create test input
    test_input = torch.randn(1, 1, 28, 28)
    
    # Set both models to eval mode
    model_1.eval()
    model_2.eval()
    
    # Get outputs from both models
    with torch.no_grad():
        output_1 = model_1(test_input)
        output_2 = model_2(test_input)
    
    # Check if outputs are close (allowing for small numerical differences)
    are_close = torch.allclose(output_1, output_2, rtol=1e-5, atol=1e-6)
    
    print(f"Outputs are close: {are_close}")
    print(f"Max difference: {torch.max(torch.abs(output_1 - output_2)).item()}")
    
    if not are_close:
        print("WARNING: Outputs differ between conversion methods!")
    else:
        print("SUCCESS: Both conversion methods produce identical results!")
    
    return are_close


if __name__ == "__main__":
    model = Net()
    rpu_config = InferenceRPUConfig()
    rpu_config.mapping.max_input_size = 10
    rpu_config.mapping.max_output_size = 10
    
    model_1 = convert_to_analog(model, rpu_config)
    print("convert_to_analog - conv1")
    for tile in model_1.conv1.analog_tiles():
        print("\t", tile)

    model_2 = convert_to_analog_mapped(model, rpu_config)
    print("convert_to_analog_mapped - conv1")
    for tile in model_2.conv1.analog_tiles():
        print("\t", tile)
    
    # Run the correctness test
    print("\n" + "="*50)
    print("Testing mapping correctness...")
    print("="*50)
    test_mapping_correctness()

Signed-off-by: Julian Buechel <jub@zurich.ibm.com>
@jubueche
Copy link
Collaborator Author

jubueche commented Nov 6, 2025

@PabloCarmona locally running make mypy does not result in errors. Can you have a look at what is going on?

Signed-off-by: Julian Buechel <jub@zurich.ibm.com>
@jubueche
Copy link
Collaborator Author

jubueche commented Nov 6, 2025

@PabloCarmona locally running make mypy does not result in errors. Can you have a look at what is going on?

Fixed it by changing the python version for mypy in the setup config. Surprised this did not trigger errors before. Maybe new matplotlib is now being used in the tests.

@anu-pub
Copy link
Collaborator

anu-pub commented Nov 14, 2025

@maljoras-sony would you have time to quickly check if the fix we wrote looks correct? This is one of the last remaining bugs we have in the toolkit.

Copy link
Collaborator

@maljoras maljoras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where in the code changes the issue is fixed. The change in get default tile class looks wrong to me... But after a more careful check it is actually correct as the returned class depends on the sizes.

maljoras
maljoras previously approved these changes Nov 14, 2025
Copy link
Collaborator

@maljoras maljoras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems actually correct as class is different based on sizes

@anu-pub
Copy link
Collaborator

anu-pub commented Nov 14, 2025

@jubueche If the fix is correct we should deprecate convert_to_analog_mapped so that only convert_to_analog is used.

@jubueche
Copy link
Collaborator Author

@anu-pub I think the proper way would be:

  1. Fix this bug in the current version/release, but leave rest as-is.
  2. Make a new release where we add deprecation messages when the mapped one is used. In this message, we would say something like: "... is deprecated and will be removed in the next version. Use ... instead."

@anu-pub
Copy link
Collaborator

anu-pub commented Nov 14, 2025

@jubueche Include this fix and the deprecation notice in the upcoming release #748.

Signed-off-by: Julian Buechel <jub@zurich.ibm.com>
@jubueche
Copy link
Collaborator Author

Will be implemented in #748

@jubueche jubueche closed this Nov 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants