Skip to content

Commit 465c625

Browse files
committed
b,ack
1 parent f2ad520 commit 465c625

File tree

8 files changed

+100
-66
lines changed

8 files changed

+100
-66
lines changed

autolens/analysis/result.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def image_plane_multiple_image_positions(self) -> aa.Grid2DIrregular:
109109

110110
return aa.Grid2DIrregular(values=multiple_images)
111111

112-
def image_plane_multiple_image_positions_for_single_image_from(self, increments : int = 20) -> aa.Grid2DIrregular:
112+
def image_plane_multiple_image_positions_for_single_image_from(
113+
self, increments: int = 20
114+
) -> aa.Grid2DIrregular:
113115
"""
114116
If the standard point solver only locates one multiple image, finds one or more additional images, which are
115117
not technically multiple image in the point source regime, but are close enough to it they can be used
@@ -133,12 +135,14 @@ def image_plane_multiple_image_positions_for_single_image_from(self, increments
133135
The number of increments the source-plane centre is moved to compute multiple images.
134136
"""
135137

136-
logger.info("""
138+
logger.info(
139+
"""
137140
Could not find multiple images for maximum likelihood lens model.
138141
139142
Incrementally moving source centre inwards towards centre of source-plane until caustic crossing occurs
140143
and multiple images are formed.
141-
""")
144+
"""
145+
)
142146

143147
grid = self.analysis.dataset.mask.derive_grid.all_false
144148

@@ -150,8 +154,7 @@ def image_plane_multiple_image_positions_for_single_image_from(self, increments
150154
)
151155

152156
for i in range(1, increments):
153-
154-
factor = 1.0 - (1.0 * (i/increments))
157+
factor = 1.0 - (1.0 * (i / increments))
155158

156159
multiple_images = solver.solve(
157160
tracer=self.max_log_likelihood_tracer,

autolens/lens/sensitivity.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@ def set_auto_filename(
361361
return False
362362

363363
def subplot_sensitivity(self):
364-
365364
log_likelihoods = self.result.figure_of_merit_array(
366365
use_log_evidences=False,
367366
remove_zeros=True,
@@ -375,7 +374,7 @@ def subplot_sensitivity(self):
375374
except TypeError:
376375
log_evidences = np.zeros_like(log_likelihoods)
377376

378-
self.open_subplot_figure(number_subplots=8, subplot_shape=(2,4))
377+
self.open_subplot_figure(number_subplots=8, subplot_shape=(2, 4))
379378

380379
plotter = aplt.Array2DPlotter(
381380
array=self.data_subtracted,
@@ -398,10 +397,7 @@ def subplot_sensitivity(self):
398397

399398
above_threshold = np.where(log_likelihoods > 5.0, 1.0, 0.0)
400399

401-
above_threshold = aa.Array2D(
402-
values=above_threshold,
403-
mask=log_likelihoods.mask
404-
)
400+
above_threshold = aa.Array2D(values=above_threshold, mask=log_likelihoods.mask)
405401

406402
self.mat_plot_2d.plot_array(
407403
array=above_threshold,
@@ -410,16 +406,32 @@ def subplot_sensitivity(self):
410406
)
411407

412408
try:
413-
log_evidences_base = self.result._array_2d_from(self.result.log_evidences_base)
414-
log_evidences_perturbed = self.result._array_2d_from(self.result.log_evidences_perturbed)
409+
log_evidences_base = self.result._array_2d_from(
410+
self.result.log_evidences_base
411+
)
412+
log_evidences_perturbed = self.result._array_2d_from(
413+
self.result.log_evidences_perturbed
414+
)
415415

416-
log_evidences_base_min = np.nanmin(np.where(log_evidences_base == 0, np.nan, log_evidences_base))
417-
log_evidences_base_max = np.nanmax(np.where(log_evidences_base == 0, np.nan, log_evidences_base))
418-
log_evidences_perturbed_min = np.nanmin(np.where(log_evidences_perturbed == 0, np.nan, log_evidences_perturbed))
419-
log_evidences_perturbed_max = np.nanmax(np.where(log_evidences_perturbed == 0, np.nan, log_evidences_perturbed))
416+
log_evidences_base_min = np.nanmin(
417+
np.where(log_evidences_base == 0, np.nan, log_evidences_base)
418+
)
419+
log_evidences_base_max = np.nanmax(
420+
np.where(log_evidences_base == 0, np.nan, log_evidences_base)
421+
)
422+
log_evidences_perturbed_min = np.nanmin(
423+
np.where(log_evidences_perturbed == 0, np.nan, log_evidences_perturbed)
424+
)
425+
log_evidences_perturbed_max = np.nanmax(
426+
np.where(log_evidences_perturbed == 0, np.nan, log_evidences_perturbed)
427+
)
420428

421-
self.mat_plot_2d.cmap.kwargs["vmin"] = np.min([log_evidences_base_min, log_evidences_perturbed_min])
422-
self.mat_plot_2d.cmap.kwargs["vmax"] = np.max([log_evidences_base_max, log_evidences_perturbed_max])
429+
self.mat_plot_2d.cmap.kwargs["vmin"] = np.min(
430+
[log_evidences_base_min, log_evidences_perturbed_min]
431+
)
432+
self.mat_plot_2d.cmap.kwargs["vmax"] = np.max(
433+
[log_evidences_base_max, log_evidences_perturbed_max]
434+
)
423435

424436
self.mat_plot_2d.plot_array(
425437
array=log_evidences_base,
@@ -431,21 +443,36 @@ def subplot_sensitivity(self):
431443
array=log_evidences_perturbed,
432444
visuals_2d=self.visuals_2d,
433445
auto_labels=AutoLabels(title="Log Evidence Perturb"),
434-
435446
)
436447
except TypeError:
437448
pass
438-
439-
log_likelihoods_base = self.result._array_2d_from(self.result.log_likelihoods_base)
440-
log_likelihoods_perturbed = self.result._array_2d_from(self.result.log_likelihoods_perturbed)
441449

442-
log_likelihoods_base_min = np.nanmin(np.where(log_likelihoods_base == 0, np.nan, log_likelihoods_base))
443-
log_likelihoods_base_max = np.nanmax(np.where(log_likelihoods_base == 0, np.nan, log_likelihoods_base))
444-
log_likelihoods_perturbed_min = np.nanmin(np.where(log_likelihoods_perturbed == 0, np.nan, log_likelihoods_perturbed))
445-
log_likelihoods_perturbed_max = np.nanmax(np.where(log_likelihoods_perturbed == 0, np.nan, log_likelihoods_perturbed))
450+
log_likelihoods_base = self.result._array_2d_from(
451+
self.result.log_likelihoods_base
452+
)
453+
log_likelihoods_perturbed = self.result._array_2d_from(
454+
self.result.log_likelihoods_perturbed
455+
)
456+
457+
log_likelihoods_base_min = np.nanmin(
458+
np.where(log_likelihoods_base == 0, np.nan, log_likelihoods_base)
459+
)
460+
log_likelihoods_base_max = np.nanmax(
461+
np.where(log_likelihoods_base == 0, np.nan, log_likelihoods_base)
462+
)
463+
log_likelihoods_perturbed_min = np.nanmin(
464+
np.where(log_likelihoods_perturbed == 0, np.nan, log_likelihoods_perturbed)
465+
)
466+
log_likelihoods_perturbed_max = np.nanmax(
467+
np.where(log_likelihoods_perturbed == 0, np.nan, log_likelihoods_perturbed)
468+
)
446469

447-
self.mat_plot_2d.cmap.kwargs["vmin"] = np.min([log_likelihoods_base_min, log_likelihoods_perturbed_min])
448-
self.mat_plot_2d.cmap.kwargs["vmax"] = np.max([log_likelihoods_base_max, log_likelihoods_perturbed_max])
470+
self.mat_plot_2d.cmap.kwargs["vmin"] = np.min(
471+
[log_likelihoods_base_min, log_likelihoods_perturbed_min]
472+
)
473+
self.mat_plot_2d.cmap.kwargs["vmax"] = np.max(
474+
[log_likelihoods_base_max, log_likelihoods_perturbed_max]
475+
)
449476

450477
self.mat_plot_2d.plot_array(
451478
array=log_likelihoods_base,

autolens/point/dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,10 @@ def info(self) -> str:
8484
info += f"fluxes_noise_map : {self.fluxes_noise_map}\n"
8585
return info
8686

87-
def extent_from(self, buffer : float = 0.1):
88-
87+
def extent_from(self, buffer: float = 0.1):
8988
y_max = max(self.positions[:, 0]) + buffer
9089
y_min = min(self.positions[:, 0]) - buffer
9190
x_max = max(self.positions[:, 1]) + buffer
9291
x_min = min(self.positions[:, 1]) - buffer
9392

94-
return [y_min, y_max, x_min, x_max]
93+
return [y_min, y_max, x_min, x_max]

autolens/point/model/analysis.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525

2626
class AnalysisPoint(AgAnalysis, AnalysisLens):
27-
2827
Visualizer = VisualizerPoint
2928
Result = ResultPoint
3029

autolens/point/model/plotter_interface.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212

1313
class PlotterInterfacePoint(PlotterInterface):
14-
1514
def dataset_point(self, dataset: PointDataset):
1615
"""
1716
Output visualization of an `PointDataset` dataset, typically before a model-fit is performed.
@@ -54,7 +53,10 @@ def should_plot(name):
5453
dataset_plotter.subplot_dataset()
5554

5655
def fit_point(
57-
self, fit: FitPointDataset, during_analysis: bool, subfolders: str = "fit_dataset"
56+
self,
57+
fit: FitPointDataset,
58+
during_analysis: bool,
59+
subfolders: str = "fit_dataset",
5860
):
5961
"""
6062
Visualizes a `FitPointDataset` object, which fits an imaging dataset.
@@ -104,7 +106,6 @@ def should_plot(name):
104106
fit_plotter.subplot_fit()
105107

106108
if not during_analysis and should_plot("all_at_end_png"):
107-
108109
mat_plot_2d = self.mat_plot_2d_from(
109110
subfolders=path.join("fit_dataset", "end"),
110111
)
@@ -113,7 +114,4 @@ def should_plot(name):
113114
fit=fit, mat_plot_2d=mat_plot_2d, include_2d=self.include_2d
114115
)
115116

116-
fit_plotter.figures_2d(
117-
positions=True,
118-
fluxes=True
119-
)
117+
fit_plotter.figures_2d(positions=True, fluxes=True)

autolens/point/model/visualizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66

77
class VisualizerPoint(af.Visualizer):
8-
98
@staticmethod
109
def visualize_before_fit(
1110
analysis,
@@ -76,7 +75,9 @@ def visualize(
7675

7776
tracer = fit.tracer
7877

79-
grid = ag.Grid2D.from_extent(extent=fit.dataset.extent_from(), shape_native=(100, 100))
78+
grid = ag.Grid2D.from_extent(
79+
extent=fit.dataset.extent_from(), shape_native=(100, 100)
80+
)
8081

8182
plotter_interface.tracer(
8283
tracer=tracer, grid=grid, during_analysis=during_analysis

autolens/point/plot/fit_point_plotters.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,36 @@ def figures_2d(self, positions: bool = False, fluxes: bool = False):
4141
)
4242

4343
if self.mat_plot_2d.axis.kwargs.get("extent") is None:
44-
4544
buffer = 0.1
4645

47-
y_max = max(
48-
max(self.fit.dataset.positions[:, 0]),
49-
max(self.fit.positions.model_data[:, 0]),
50-
) + buffer
51-
y_min = min(
52-
min(self.fit.dataset.positions[:, 0]),
53-
min(self.fit.positions.model_data[:, 0]),
54-
) - buffer
55-
x_max = max(
56-
max(self.fit.dataset.positions[:, 1]),
57-
max(self.fit.positions.model_data[:, 1]),
58-
) + buffer
59-
x_min = min(
60-
min(self.fit.dataset.positions[:, 1]),
61-
min(self.fit.positions.model_data[:, 1]),
62-
) - buffer
46+
y_max = (
47+
max(
48+
max(self.fit.dataset.positions[:, 0]),
49+
max(self.fit.positions.model_data[:, 0]),
50+
)
51+
+ buffer
52+
)
53+
y_min = (
54+
min(
55+
min(self.fit.dataset.positions[:, 0]),
56+
min(self.fit.positions.model_data[:, 0]),
57+
)
58+
- buffer
59+
)
60+
x_max = (
61+
max(
62+
max(self.fit.dataset.positions[:, 1]),
63+
max(self.fit.positions.model_data[:, 1]),
64+
)
65+
+ buffer
66+
)
67+
x_min = (
68+
min(
69+
min(self.fit.dataset.positions[:, 1]),
70+
min(self.fit.positions.model_data[:, 1]),
71+
)
72+
- buffer
73+
)
6374

6475
extent = [y_min, y_max, x_min, x_max]
6576

test_autolens/point/model/test_plotter_interface_point.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,16 @@ def make_plotter_interface_plotter_setup():
1313
return path.join("{}".format(directory), "files")
1414

1515

16-
def test__fit_point(
17-
fit_point_dataset_x2_plane, include_2d_all, plot_path, plot_patch
18-
):
16+
def test__fit_point(fit_point_dataset_x2_plane, include_2d_all, plot_path, plot_patch):
1917
if os.path.exists(plot_path):
2018
shutil.rmtree(plot_path)
2119

2220
plotter_interface = PlotterInterfacePoint(image_path=plot_path)
2321

24-
plotter_interface.fit_point(
25-
fit=fit_point_dataset_x2_plane, during_analysis=False
26-
)
22+
plotter_interface.fit_point(fit=fit_point_dataset_x2_plane, during_analysis=False)
2723

2824
assert path.join(plot_path, "subplot_fit.png") in plot_patch.paths
2925

3026
plot_path = path.join(plot_path, "fit_dataset")
3127

32-
assert path.join(plot_path, "fit_point_positions.png") in plot_patch.paths
28+
assert path.join(plot_path, "fit_point_positions.png") in plot_patch.paths

0 commit comments

Comments
 (0)