Skip to content

Commit

Permalink
Update display_image.py for PEP8
Browse files Browse the repository at this point in the history
changes for PEP8 compiance
  • Loading branch information
bjkuhn authored Nov 8, 2023
1 parent b81b1a5 commit e060237
Showing 1 changed file with 37 additions and 69 deletions.
106 changes: 37 additions & 69 deletions notebooks/WFC3/exception_report/docs/display_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#! /usr/bin/env python

import sys

from astropy.io import fits
Expand All @@ -15,10 +14,10 @@ def display_image(filename,
ima_multiread=False,
figsize=(18, 18),
dpi=200):

""" A function to display the 'SCI', 'ERR/WHT', and 'DQ/CTX' arrays
of any WFC3 fits image. This function returns nothing, but will display
the requested image on the screen when called.
"""
A function to display the 'SCI', 'ERR/WHT', and 'DQ/CTX' arrays
of any WFC3 fits image. This function returns nothing, but will display
the requested image on the screen when called.
Authors
-------
Expand Down Expand Up @@ -109,7 +108,7 @@ def display_image(filename,
print("Invalid image section specified")
return 0, 0
try:
xstart = int(xsec[: xs])
xstart = int(xsec[:xs])
except ValueError:
print("Problem getting xstart")
return
Expand All @@ -133,7 +132,6 @@ def display_image(filename,
print("Problem getting yend")
return

bunit = get_bunit(h1)
detector = h['detector']
issubarray = h['subarray']
si = h['primesi']
Expand All @@ -152,15 +150,14 @@ def display_image(filename,
print('-'*44)
print(f"Filter = {h['filter']}, Date-Obs = {h['date-obs']} T{h['time-obs']},\nTarget = {h['targname']}, Exptime = {h['exptime']}, Subarray = {issubarray}, Units = {h1['bunit']}\n")


if detector == 'UVIS':
if ima_multiread == True:
if ima_multiread is True:
sys.exit("keyword argument 'ima_multiread' can only be set to True for 'ima.fits' files")
try:
if all_pixels:
xstart = 0
ystart = 0
xend = naxis1 # full x size
xend = naxis1 # full x size
yend = naxis2*2 # full y size

with fits.open(imagename) as hdu:
Expand All @@ -177,7 +174,7 @@ def display_image(filename,
fullerr = np.concatenate([uvis2_err, uvis1_err])

fullsci = fullsci[ystart:yend, xstart:xend]
fulldq = fulldq[ystart:yend, xstart:xend]
fulldq = fulldq[ystart:yend, xstart:xend]
fullerr = fullerr[ystart:yend, xstart:xend]

make1x3plot(scaling, colormaps, fullsci, fullerr, fulldq,
Expand All @@ -202,10 +199,10 @@ def display_image(filename,
except (IndexError, KeyError):

if all_pixels:
xstart = 0
ystart = 0
xend = naxis1 # full x size
yend = naxis2 # full y size
xstart = 0
ystart = 0
xend = naxis1 # full x size
yend = naxis2 # full y size

with fits.open(imagename) as hdu:
uvis_ext1 = hdu[1].data
Expand Down Expand Up @@ -233,15 +230,14 @@ def display_image(filename,
ax1.set_title(f"WFC3/{detector} {fname} {h1['extname']} ext")
fig.colorbar(im1, ax=ax1, shrink=.75, pad=.03)


if detector == 'IR' and '_ima.fits' not in fname:
if ima_multiread == True:
if ima_multiread is True:
sys.exit("keyword argument 'ima_multiread' can only be set to True for 'ima.fits' files")
if all_pixels:
xstart = 0
ystart = 0
xend = naxis1 # full x size
yend = naxis2 # full y size
xend = naxis1 # full x size
yend = naxis2 # full y size

try:
with fits.open(imagename) as hdu:
Expand All @@ -251,7 +247,7 @@ def display_image(filename,

data_sci = data_sci[ystart:yend, xstart:xend]
data_err = data_err[ystart:yend, xstart:xend]
data_dq = data_dq[ystart:yend, xstart:xend]
data_dq = data_dq[ystart:yend, xstart:xend]

make1x3plot(scaling, colormaps, data_sci, data_err, data_dq,
xstart, xend, ystart, yend,
Expand All @@ -268,78 +264,49 @@ def display_image(filename,
ax1.set_title(f"WFC3/{detector} {fname} {h1['extname']} ext")
fig.colorbar(im1, ax=ax1,shrink=.75, pad=.03)


if '_ima.fits' in fname:
if all_pixels:
xstart = 0
ystart = 0
xend = naxis1 # full x size
yend = naxis2 # full y size
xend = naxis1 # full x size
yend = naxis2 # full y size

if ima_multiread == True:
if ima_multiread is True:
nsamps = h['NSAMP']
for ext in reversed(range(1,nsamps+1)):
with fits.open(imagename) as hdu:
data_sci = hdu['SCI', ext].data
data_err = hdu['ERR', ext].data
data_dq = hdu['DQ', ext].data
data_dq = hdu['DQ', ext].data

data_sci = data_sci[ystart:yend, xstart:xend]
data_err = data_err[ystart:yend, xstart:xend]
data_dq = data_dq[ystart:yend, xstart:xend]
data_dq = data_dq[ystart:yend, xstart:xend]

makeIR1x3plot(scaling, colormaps, data_sci, data_err, data_dq,
xstart, xend, ystart, yend,
detector, fname, h1, h2, h3, nsamps, ext,
figsize, dpi)

if ima_multiread == False:
if ima_multiread is False:
with fits.open(imagename) as hdu:
data_sci = hdu['SCI', 1].data
data_err = hdu['ERR', 1].data
data_dq = hdu['DQ', 1].data
data_dq = hdu['DQ', 1].data

data_sci = data_sci[ystart:yend, xstart:xend]
data_err = data_err[ystart:yend, xstart:xend]
data_dq = data_dq[ystart:yend, xstart:xend]
data_dq = data_dq[ystart:yend, xstart:xend]

make1x3plot(scaling, colormaps, data_sci, data_err, data_dq,
xstart, xend, ystart, yend,
detector, fname, h1, h2, h3,
figsize, dpi)


def get_bunit(ext1header):
""" Get the brightness unit for the plot axis label.
Parameters
----------
ext1header: Header
The extension 1 header of the fits file being displayed. This is the
extension that contains the brightness unit keyword.
Returns
-------
The string of the brightness unit for the axis label
{'counts', 'counts/s','e$^-$', 'e$^-$/s'}
"""
units = ext1header['bunit']

if units == 'COUNTS':
return 'counts'
elif units == 'COUNTS/S':
return 'counts/s'
elif units == 'ELECTRONS':
return 'e$^-$'
elif units == 'ELECTRONS/S':
return 'e$^-$/s'
else:
return units


def get_scale_limits(scaling, array, extname):
""" Get the scale limits to use for the image extension being displayed.
"""
Get the scale limits to use for the image extension being displayed.
Parameters
----------
Expand Down Expand Up @@ -368,28 +335,28 @@ def get_scale_limits(scaling, array, extname):
"""
if extname == 'DQ':
if scaling[0] == None and scaling[1] == None:
if scaling[0] is None and scaling[1] is None:
z1, z2 = array.min(), array.max()
elif scaling[0] == None and scaling[1] != None:
elif scaling[0] is None and scaling[1] is not None:
z1 = array.min()
z2 = scaling[1]
elif scaling[0] != None and scaling[1] == None:
elif scaling[0] is not None and scaling[1] is None:
z1 = scaling[0]
z2 = array.max()
elif scaling[0] != None and scaling[1] != None:
elif scaling[0] is not None and scaling[1] is not None:
z1 = scaling[0]
z2 = scaling[1]

elif extname == 'SCI' or extname == 'ERR':
if scaling[0] == None and scaling[1] == None:
if scaling[0] is None and scaling[1] is None:
z1, z2 = zscale.zscale(array)
elif scaling[0] == None and scaling[1] != None:
elif scaling[0] is None and scaling[1] is not None:
z1 = zscale.zscale(array)[0]
z2 = scaling[1]
elif scaling[0] != None and scaling[1] == None:
elif scaling[0] is not None and scaling[1] is None:
z1 = scaling[0]
z2 = zscale.zscale(array)[1]
elif scaling[0] != None and scaling[1] != None:
elif scaling[0] is not None and scaling[1] is not None:
z1 = scaling[0]
z2 = scaling[1]
else:
Expand Down Expand Up @@ -480,7 +447,7 @@ def make1x3plot(scaling, colormaps, fullsci, fullerr, fulldq,

z1_sci, z2_sci = get_scale_limits(scaling[0], fullsci, 'SCI')
z1_err, z2_err = get_scale_limits(scaling[1], fullerr, 'ERR')
z1_dq, z2_dq = get_scale_limits(scaling[2], fulldq, 'DQ')
z1_dq, z2_dq = get_scale_limits(scaling[2], fulldq, 'DQ')

fig, [ax1, ax2, ax3] = plt.subplots(1, 3, figsize=figsize, dpi=dpi)

Expand All @@ -500,6 +467,7 @@ def make1x3plot(scaling, colormaps, fullsci, fullerr, fulldq,
fig.colorbar(im2, ax=ax2, shrink=.25, pad=.03)
fig.colorbar(im3, ax=ax3, shrink=.25, pad=.03)


def makeIR1x3plot(scaling, colormaps, data_sci, data_err, data_dq,
xstart, xend, ystart, yend,
detector, fname, h1, h2, h3, nsamps, ext,
Expand Down Expand Up @@ -587,7 +555,7 @@ def makeIR1x3plot(scaling, colormaps, data_sci, data_err, data_dq,

z1_sci, z2_sci = get_scale_limits(scaling[0], data_sci, 'SCI')
z1_err, z2_err = get_scale_limits(scaling[1], data_err, 'ERR')
z1_dq, z2_dq = get_scale_limits(scaling[2], data_dq, 'DQ')
z1_dq, z2_dq = get_scale_limits(scaling[2], data_dq, 'DQ')

fig, [ax1, ax2, ax3] = plt.subplots(1, 3, figsize=figsize, dpi=dpi)
im1 = ax1.imshow(data_sci, origin='lower', extent=(xstart, xend, ystart, yend), cmap=colormaps[0], vmin=z1_sci, vmax=z2_sci)
Expand Down

0 comments on commit e060237

Please sign in to comment.