diff --git a/CHANGELOG.md b/CHANGELOG.md index 955b8db..fe0f0cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Changelog ## 1.0.0 +* Public release -* Public release \ No newline at end of file +## 1.0.1 (Patches for Pytorch 2.0.0) +* added `grad.setter` to `PseudoParameterModule` class diff --git a/README.md b/README.md index 919f2b9..1fedc3f 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ pip install analogvnn ![3 Layered Linear Photonic Analog Neural Network](docs/_static/analogvnn_model.png) -[//]: # (![3 Layered Linear Photonic Analog Neural Network](https://github.com/Vivswan/AnalogVNN/raw/v1.0.0/docs/_static/analogvnn_model.png)) +[//]: # (![3 Layered Linear Photonic Analog Neural Network](https://github.com/Vivswan/AnalogVNN/raw/release/docs/_static/analogvnn_model.png)) ## Abstract diff --git a/analogvnn/parameter/PseudoParameter.py b/analogvnn/parameter/PseudoParameter.py index 2118e6f..1574416 100644 --- a/analogvnn/parameter/PseudoParameter.py +++ b/analogvnn/parameter/PseudoParameter.py @@ -178,6 +178,16 @@ def grad(self): return self._transformed.grad + @grad.setter + def grad(self, grad: Tensor): + """Sets the gradient of the parameter. + + Args: + grad (Tensor): the gradient. + """ + + self._transformed.grad = grad + @property def module(self): """Returns the module. diff --git a/pyproject.toml b/pyproject.toml index 9929739..08f55aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ py-modules = ['analogvnn'] [project] # $ pip install analogvnn name = "analogvnn" -version = "1.0.0" +version = "1.0.1" description = "A fully modular framework for modeling and optimizing analog/photonic neural networks" # Optional readme = "README.md" requires-python = ">=3.7" diff --git a/requirements.txt b/requirements.txt index e87479b..61ad414 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ ---extra-index-url https://download.pytorch.org/whl/cu116 +--extra-index-url https://download.pytorch.org/whl/cu118 torch -torchvision~=0.14.0 +torchvision torchaudio numpy scipy @@ -11,5 +11,6 @@ importlib-metadata; python_version < '3.8' tensorflow>=2.0.0 tensorboard>=2.0.0 torchinfo -# conda install graphviz +# conda install graphviz python-graphviz pydot pydotplus python-dotenv +# conda install --channel conda-forge pygraphviz graphviz