Skip to content

Commit

Permalink
rename percents to operation progress, refactor get_hdu to work on fi…
Browse files Browse the repository at this point in the history
…ts files
  • Loading branch information
LTDakin committed Sep 30, 2024
1 parent 1e66ff3 commit 4f91efc
Showing 15 changed files with 51 additions and 49 deletions.
5 changes: 4 additions & 1 deletion datalab/datalab_session/analysis/line_profile.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
from astropy import coordinates

from datalab.datalab_session.file_utils import scale_points, get_hdu
from datalab.datalab_session.s3_utils import get_fits

# For creating an array of brightness along a user drawn line
def line_profile(input: dict):
@@ -19,7 +20,9 @@ def line_profile(input: dict):
y2 (int): The y coordinate of the ending point
}
"""
sci_hdu = get_hdu(input['basename'], 'SCI')
fits_path = get_fits(input['basename'])

sci_hdu = get_hdu(fits_path, 'SCI')

x_points, y_points = scale_points(input["height"], input["width"], sci_hdu.data.shape[0], sci_hdu.data.shape[1], x_points=[input["x1"], input["x2"]], y_points=[input["y1"], input["y2"]])

7 changes: 5 additions & 2 deletions datalab/datalab_session/analysis/source_catalog.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import numpy as np

from datalab.datalab_session.file_utils import get_hdu, scale_points
from datalab.datalab_session.s3_utils import get_fits

def source_catalog(input: dict):
"""
Returns a dict representing the source catalog data with x,y coordinates and flux values
"""
cat_hdu = get_hdu(input['basename'], 'CAT')
sci_hdu = get_hdu(input['basename'], 'SCI')
fits_path = get_fits(input['basename'])

cat_hdu = get_hdu(fits_path, 'CAT')
sci_hdu = get_hdu(fits_path, 'SCI')

# The number of sources to send back to the frontend, default 50
SOURCE_CATALOG_COUNT = min(50, len(cat_hdu.data["x"]))
22 changes: 10 additions & 12 deletions datalab/datalab_session/data_operations/data_operation.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from django.core.cache import cache
import numpy as np

from datalab.datalab_session.s3_utils import get_fits
from datalab.datalab_session.tasks import execute_data_operation
from datalab.datalab_session.file_utils import get_hdu

@@ -57,7 +58,7 @@ def perform_operation(self):
status = self.get_status()
if status == 'PENDING' or status == 'FAILED':
self.set_status('IN_PROGRESS')
self.set_percent_completion(0.0)
self.set_operation_progress(0.0)
# This asynchronous task will call the operate() method on the proper operation
execute_data_operation.send(self.name(), self.input_data)

@@ -78,15 +79,15 @@ def set_message(self, message: str):
def get_message(self) -> str:
return cache.get(f'operation_{self.cache_key}_message', '')

def set_percent_completion(self, percent_completed: float):
cache.set(f'operation_{self.cache_key}_percent_completion', percent_completed, CACHE_DURATION)
def set_operation_progress(self, percent_completed: float):
cache.set(f'operation_{self.cache_key}_progress', percent_completed, CACHE_DURATION)

def get_percent_completion(self) -> float:
return cache.get(f'operation_{self.cache_key}_percent_completion', 0.0)
def get_operation_progress(self) -> float:
return cache.get(f'operation_{self.cache_key}_progress', 0.0)

def set_output(self, output_data: dict):
self.set_status('COMPLETED')
self.set_percent_completion(1.0)
self.set_operation_progress(1.0)
cache.set(f'operation_{self.cache_key}_output', output_data, CACHE_DURATION)

def get_output(self) -> dict:
@@ -96,19 +97,16 @@ def set_failed(self, message: str):
self.set_status('FAILED')
self.set_message(message)

def get_fits_npdata(self, input_files: list[dict], percent=None, cur_percent=None) -> list[np.memmap]:
total_files = len(input_files)
def get_fits_npdata(self, input_files: list[dict]) -> list[np.memmap]:
image_data_list = []

# get the fits urls and extract the image data
for index, file_info in enumerate(input_files, start=1):
basename = file_info.get('basename', 'No basename found')
source = file_info.get('source', 'No source found')

sci_hdu = get_hdu(basename, 'SCI', source)
fits_path = get_fits(file_info['basename'], file_info['source'])
sci_hdu = get_hdu(fits_path, 'SCI')
image_data_list.append(sci_hdu.data)

if percent is not None and cur_percent is not None:
self.set_percent_completion(cur_percent + index/total_files * percent)

return image_data_list
2 changes: 1 addition & 1 deletion datalab/datalab_session/data_operations/long.py
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@ def operate(self):
for i, file in enumerate(self.input_data.get('input_files', [])):
print(f"Processing long operation on file {file.get('basename', 'No basename found')}")
sleep(per_image_timeout)
self.set_percent_completion((i+1) / num_files)
self.set_operation_progress((i+1) / num_files)
# Done "processing" the files so set the output which sets the final status
output = {
'output_files': self.input_data.get('input_files', [])
3 changes: 2 additions & 1 deletion datalab/datalab_session/data_operations/median.py
Original file line number Diff line number Diff line change
@@ -49,7 +49,8 @@ def operate(self):

log.info(f'Executing median operation on {len(input)} files')

image_data_list = self.get_fits_npdata(input, percent=0.4, cur_percent=0.0)
image_data_list = self.get_fits_npdata(input)
self.set_operation_progress(0.40)

cropped_data_list = crop_arrays(image_data_list)
stacked_data = np.stack(cropped_data_list, axis=2)
4 changes: 2 additions & 2 deletions datalab/datalab_session/data_operations/normalization.py
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ def operate(self):
log.info(f'Executing normalization operation on {len(input)} file(s)')

image_data_list = self.get_fits_npdata(input)
self.set_percent_completion(0.40)
self.set_operation_progress(0.40)

output_files = []
for index, image in enumerate(image_data_list):
@@ -58,7 +58,7 @@ def operate(self):
output_file = save_fits_and_thumbnails(self.cache_key, fits_file, large_jpg_path, small_jpg_path, index=index)
output_files.append(output_file)

self.set_percent_completion(self.get_percent_completion() + .40 * (index + 1) / len(input))
self.set_operation_progress(self.get_operation_progress() + .40 * (index + 1) / len(input))

output = {'output_files': output_files}

4 changes: 2 additions & 2 deletions datalab/datalab_session/data_operations/rgb_stack.py
Original file line number Diff line number Diff line change
@@ -65,7 +65,7 @@ def operate(self):
fits_paths = []
for file in rgb_input_list:
fits_paths.append(get_fits(file.get('basename')))
self.set_percent_completion(self.get_percent_completion() + 0.2)
self.set_operation_progress(self.get_operation_progress() + 0.2)

large_jpg_path, small_jpg_path = create_jpgs(self.cache_key, fits_paths, color=True)

@@ -83,6 +83,6 @@ def operate(self):
output = {'output_files': []}
raise ClientAlertException('RGB Stack operation requires exactly 3 input files')

self.set_percent_completion(1.0)
self.set_operation_progress(1.0)
self.set_output(output)
log.info(f'RGB Stack output: {self.get_output()}')
6 changes: 3 additions & 3 deletions datalab/datalab_session/data_operations/stacking.py
Original file line number Diff line number Diff line change
@@ -52,17 +52,17 @@ def operate(self):

image_data_list = self.get_fits_npdata(input_files)

self.set_percent_completion(0.4)
self.set_operation_progress(0.4)

cropped_data = crop_arrays(image_data_list)
stacked_data = np.stack(cropped_data, axis=2)

self.set_percent_completion(0.6)
self.set_operation_progress(0.6)

# using the numpy library's sum method
stacked_sum = np.sum(stacked_data, axis=2)

self.set_percent_completion(0.8)
self.set_operation_progress(0.8)

fits_file = create_fits(self.cache_key, stacked_sum)

6 changes: 3 additions & 3 deletions datalab/datalab_session/data_operations/subtraction.py
Original file line number Diff line number Diff line change
@@ -64,10 +64,10 @@ def operate(self):
log.info(f'Executing subtraction operation on {len(input_files)} files')

input_image_data_list = self.get_fits_npdata(input_files)
self.set_percent_completion(.30)
self.set_operation_progress(.30)

subtraction_image = self.get_fits_npdata(subtraction_file_input)[0]
self.set_percent_completion(.40)
self.set_operation_progress(.40)

outputs = []
for index, input_image in enumerate(input_image_data_list):
@@ -82,7 +82,7 @@ def operate(self):
output_file = save_fits_and_thumbnails(self.cache_key, fits_file, large_jpg_path, small_jpg_path, index)
outputs.append(output_file)

self.set_percent_completion(self.get_percent_completion() + .50 * (index + 1) / len(input_files))
self.set_operation_progress(self.get_operation_progress() + .50 * (index + 1) / len(input_files))

output = {'output_files': outputs}

14 changes: 5 additions & 9 deletions datalab/datalab_session/file_utils.py
Original file line number Diff line number Diff line change
@@ -6,25 +6,21 @@
from fits2image.conversions import fits_to_jpg, fits_to_tif

from datalab.datalab_session.exceptions import ClientAlertException
from datalab.datalab_session.s3_utils import get_fits, add_file_to_bucket
from datalab.datalab_session.s3_utils import get_fits

log = logging.getLogger()
log.setLevel(logging.INFO)

def get_hdu(basename: str, extension: str = 'SCI', source: str = 'archive') -> list[fits.HDUList]:
def get_hdu(path: str, extension: str = 'SCI') -> list[fits.HDUList]:
"""
Returns a HDU for the given basename from the source
Will download the file to a tmp directory so future calls can open it directly
Returns a HDU for the fits in the given path
Warning: this function returns an opened file that must be closed after use
"""

basename_file_path = get_fits(basename, source)

hdu = fits.open(basename_file_path)
hdu = fits.open(path)
try:
extension = hdu[extension]
except KeyError:
raise ClientAlertException(f"{extension} Header not found in fits file {basename}")
raise ClientAlertException(f"{extension} Header not found in fits file at {path.split('/')[-1]}")

return extension

4 changes: 2 additions & 2 deletions datalab/datalab_session/models.py
Original file line number Diff line number Diff line change
@@ -74,8 +74,8 @@ def status(self):
return cache.get(f'operation_{self.cache_key}_status', 'PENDING')

@property
def percent_completion(self):
return cache.get(f'operation_{self.cache_key}_percent_completion', 0.0)
def operation_progress(self):
return cache.get(f'operation_{self.cache_key}_progress', 0.0)

@property
def output(self):
1 change: 1 addition & 0 deletions datalab/datalab_session/s3_utils.py
Original file line number Diff line number Diff line change
@@ -114,6 +114,7 @@ def get_archive_url(basename: str, archive: str = settings.ARCHIVE_API) -> dict:
def get_fits(basename: str, source: str = 'archive'):
"""
Returns a Fits File for the given basename from the source bucket
Will download the file to a tmp directory so future calls can open it directly
"""
basename = basename.replace('-large', '').replace('-small', '')
basename_file_path = os.path.join(settings.TEMP_FITS_DIR, basename)
4 changes: 2 additions & 2 deletions datalab/datalab_session/serializers.py
Original file line number Diff line number Diff line change
@@ -10,14 +10,14 @@ class DataOperationSerializer(serializers.ModelSerializer):
cache_key = serializers.CharField(write_only=True, required=False)
status = serializers.ReadOnlyField()
message = serializers.ReadOnlyField()
percent_completion = serializers.ReadOnlyField()
operation_progress = serializers.ReadOnlyField()
output = serializers.ReadOnlyField()

class Meta:
model = DataOperation
exclude = ('session',)
read_only_fields = (
'id', 'created', 'status', 'percent_completion', 'message', 'output',
'id', 'created', 'status', 'operation_progress', 'message', 'output',
)

class DataSessionSerializer(serializers.ModelSerializer):
4 changes: 2 additions & 2 deletions datalab/datalab_session/tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ def setUp(self):
with open(f'{self.analysis_test_path}test_source_catalog.json') as f:
self.test_source_catalog_data = json.load(f)['test_source_catalog']

@mock.patch('datalab.datalab_session.file_utils.get_fits')
@mock.patch('datalab.datalab_session.analysis.line_profile.get_fits')
def test_line_profile(self, mock_get_fits):

mock_get_fits.return_value = self.analysis_fits_1_path
@@ -34,7 +34,7 @@ def test_line_profile(self, mock_get_fits):

assert_almost_equal(output.get('line_profile').tolist(), self.test_line_profile_data, decimal=3)

@mock.patch('datalab.datalab_session.file_utils.get_fits')
@mock.patch('datalab.datalab_session.analysis.source_catalog.get_fits')
def test_source_catalog(self, mock_get_fits):

mock_get_fits.return_value = self.analysis_fits_1_path
14 changes: 7 additions & 7 deletions datalab/datalab_session/tests/test_operations.py
Original file line number Diff line number Diff line change
@@ -138,7 +138,7 @@ def test_wizard_description(self):

def test_operate(self):
self.data_operation.operate()
self.assertEqual(self.data_operation.get_percent_completion(), 1.0)
self.assertEqual(self.data_operation.get_operation_progress(), 1.0)
self.assertEqual(self.data_operation.get_status(), 'COMPLETED')
self.assertEqual(self.data_operation.get_output(), {'output_files': []})

@@ -148,7 +148,7 @@ def test_generate_cache_key(self):

def test_set_get_output(self):
self.data_operation.set_output({'output_files': []})
self.assertEqual(self.data_operation.get_percent_completion(), 1.0)
self.assertEqual(self.data_operation.get_operation_progress(), 1.0)
self.assertEqual(self.data_operation.get_status(), 'COMPLETED')
self.assertEqual(self.data_operation.get_output(), {'output_files': []})

@@ -169,7 +169,7 @@ def tearDown(self):
return super().tearDown()

@mock.patch('datalab.datalab_session.file_utils.tempfile.NamedTemporaryFile')
@mock.patch('datalab.datalab_session.file_utils.get_fits')
@mock.patch('datalab.datalab_session.data_operations.data_operation.get_fits')
@mock.patch('datalab.datalab_session.data_operations.median.save_fits_and_thumbnails')
@mock.patch('datalab.datalab_session.data_operations.median.create_jpgs')
def test_operate(self, mock_create_jpgs, mock_save_fits_and_thumbnails, mock_get_fits, mock_named_tempfile):
@@ -194,7 +194,7 @@ def test_operate(self, mock_create_jpgs, mock_save_fits_and_thumbnails, mock_get
median.operate()
output = median.get_output().get('output_files')

self.assertEqual(median.get_percent_completion(), 1.0)
self.assertEqual(median.get_operation_progress(), 1.0)
self.assertTrue(os.path.exists(output[0]))
self.assertFilesEqual(self.test_median_path, output[0])

@@ -246,7 +246,7 @@ def test_operate(self, mock_get_fits, mock_named_tempfile, mock_create_jpgs, moc
rgb.operate()
output = rgb.get_output().get('output_files')

self.assertEqual(rgb.get_percent_completion(), 1.0)
self.assertEqual(rgb.get_operation_progress(), 1.0)
self.assertTrue(os.path.exists(output[0]))
self.assertFilesEqual(self.test_rgb_path, output[0])

@@ -265,7 +265,7 @@ def tearDown(self):
return super().tearDown()

@mock.patch('datalab.datalab_session.file_utils.tempfile.NamedTemporaryFile')
@mock.patch('datalab.datalab_session.file_utils.get_fits')
@mock.patch('datalab.datalab_session.data_operations.data_operation.get_fits')
@mock.patch('datalab.datalab_session.data_operations.stacking.save_fits_and_thumbnails')
@mock.patch('datalab.datalab_session.data_operations.stacking.create_jpgs')
def test_operate(self, mock_create_jpgs, mock_save_fits_and_thumbnails, mock_get_fits, mock_named_tempfile):
@@ -323,7 +323,7 @@ def test_operate(self, mock_create_jpgs, mock_save_fits_and_thumbnails, mock_get
output = stack.get_output().get('output_files')

# 100% completion
self.assertEqual(stack.get_percent_completion(), 1.0)
self.assertEqual(stack.get_operation_progress(), 1.0)

# test that file paths are the same
self.assertEqual(self.temp_stacked_path, output[0])

0 comments on commit 4f91efc

Please sign in to comment.