Skip to content

Commit

Permalink
Merge pull request #49 from Vivswan/develop
Browse files Browse the repository at this point in the history
v1.0.7
  • Loading branch information
Vivswan authored Nov 22, 2023
2 parents 3ff13cc + 38a966c commit 0bf684c
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 1.0.7
* Fixed `GeLU` backward function equation.

## 1.0.6

* `Model` is subclass of `BackwardModule` for additional functionality.
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ Documentation: [https://analogvnn.readthedocs.io/](https://analogvnn.readthedocs

## Abstract

![3 Layered Linear Photonic Analog Neural Network](https://github.com/Vivswan/AnalogVNN/raw/release/docs/_static/analogvnn_model.png)
![3 Layered Linear Photonic Analog Neural Network](docs/_static/analogvnn_model.png)

[//]: # (![3 Layered Linear Photonic Analog Neural Network](https://github.com/Vivswan/AnalogVNN/raw/release/docs/_static/analogvnn_model.png))

**AnalogVNN** is a simulation framework built on PyTorch which can simulate the effects of
optoelectronic noise, limited precision, and signal normalization present in photonic
Expand Down
14 changes: 9 additions & 5 deletions analogvnn/graph/AccumulateGrad.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Union, Callable
from typing import Dict, Union, Callable, List

import torch
from torch import nn
Expand All @@ -9,6 +9,10 @@
__all__ = ['AccumulateGrad']


def _get_index(tensor: torch.Tensor, tensor_list: List[torch.Tensor]) -> int:
return [(i.shape == tensor.shape and torch.all(torch.eq(i, tensor))) for i in tensor_list].index(True)


class AccumulateGrad:
"""AccumulateGrad is a module that accumulates the gradients of the outputs of the module it is attached to.
Expand Down Expand Up @@ -74,14 +78,14 @@ def __call__( # noqa: C901
if forward_out_arg is True and isinstance(forward_in_arg, int) and not isinstance(forward_in_arg, bool):
forward_inputs = forward_input_output_graph[predecessor].inputs.args
forward_outputs = forward_input_output_graph[self.module].outputs.args
forward_out_arg = forward_inputs.index(forward_outputs[forward_in_arg])
forward_out_arg = _get_index(forward_outputs[forward_in_arg], forward_inputs)
grad_output = grad_output[forward_out_arg]

# 7
if forward_out_arg is True and isinstance(forward_in_kwarg, str):
forward_inputs = forward_input_output_graph[predecessor].inputs.args
forward_outputs = forward_input_output_graph[self.module].outputs.kwargs
forward_out_arg = forward_inputs.index(forward_outputs[forward_in_kwarg])
forward_out_arg = _get_index(forward_outputs[forward_in_kwarg], forward_inputs)
grad_output = grad_output[forward_out_arg]

# 1
Expand All @@ -92,7 +96,7 @@ def __call__( # noqa: C901
if forward_inputs[i] not in forward_outputs:
continue

value_index = forward_outputs.index(forward_inputs[i])
value_index = _get_index(forward_inputs[i], forward_outputs)
if value_index not in grad_inputs_args:
grad_inputs_args[value_index] = torch.zeros_like(grad_output[i])
grad_inputs_args[value_index] += grad_output[i]
Expand All @@ -103,7 +107,7 @@ def __call__( # noqa: C901
forward_inputs = forward_input_output_graph[predecessor].inputs.args
forward_outputs = forward_input_output_graph[self.module].outputs.kwargs
for i in forward_outputs:
value_index = forward_inputs.index(forward_outputs[i])
value_index = _get_index(forward_outputs[i], forward_inputs)

if i not in grad_inputs_kwargs:
grad_inputs_kwargs[i] = torch.zeros_like(grad_output[value_index])
Expand Down
17 changes: 15 additions & 2 deletions analogvnn/graph/ForwardGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,15 @@ def calculate(
inputs = (inputs,)

if not self.graph_state.use_autograd_graph and is_training:
value_tensor = False
for i in inputs:
if not isinstance(i, torch.Tensor):
continue
i.requires_grad = True
value_tensor = True

if not value_tensor:
raise ValueError('At least one input must be a tensor.')

input_output_graph = self._pass(self.INPUT, *inputs)
if is_training:
Expand Down Expand Up @@ -102,8 +109,14 @@ def _pass(self, from_node: GraphEnum, *inputs: Tensor) -> Dict[GraphEnum, InputO
if module != from_node:
inputs = self.parse_args_kwargs(input_output_graph, module, predecessors)
if not self.graph_state.use_autograd_graph:
inputs.args = [self._detach_tensor(i) for i in inputs.args]
inputs.kwargs = {k: self._detach_tensor(v) for k, v in inputs.kwargs.items()}
inputs.args = [
self._detach_tensor(i) if isinstance(i, torch.Tensor) else i
for i in inputs.args
]
inputs.kwargs = {
k: self._detach_tensor(v) if isinstance(v, torch.Tensor) else v
for k, v in inputs.kwargs.items()
}
input_output_graph[module] = InputOutput(inputs=inputs)

if isinstance(module, GraphEnum):
Expand Down
2 changes: 1 addition & 1 deletion analogvnn/nn/activation/Gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ def backward(self, grad_output: Optional[Tensor]) -> Optional[Tensor]:

x = self.inputs
grad = (1 / 2) * (
(1 + torch.erf(x / math.sqrt(2))) + x * ((2 / math.sqrt(math.pi)) * torch.exp(-torch.pow(x, 2)))
(1 + torch.erf(x / math.sqrt(2))) + x * (math.sqrt(2 * math.pi) * torch.exp(-torch.pow(x, 2) / 2))
)
return grad_output * grad
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pyproject-build

[build-system]
requires = ["wheel", "setuptools>=61.0.0", "flit_core >=3.2,<4"]
requires = ["wheel>=0.38.0", "setuptools>=61.0.0", "flit_core>=3.2,<4"]
build-backend = "flit_core.buildapi"

[tool.flit.module]
Expand All @@ -19,7 +19,7 @@ where = ["analogvnn"]
[project]
# $ pip install analogvnn
name = "analogvnn"
version = "1.0.6"
version = "1.0.7"
description = "A fully modular framework for modeling and optimizing analog/photonic neural networks"
readme = "README.md"
requires-python = ">=3.7"
Expand Down
7 changes: 5 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
--extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://download.pytorch.org/whl/cu121
torch
torchvision
torchaudio
numpy
numpy>=1.22.2
scipy
networkx
importlib-metadata; python_version < '3.8'

# Full
wheel>=0.38.0
tensorflow>=2.0.0
tensorboard>=2.0.0
torchinfo
# conda install graphviz python-graphviz pydot pydotplus python-dotenv
# conda install --channel conda-forge pygraphviz
graphviz
pillow>=10.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
werkzeug>=3.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
2 changes: 1 addition & 1 deletion requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Development
flit
setuptools>=61.0.0
setuptools>=65.5.1
build # building the package {pyproject-build}
twine # to publish on pypi {twine upload --repository-url=https://test.pypi.org/legacy/ dist/*} {twine upload dist/*}
johnnydep # to see dependencies {johnnydep <package>}
2 changes: 2 additions & 0 deletions requirements/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ sphinx-notfound-page
sphinx-inline-tabs
sphinxext-opengraph
sphinxcontrib-katex # math
tornado>=6.3.3 # not directly required, pinned by Snyk to avoid a vulnerability
pillow>=10.0.1 # not directly required, pinned by Snyk to avoid a vulnerability

0 comments on commit 0bf684c

Please sign in to comment.