Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from torch import nn
import numpy as np
import torch.nn.functional as F

class CNN(nn.Module):
"""
A CNN with residual connections

Attributes
----------
image_x: int
dim 3 of the images to be processed
image_y: int
dim 2 of the images to be processed
image_c: int
dim 1 of the images to be processed (number of channels)
out_shape: int
Shape of the output vector

Methods
-------
forward(x)
Feed-forward method
"""
def __init__(self, image_x, image_y, image_c, out_shape):
super(CNN, self).__init__()
pow_2 = int(np.floor(np.log2(max(image_x, image_y))))
self.module_list = nn.ModuleList()
channels = image_c
image_x_, image_y_ = image_x, image_y
for _ in range(0,pow_2,2):
self.module_list.append(
nn.Conv2d(in_channels=channels, out_channels=channels*2, kernel_size=4, stride=2, padding=1)
)
self.module_list.append(
nn.ReLU()
)
self.module_list.append(
nn.MaxPool2d(kernel_size=4, stride=2, padding=1)
)
channels *= 2
image_x_ //= 4
image_y_ //= 4
self.fc1 = nn.Linear(in_features=channels*image_x_*image_y_, out_features=128)
self.fc2 = nn.Linear(in_features=128, out_features=out_shape)

def forward(self, x):
for module in self.module_list:
x = module(x)
x = x.view(x.shape[0], -1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x

Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import torch
import numpy as np
from tqdm import tqdm
import argparse
import os
import random
from torch.utils import tensorboard
from lensing_envs.lensing_envs import Source
from cnn import CNN

class SersicDecoder(torch.nn.Module):
"""
A wrapper around the a residual CNN with two heads, one for parameter optimisation and the other to stop addition of further Sérsics

Attributes
----------

image_shape: int
Length of a single dimension of the square images to be studied
hidden_dim: int
Size of the hidden dimension before passing through to the heads
output_dim: int
Output dimension for the parameter optimisation head

Methods
-------
forward(x)
Feed-forward method
"""
def __init__(self, image_shape, hidden_dim, output_dim=6): # 6 = (n, r_e, q, θ, x₀, y₀)
super().__init__()
self.encoder = CNN(image_shape[0], image_shape[1], image_shape[2], hidden_dim)
self.fc_sersic = torch.nn.Linear(hidden_dim, output_dim)
self.fc_stop = torch.nn.Linear(hidden_dim, 1) # For termination probability

def forward(self, x): # x: (B, C, H, W)
z = self.encoder(x) # (B, hidden_dim)
sersic_params = self.fc_sersic(z) # (B, output_dim)
stop_prob = torch.sigmoid(self.fc_stop(z)).squeeze(-1) # (B,)
return sersic_params, stop_prob


def strtobool(x):
"""
Helper function to convert a string to a boolean value
"""
if x.lower().strip() == 'true': return True
else: return False

def parse_args():
"""
Handles arguments for the argparse
"""
parser = argparse.ArgumentParser()
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
help='the name of this experiment')
parser.add_argument('--learning-rate', type=float, default=2.5e-4,
help='the LR of the optimizer(s)')
parser.add_argument('--seed', type=int, default=0,
help='the seed of the experiment')
parser.add_argument('--total-timesteps', type=int, default=8e3,
help='total timesteps of the experiment')
parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='if False, `torch.backends.cudnn.deterministic=False`')
parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='if True, cuda will be enabled when possible')
parser.add_argument('--log-train', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='if True, training will be logged with Tensorboard')

# Performance altering
parser.add_argument('--num-steps', type=int, default=10,
help='number of steps per environment per rollout')
args = parser.parse_args()
return args

selected_galaxies = np.load('selected_galaxies.npy')
selected_galaxies = np.mean(selected_galaxies, axis=-1, keepdims=True)
B, y, x, c = selected_galaxies.shape
selected_galaxies = np.reshape(selected_galaxies, (B, c, y, x))
selected_galaxies_min, selected_galaxies_max = selected_galaxies.min(axis=(-1,-2), keepdims=True), selected_galaxies.max(axis=(-1,-2), keepdims=True)
selected_galaxies = (selected_galaxies - selected_galaxies_min) / (selected_galaxies_max - selected_galaxies_min)

if __name__ == '__main__':
args = parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
print(f'[AGENT] Seed set to {args.seed}')

device = torch.device('cuda' if torch.cuda.is_available and args.cuda else 'cpu')
if args.log_train:
writer = tensorboard.SummaryWriter(f'runs/{args.exp_name}')
writer.add_text(
'hyperparameters',
'|param|value|\n|-|-|\n%s'%('\n'.join([f'|{key}|{value}' for key, value in vars(args).items()])),
)
def make_env(seed):
def thunk():
env = Source(hyperparameters={
'B':47,
'image_x':256,
'image_y':256,
'image_c':1,
'seed':seed,
'cuda':args.cuda,
})
return env
return thunk

env = make_env(seed=args.seed)()

model = CNN(env.image_x, env.image_y, env.image_c, env.low.shape[1]).to(device)
delta = 1e-6
opt = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
episodic_loss = 0

temp = []
with tqdm(range(int(args.total_timesteps)), desc=f'episodic_reward: {episodic_loss}') as progress:
best_actions = None
best_loss = np.inf

for i in range(int(args.total_timesteps)):
done = False
state, _ = env.reset(selected_galaxies)

params_list = []
sersics = torch.zeros((args.num_steps, env.B, env.image_c, env.image_y, env.image_x)).to(device)
j=0
while j < args.num_steps:
# gathering rollout data
action = model(state)
state, reward, done, _, info = env.step(action)
sersics[j] = info['source']
state = state.detach()
params_list.append(action.detach().cpu().numpy())
if torch.all(done):
break
j += 1
y_pred = torch.sum(sersics, dim=0)
y_pred_flat = y_pred.view(env.B, -1)
y_pred_min, _ = y_pred_flat.min(dim=-1, keepdim=True)
y_pred_max, _ = y_pred_flat.max(dim=-1, keepdim=True)
y_pred_min, y_pred_max = y_pred_min.view(env.B, 1, 1, 1), y_pred_max.view(env.B, 1, 1, 1)
y_pred = (y_pred - y_pred_min) / (y_pred_max - y_pred_min + delta)
y_labels = env.source_labels
loss = torch.nn.functional.mse_loss(y_labels, y_pred)
opt.zero_grad()
loss.backward()
opt.step()

episodic_loss = loss.detach().cpu().numpy()
if i > 4200:
temp.append(params_list)
progress.set_description(f'episodic_loss: {episodic_loss}')
progress.update()
if args.log_train:
writer.add_scalar("loss/actor_loss", episodic_loss, global_step=i)

if episodic_loss < best_loss:
best_loss = episodic_loss
best_actions = np.array(params_list)
np.save(f'best_actions_{args.exp_name}', best_actions)
print('Best_loss',best_loss)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from gymnasium.envs.registration import register

register(
id="Source-v0",
entry_point="lensing_envs.lensing_envs:Source",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F

class Source():
"""
Analytic source construction module that creates compositions of Sérsics

Arguments
---------
hyperparameters: dict
Dictionary of hyperparameters in a dictionary

Methods
-------
reset(source_labels)
Resets the environment with the given source labels and returns it as observations
step(action, delta=1e-4)
Steps the environment with the given action (Sérsic parameters) to return the new set of observations, the reward, and the done value
_init_hyperparameters(hyperparameters)
Initialises hyperparameters in the environment
_process_actions(action)
Scales the actions to match the environment's bounds
_add_sersic(sersic_params)
Helper function that adds Sérsics to the current collection, called when stepping the environment
"""
def __init__(self, hyperparameters):
self._init_hyperparameters(hyperparameters)
# r_e is the radius that encloses half of the total light of the galaxy, set here between 0 and 1 (scaled to image size)
# n [0.1, 15], r_e [0, 1], q [0,1], theta [0, 2pi], center_x [0,1], center_y [0,1]
self.low = torch.reshape(torch.tensor([0., 0., 0., 0., -1., -1.,]), (1, -1)) * torch.ones((self.B, 1))
self.high = torch.reshape(torch.tensor([15., 1., 1., 2*torch.pi, 1., 1.]), (1, -1)) * torch.ones((self.B, 1))
self.low, self.high = self.low.to(self.device), self.high.to(self.device)
self.image_arcsec_bounds = torch.tensor([self.image_x * self.arcsec_per_pixel / 2, self.image_y * self.arcsec_per_pixel / 2])
pos_x = torch.linspace(-self.image_arcsec_bounds[0], self.image_arcsec_bounds[0], self.image_x).to(self.device)
pos_y = torch.linspace(-self.image_arcsec_bounds[1], self.image_arcsec_bounds[1], self.image_y).to(self.device)
phi_y, phi_x = torch.meshgrid(pos_x, pos_y)
self.phi_y, self.phi_x = torch.reshape(phi_y, (1, 1, self.image_y, self.image_x)) * torch.ones((self.B, self.image_c, self.image_y, self.image_x)).to(self.device), torch.reshape(phi_x, (1, 1, self.image_x, self.image_x)) * torch.ones((self.B, self.image_c, self.image_y, self.image_x)).to(self.device)
self.phi_y, self.phi_x = self.phi_y.to(self.device), self.phi_x.to(self.device)

def reset(self, source_labels):
self.source_labels = torch.tensor(source_labels, device=self.device, dtype=torch.float32)
self.constructed_sersic = torch.zeros(self.B, self.image_c, self.image_y, self.image_x).to(self.device)
return self.source_labels, self._get_info()

def step(self, action, delta=1e-4):
action = self._process_actions(action)
I = self._add_sersic(action)
self.constructed_sersic += I
source_diff = self.source_labels - self.constructed_sersic
reward = -torch.mean((source_diff)**2, dim=(-3, -2, -1))
reward_x = -torch.mean((torch.sum(source_diff, dim=-1)/self.image_x)**2, dim=(-2, -1))
reward_y = -torch.mean((torch.sum(source_diff, dim=-2)/self.image_y)**2, dim=(-2, -1))
reward += reward_x + reward_y
done = torch.mean(torch.abs(source_diff), dim=(-3, -2, -1)) < delta
if self.render_mode == 'source':
return source_diff, reward, done, False, {'source':I, 'sersics':self.constructed_sersic, 'labels':self.source_labels}
return source_diff, 100*reward, done, False, {}

def _init_hyperparameters(self, hyperparameters):
self.B = 1
self.image_x, self.image_y, self.image_c = 1, 1, 1
self.seed = 0
self.cuda = False
self.arcsec_per_pixel = 0.001
self.num_sersics = 10
self.render_mode = 'source'
for param, val in hyperparameters.items():
exec('self.' + param + ' = ' + '%s'%val)

self.device = torch.device('cuda' if self.cuda and torch.cuda.is_available else 'cpu')
print(f'[ENV] Using {self.device}')

if self.seed != None:
assert(type(self.seed) == int)
torch.manual_seed(self.seed)
np.random.seed(self.seed)
print(f"[ENV] Seed set to {self.seed}")

def _get_info(self):
return {}

def _process_actions(self, action):
# actions given in R
squashed_action = torch.nn.functional.tanh(action) # in [-1,1]
action = 0.5 * (squashed_action + 1) # in [0,1]
action = torch.clip(action * (self.high - self.low) + self.low, self.low + 1e-1, self.high)
return action

def _add_sersic(self, sersic_params):
delta = 1e-6
def _sersic_law(R, r_e, b_n, n):
return torch.exp(-b_n * ((R / (r_e + delta)) ** (1/(n + delta)) - 1))
n, r_e_, q, theta, centers_x, centers_y = torch.tensor_split(sersic_params, 6, dim=-1) # (B, 1,)* on all
centers_x, centers_y = centers_x*self.image_arcsec_bounds[0], centers_y*self.image_arcsec_bounds[1] # (B, 1)*
b_n = 2 * n - 0.331 # (B, 1)*

r_e = torch.linalg.norm(self.image_arcsec_bounds) * r_e_
r_e = r_e.view(self.B, 1, 1, 1)
q = q.view(self.B, 1, 1, 1)
n = n.view(self.B, 1, 1, 1)
theta = theta.view(self.B, 1, 1, 1)
b_n = b_n.view(self.B, 1, 1, 1) # (B, 1, 1, 1) on all
centers_x, centers_y = centers_x.view(self.B, 1, 1, 1), centers_y.view(self.B, 1, 1, 1)

phi_y, phi_x = self.phi_y - centers_y, self.phi_x - centers_x # shifted coordinates in pixel scale (B, c, y, x,)*
cos_theta, sin_theta = torch.cos(theta), torch.sin(theta) # (B, 1, 1, 1)*
x_rot, y_rot = phi_x * cos_theta + phi_y * sin_theta, -phi_x * sin_theta + phi_y * cos_theta # rotated coordinates in (fractional) pixel scale (B, c, y, x)
R = torch.sqrt((x_rot**2)/q + q*(y_rot**2)) * self.arcsec_per_pixel # (B, c, y, x,) in arcsec
R[R==0] = 1e-6
I = _sersic_law(R, r_e, b_n, n) # (B, c, y, x,)
I_flat = I.view(self.B, -1)
I_min, _ = I_flat.min(dim=-1, keepdim=True)
I_max, _ = I_flat.max(dim=-1, keepdim=True)
I_min, I_max = I_min.view(self.B, 1, 1, 1), I_max.view(self.B, 1, 1, 1)
return (I - I_min) / (I_max - I_min + delta)
Loading