|
| 1 | +from io import BytesIO |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import tempfile |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from astropy.io import fits |
| 8 | + |
1 | 9 | from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
|
| 10 | +from datalab.datalab_session.util import store_fits_output, get_archive_from_basename |
| 11 | + |
| 12 | +log = logging.getLogger() |
| 13 | +log.setLevel(logging.INFO) |
2 | 14 |
|
3 | 15 |
|
4 | 16 | class Median(BaseDataOperation):
|
@@ -31,4 +43,62 @@ def wizard_description():
|
31 | 43 | }
|
32 | 44 |
|
33 | 45 | def operate(self):
|
34 |
| - pass |
| 46 | + input_files = self.input_data.get('input_files', []) |
| 47 | + file_count = len(input_files) |
| 48 | + |
| 49 | + if file_count == 0: |
| 50 | + return { 'output_files': [] } |
| 51 | + |
| 52 | + log.info(f'Executing median operation on {file_count} files') |
| 53 | + |
| 54 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 55 | + memmap_paths = [] |
| 56 | + |
| 57 | + for index, file_info in enumerate(input_files): |
| 58 | + basename = file_info.get('basename', 'No basename found') |
| 59 | + archive_record = get_archive_from_basename(basename) |
| 60 | + |
| 61 | + try: |
| 62 | + fits_url = archive_record[0].get('url', 'No URL found') |
| 63 | + except IndexError: |
| 64 | + continue |
| 65 | + |
| 66 | + with fits.open(fits_url, use_fsspec=True) as hdu_list: |
| 67 | + data = hdu_list['SCI'].data |
| 68 | + memmap_path = os.path.join(temp_dir, f'memmap_{index}.dat') |
| 69 | + memmap_array = np.memmap(memmap_path, dtype=data.dtype, mode='w+', shape=data.shape) |
| 70 | + memmap_array[:] = data[:] |
| 71 | + memmap_paths.append(memmap_path) |
| 72 | + |
| 73 | + self.set_percent_completion(index / file_count) |
| 74 | + |
| 75 | + image_data_list = [ |
| 76 | + np.memmap(path, dtype=np.float32, mode='r', shape=memmap_array.shape) |
| 77 | + for path in memmap_paths |
| 78 | + ] |
| 79 | + |
| 80 | + # Crop fits image data to be the same shape then stack |
| 81 | + min_shape = min(arr.shape for arr in image_data_list) |
| 82 | + cropped_data_list = [arr[:min_shape[0], :min_shape[1]] for arr in image_data_list] |
| 83 | + stacked_data = np.stack(cropped_data_list, axis=2) |
| 84 | + |
| 85 | + # Calculate a Median along the z axis |
| 86 | + median = np.median(stacked_data, axis=2) |
| 87 | + |
| 88 | + cache_key = self.generate_cache_key() |
| 89 | + header = fits.Header([('KEY', cache_key)]) |
| 90 | + primary_hdu = fits.PrimaryHDU(header=header) |
| 91 | + image_hdu = fits.ImageHDU(median) |
| 92 | + hdu_list = fits.HDUList([primary_hdu, image_hdu]) |
| 93 | + |
| 94 | + fits_buffer = BytesIO() |
| 95 | + hdu_list.writeto(fits_buffer) |
| 96 | + fits_buffer.seek(0) |
| 97 | + |
| 98 | + # Write the HDU List to the output FITS file in the bucket |
| 99 | + response = store_fits_output(cache_key, fits_buffer) |
| 100 | + |
| 101 | + # TODO: No output yet, need to build a thumbnail service |
| 102 | + output = {'output_files': []} |
| 103 | + self.set_percent_completion(file_count / file_count) |
| 104 | + self.set_output(output) |
0 commit comments