Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/jax wrapper autogalaxy fixes #156

Open
wants to merge 1 commit into
base: feature/jax_wrapper
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 91 additions & 91 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -1,92 +1,92 @@
name: Tests

on: [push, pull_request]

jobs:
unittest:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9, '3.10', '3.11', '3.12']
steps:
- name: Checkout PyAutoConf
uses: actions/checkout@v2
with:
repository: rhayes777/PyAutoConf
path: PyAutoConf
- name: Checkout PyAutoArray
uses: actions/checkout@v2
with:
path: PyAutoArray
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
id: cache-pip
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
# if: steps.cache-pip.outputs.cache-hit != 'true'
run: |
pip3 install --upgrade pip
pip3 install setuptools
pip3 install wheel
pip3 install pytest coverage pytest-cov
pip3 install -r PyAutoConf/requirements.txt
pip3 install -r PyAutoArray/requirements.txt
pip3 install -r PyAutoArray/optional_requirements.txt

cd PyAutoArray/autoarray/util/nn/src/nn
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/runner/work/PyAutoArray/PyAutoArray/PyAutoArray/autoarray/util/nn/src/nn
bash ./configure
cp makefile_autolens makefile
make
cd /home/runner/work/PyAutoArray/PyAutoArray

- name: Extract branch name
shell: bash
run: |
cd PyAutoArray
echo "##[set-output name=branch;]$(echo ${GITHUB_REF#refs/heads/})"
id: extract_branch
- name: Change to same branch if exists in deps
shell: bash
run: |
export PACKAGES=("PyAutoConf")
export BRANCH="${{ steps.extract_branch.outputs.branch }}"
for PACKAGE in ${PACKAGES[@]}; do
pushd $PACKAGE
export existed_in_remote=$(git ls-remote --heads origin ${BRANCH})
if [[ -z ${existed_in_remote} ]]; then
echo "Branch $BRANCH did not exist in $PACKAGE"
else
echo "Branch $BRANCH did exist in $PACKAGE"
git fetch
git checkout $BRANCH
fi
popd
done
- name: Run tests
run: |
export ROOT_DIR=`pwd`
export PYTHONPATH=$PYTHONPATH:$ROOT_DIR/PyAutoConf
export PYTHONPATH=$PYTHONPATH:$ROOT_DIR/PyAutoArray
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/runner/work/PyAutoArray/PyAutoArray/PyAutoArray/autoarray/util/nn/src/nn
pushd PyAutoArray
python3 -m pytest --cov autoarray --cov-report xml:coverage.xml
- name: Slack send
if: ${{ failure() }}
id: slack
uses: slackapi/slack-github-action@v1.21.0
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
with:
channel-id: C03S98FEDK2
payload: |
{
"text": "${{ github.repository }}/${{ github.ref_name }} (Python ${{ matrix.python-version }}) build result: ${{ job.status }}\n${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}"
name: Tests
on: [push, pull_request]
jobs:
unittest:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9, '3.10', '3.11', '3.12']
steps:
- name: Checkout PyAutoConf
uses: actions/checkout@v2
with:
repository: rhayes777/PyAutoConf
path: PyAutoConf
- name: Checkout PyAutoArray
uses: actions/checkout@v2
with:
path: PyAutoArray
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
id: cache-pip
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
# if: steps.cache-pip.outputs.cache-hit != 'true'
run: |
pip3 install --upgrade pip
pip3 install setuptools
pip3 install wheel
pip3 install pytest coverage pytest-cov
pip3 install -r PyAutoConf/requirements.txt
pip3 install -r PyAutoArray/requirements.txt
pip3 install -r PyAutoArray/optional_requirements.txt
cd PyAutoArray/autoarray/util/nn/src/nn
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/runner/work/PyAutoArray/PyAutoArray/PyAutoArray/autoarray/util/nn/src/nn
bash ./configure
cp makefile_autolens makefile
make
cd /home/runner/work/PyAutoArray/PyAutoArray
- name: Extract branch name
shell: bash
run: |
cd PyAutoArray
echo "##[set-output name=branch;]$(echo ${GITHUB_REF#refs/heads/})"
id: extract_branch
- name: Change to same branch if exists in deps
shell: bash
run: |
export PACKAGES=("PyAutoConf")
export BRANCH="${{ steps.extract_branch.outputs.branch }}"
for PACKAGE in ${PACKAGES[@]}; do
pushd $PACKAGE
export existed_in_remote=$(git ls-remote --heads origin ${BRANCH})
if [[ -z ${existed_in_remote} ]]; then
echo "Branch $BRANCH did not exist in $PACKAGE"
else
echo "Branch $BRANCH did exist in $PACKAGE"
git fetch
git checkout $BRANCH
fi
popd
done
- name: Run tests
run: |
export ROOT_DIR=`pwd`
export PYTHONPATH=$PYTHONPATH:$ROOT_DIR/PyAutoConf
export PYTHONPATH=$PYTHONPATH:$ROOT_DIR/PyAutoArray
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/runner/work/PyAutoArray/PyAutoArray/PyAutoArray/autoarray/util/nn/src/nn
pushd PyAutoArray
python3 -m pytest --cov autoarray --cov-report xml:coverage.xml
- name: Slack send
if: ${{ failure() }}
id: slack
uses: slackapi/slack-github-action@v1.21.0
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
with:
channel-id: C03S98FEDK2
payload: |
{
"text": "${{ github.repository }}/${{ github.ref_name }} (Python ${{ matrix.python-version }}) build result: ${{ job.status }}\n${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}"
}
36 changes: 18 additions & 18 deletions autoarray/config/general.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
fits:
flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9.
inversion:
check_reconstruction: true # If True, the inversion's reconstruction is checked to ensure the solution of a meshs's mapper is not an invalid solution where the values are all the same.
use_positive_only_solver: true # If True, inversion's use a positive-only linear algebra solver by default, which is slower but prevents unphysical negative values in the reconstructed solutuion.
no_regularization_add_to_curvature_diag_value : 1.0e-3 # The default value added to the curvature matrix's diagonal when regularization is not applied to a linear object, which prevents inversion's failing due to the matrix being singular.
positive_only_uses_p_initial: true # If True, the positive-only solver of an inversion's uses an initial guess of the reconstructed data's values as which values should be positive, speeding up the solver.
use_border_relocator: false # If True, by default a pixelization's border is used to relocate all pixels outside its border to the border.
reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor.
numba:
use_numba: true
cache: false
nopython: true
parallel: false
pixelization:
voronoi_nn_max_interpolation_neighbors: 300
structures:
native_binned_only: false # If True, data structures are only stored in their native and binned format. This is used to reduce memory usage in autocti.
fits:
flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9.
inversion:
check_reconstruction: true # If True, the inversion's reconstruction is checked to ensure the solution of a meshs's mapper is not an invalid solution where the values are all the same.
use_positive_only_solver: true # If True, inversion's use a positive-only linear algebra solver by default, which is slower but prevents unphysical negative values in the reconstructed solutuion.
no_regularization_add_to_curvature_diag_value : 1.0e-3 # The default value added to the curvature matrix's diagonal when regularization is not applied to a linear object, which prevents inversion's failing due to the matrix being singular.
positive_only_uses_p_initial: true # If True, the positive-only solver of an inversion's uses an initial guess of the reconstructed data's values as which values should be positive, speeding up the solver.
use_border_relocator: false # If True, by default a pixelization's border is used to relocate all pixels outside its border to the border.
reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor.
numba:
use_numba: true
cache: false
nopython: true
parallel: false
pixelization:
voronoi_nn_max_interpolation_neighbors: 300
structures:
native_binned_only: false # If True, data structures are only stored in their native and binned format. This is used to reduce memory usage in autocti.
6 changes: 3 additions & 3 deletions autoarray/config/grids.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
radial_minimum:
function_name:
class_name: 1.0e-08
radial_minimum:
function_name:
class_name: 1.0e-08
44 changes: 22 additions & 22 deletions autoarray/config/logging.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
version: 1
disable_existing_loggers: false

handlers:
console:
class: logging.StreamHandler
level: INFO
stream: ext://sys.stdout
formatter: formatter
file:
class: logging.FileHandler
level: INFO
filename: root.log
formatter: formatter

root:
level: INFO
handlers: [ console, file ]

formatters:
formatter:
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
version: 1
disable_existing_loggers: false
handlers:
console:
class: logging.StreamHandler
level: INFO
stream: ext://sys.stdout
formatter: formatter
file:
class: logging.FileHandler
level: INFO
filename: root.log
formatter: formatter
root:
level: INFO
handlers: [ console, file ]
formatters:
formatter:
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
74 changes: 37 additions & 37 deletions autoarray/numpy_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
import logging

from os import environ

use_jax = environ.get("USE_JAX", "0") == "1"

if use_jax:
try:
import jax
from jax import numpy as np, jit

print("JAX mode enabled")
except ImportError:
raise ImportError(
"JAX is not installed. Please install it with `pip install jax`."
)
else:
import numpy as np

def jit(function, *_, **__):
return function


try:
from jax._src.tree_util import register_pytree_node
from jax._src.tree_util import register_pytree_node_class

from jax import Array
except ImportError:

def register_pytree_node_class(cls):
return cls

def register_pytree_node(*_, **__):
pass

Array = np.ndarray
import logging
from os import environ
use_jax = environ.get("USE_JAX", "0") == "1"
if use_jax:
try:
import jax
from jax import numpy as np, jit
print("JAX mode enabled")
except ImportError:
raise ImportError(
"JAX is not installed. Please install it with `pip install jax`."
)
else:
import numpy as np
def jit(function, *_, **__):
return function
try:
from jax._src.tree_util import register_pytree_node
from jax._src.tree_util import register_pytree_node_class
from jax import Array
except ImportError:
def register_pytree_node_class(cls):
return cls
def register_pytree_node(*_, **__):
pass
Array = np.ndarray
4 changes: 2 additions & 2 deletions autoarray/structures/arrays/kernel_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
)

if normalize:
self._array[:] = np.divide(self._array, np.sum(self._array))
self._array = np.divide(self._array, np.sum(self._array))

@classmethod
def no_mask(
Expand Down Expand Up @@ -84,7 +84,7 @@ def no_mask(
If True, the Kernel2D's array values are normalized such that they sum to 1.0.
"""
values = Array2D.no_mask(
values=values,
values=values.array,
shape_native=shape_native,
pixel_scales=pixel_scales,
origin=origin,
Expand Down
Loading
Loading