Skip to content

Commit

Permalink
add corner plot function to visualization module
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert Morgan committed Mar 3, 2021
1 parent 50ae579 commit 9a88fcb
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions deeplenstronomy/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from astropy.visualization import make_lupton_rgb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

def _no_stretch(val):
return val
Expand Down Expand Up @@ -64,3 +65,54 @@ def view_image_rgb(images, Q=2.0, stretch=4.0, **imshow_kwargs):
plt.close()

return

def view_corner(metadata, labels, hist_kwargs={}, hist2d_kwargs={}, label_kwargs={}):
"""
Show a corner plot of the columns in a DataFrame.
Args:
metadata (pd.DataFrame): A pandas DataFrame containing the metadata to visualize
labels (dict): A dictionary mapping column names to axis labels
hist_kwargs (dict): keyword arguments to pass to matplotlib.pyplot.hist
hist2d_kwargs (dict): keyword arguments to pass to matplotlib.pyplot.hist2d
label_kwargs (dict): keyword arguments to pass to matplotlib.axes.Axes.set_xlabel (and ylabel)
Raises:
KeyError: if one or more of the columns are not present in the metadata
TypeError: if metadata is not a pandas DataFrame
TypeError: if labels is not a dict
"""
if not isinstance(metadata, pd.DataFrame):
raise TypeError("first argument must be a pandas DataFrame")

if not isinstance(labels, dict):
raise TypeError("second argument must be a list")

if any([x not in metadata.columns for x in labels]):
raise KeyError("One or more passed columns is not present in the metadata")

fig, axs = plt.subplots(len(labels), len(labels), figsize=(14,14))

for row, row_label in enumerate(labels.keys()):
for col, col_label in enumerate(labels.keys()):

if row == col:
# hist
axs[row, col].hist(metadata[row_label].values, **hist_kwargs)

elif row > col:
# hist2d
axs[row, col].hist2d(metadata[col_label].values,
metadata[row_label].values, **hist2d_kwargs)
else:
axs[row, col].set_visible(False)

if row == len(labels) -1:
axs[row, col].set_xlabel(labels[col_label], **label_kwargs)

if col == 0 and row != 0:
axs[row, col].set_ylabel(labels[row_label], **label_kwargs)

fig.tight_layout()
plt.show()
plt.close()

0 comments on commit 9a88fcb

Please sign in to comment.