-
Notifications
You must be signed in to change notification settings - Fork 177
define data attribution for AnalogContext #717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
961356f
d6af456
f07b1fc
4b0a202
508c5c2
19769ee
ddb0779
668d42d
cab1fd4
ed9ef8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -150,7 +150,7 @@ def set_weights( | |
| if not isinstance(bias, Tensor): | ||
| bias = from_numpy(array(bias)) | ||
|
|
||
| self.bias.data[:] = bias[:].clone().detach().to(self.get_dtype()).to(self.bias.device) | ||
| self.bias.data.copy_(bias) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this allow for setting the bias with the data type defined in the tile? While it is correct that torch defines the data type of a layer solely by the data type of the weight tensor, I found it more convenient to handle all the specialized code to have a dtype property on tile level (as this is essentially the "analog tensor" ). Do you suggest that the d_type should be removed from the tile, but now determined by the ctx.data tensor dtype? |
||
| bias = None | ||
|
|
||
| combined_weights = self._combine_weights(weight, bias) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,7 +77,7 @@ def set_weights(self, weight: Tensor) -> None: | |
| Args: | ||
| weight: ``[out_size, in_size]`` weight matrix. | ||
| """ | ||
| self.weight.data = weight.clone().to(self.weight.device) | ||
| self.weight.data.copy_(weight) | ||
|
|
||
| def get_weights(self) -> Tensor: | ||
| """Get the tile weights. | ||
|
|
@@ -87,7 +87,7 @@ def get_weights(self) -> Tensor: | |
| matrix; and the second item is either the ``[out_size]`` bias vector | ||
| or ``None`` if the tile is set not to use bias. | ||
| """ | ||
| return self.weight.data.detach().cpu() | ||
| return self.weight.data | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that the current convention is that get_weights always returns CPU weights. If you want to change this, the RPUCuda get_weights call need to change as well, as they producing CPU weights by default. Moreover, get_weights will always product a copy (without the backward trace) by design, to avoid implicit things that cannot be done with analog weights. Of course, hardware aware training is a special case, but for that we have a separate tile. |
||
|
|
||
| def get_x_size(self) -> int: | ||
| """Returns input size of tile""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the logic here? Note that this will only the a copy of the current weights. So if you update the weights (using RPUCuda) the
analog_ctx.datawill not be synchronized correctly with the actual weight. Of course the size of the weight will not change, but it will be more confusing of one maintains two different version of the weight which are not synced, or not?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely, there could be an out-of-sync concern here. Therefore, I also change the definition of
self.tile.get_weights(). So far, the tile will return an original weight instead of a detached tensor here. Since thedatahere and the actual weight tenser are the same object in essence, there is no sync issue here.