diff --git a/.gitignore b/.gitignore index 722f2f0b8..37100481c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,170 +1,171 @@ -# Elastic Beanstalk Files -.elasticbeanstalk/* -!.elasticbeanstalk/*.cfg.yml -!.elasticbeanstalk/*.global.yml - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# OS generated files -.DS_Store -.DS_Store? -._* -.Spotlight-V100 -.Trashes -ehthumbs.db -Thumbs.db - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 - -# Flask stuff: -instance/ -.webassets-cache - -# Flask monitoring dashboard stuff: -flask_monitoringdashboard.db -fmd_config.cfg - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# pyenv -.python-version - -# celery beat schedule file -celerybeat-schedule - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ - -node_modules - -# Elastic Beanstalk Files -backend/.elasticbeanstalk/* -!backend/.elasticbeanstalk/*.cfg.yml -!backend/.elasticbeanstalk/*.global.yml - -# VS Code -settings.json -.vscode - -# Environments with sensitive info -.env -fmd_config.cfg - -# from create-react-app - -# dependencies -/node_modules -/.pnp -.pnp.js - -# testing -/coverage -**/coverage - -# production -/build - -# models and other pickled files -*.h5 -*.npy - -# misc -.DS_Store -.env.local -.env.development.local -.env.test.local -.env.production.local - -npm-debug.log* -yarn-debug.log* -yarn-error.log* -package-lock.json -package.json - -.nyc_output +# Elastic Beanstalk Files +.elasticbeanstalk/* +!.elasticbeanstalk/*.cfg.yml +!.elasticbeanstalk/*.global.yml + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# OS generated files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Flask monitoring dashboard stuff: +flask_monitoringdashboard.db +fmd_config.cfg + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +deepcell-env/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +node_modules + +# Elastic Beanstalk Files +backend/.elasticbeanstalk/* +!backend/.elasticbeanstalk/*.cfg.yml +!backend/.elasticbeanstalk/*.global.yml + +# VS Code +settings.json +.vscode + +# Environments with sensitive info +.env +fmd_config.cfg + +# from create-react-app + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage +**/coverage + +# production +/build + +# models and other pickled files +*.h5 +*.npy + +# misc +.DS_Store +.env.local +.env.development.local +.env.test.local +.env.production.local + +npm-debug.log* +yarn-debug.log* +yarn-error.log* +package-lock.json +package.json + +.nyc_output diff --git a/backend/deepcell_label/blueprints.py b/backend/deepcell_label/blueprints.py index f2da822e7..6ddce6b5e 100644 --- a/backend/deepcell_label/blueprints.py +++ b/backend/deepcell_label/blueprints.py @@ -1,208 +1,208 @@ -"""Flask blueprint for modular routes.""" -from __future__ import absolute_import, division, print_function - -import io -import os -import tempfile -import timeit -import traceback - -import boto3 -import requests -from flask import Blueprint, abort, current_app, jsonify, request, send_file -from werkzeug.exceptions import HTTPException - -from deepcell_label.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, DELETE_TEMP -from deepcell_label.export import Export -from deepcell_label.label import Edit -from deepcell_label.loaders import Loader -from deepcell_label.models import Project - -bp = Blueprint('label', __name__) # pylint: disable=C0103 - - -@bp.route('/health') -def health(): - """Returns success if the application is ready.""" - return jsonify({'message': 'success'}), 200 - - -@bp.errorhandler(Exception) -def handle_exception(error): - """Handle all uncaught exceptions""" - # pass through HTTP errors - if isinstance(error, HTTPException): - return error - - current_app.logger.error( - 'Encountered %s: %s', error.__class__.__name__, error, exc_info=1 - ) - - traceback.print_exc() - # now you're handling non-HTTP exceptions only - return jsonify({'error': str(error)}), 500 - - -@bp.route('/api/project/', methods=['GET']) -def get_project(project): - start = timeit.default_timer() - project = Project.get(project) - if not project: - return abort(404, description=f'project {project} not found') - bucket = request.args.get('bucket', default=project.bucket) - s3 = boto3.client('s3') - data = io.BytesIO() - s3.download_fileobj(bucket, project.key, data) - data.seek(0) - current_app.logger.info( - f'Loaded project {project.key} from {bucket} in {timeit.default_timer() - start} s.', - ) - return send_file(data, mimetype='application/zip') - - -@bp.route('/api/project', methods=['POST']) -def create_project(): - """ - Create a new Project from URL. - """ - start = timeit.default_timer() - if 'images' in request.form: - images_url = request.form['images'] - else: - return abort( - 400, - description='Include "images" in the request form with a URL to download the project data.', - ) - labels_url = request.form['labels'] if 'labels' in request.form else None - axes = request.form['axes'] if 'axes' in request.form else None - with tempfile.NamedTemporaryFile( - delete=DELETE_TEMP - ) as image_file, tempfile.NamedTemporaryFile(delete=DELETE_TEMP) as label_file: - if images_url is not None: - image_response = requests.get(images_url) - if image_response.status_code != 200: - return ( - image_response.text, - image_response.status_code, - image_response.headers.items(), - ) - image_file.write(image_response.content) - image_file.seek(0) - if labels_url is not None: - labels_response = requests.get(labels_url) - if labels_response.status_code != 200: - return ( - labels_response.text, - labels_response.status_code, - labels_response.headers.items(), - ) - label_file.write(labels_response.content) - label_file.seek(0) - else: - label_file = image_file - loader = Loader(image_file, label_file, axes) - project = Project.create(loader) - if not DELETE_TEMP: - image_file.close() - label_file.close() - os.remove(image_file.name) # Manually close and delete if using Windows - current_app.logger.info( - 'Created project %s from %s in %s s.', - project.project, - f'{images_url}' if labels_url is None else f'{images_url} and {labels_url}', - timeit.default_timer() - start, - ) - return jsonify(project.project) - - -@bp.route('/api/project/dropped', methods=['POST']) -def create_project_from_dropped_file(): - """ - Create a new Project from drag & dropped file. - """ - start = timeit.default_timer() - input_file = request.files.get('images') - axes = request.form['axes'] if 'axes' in request.form else None - # axes = request.form['axes'] if 'axes' in request.form else DCL_AXES - with tempfile.NamedTemporaryFile(delete=DELETE_TEMP) as f: - f.write(input_file.read()) - f.seek(0) - loader = Loader(f, axes=axes) - project = Project.create(loader) - if not DELETE_TEMP: - f.close() - os.remove(f.name) # Manually close and delete if using Windows - current_app.logger.info( - 'Created project %s from %s in %s s.', - project.project, - input_file.filename, - timeit.default_timer() - start, - ) - return jsonify(project.project) - - -@bp.route('/api/edit', methods=['POST']) -def edit(): - """Loads labeled data from a zip, edits them, and responds with a zip with the edited labels.""" - start = timeit.default_timer() - if 'labels' not in request.files: - return abort(400, description='Attach the labeled data to edit in labels.zip.') - labels_zip = request.files['labels'] - edit = Edit(labels_zip) - current_app.logger.debug( - 'Finished action %s in %s s.', - edit.action, - timeit.default_timer() - start, - ) - return send_file(edit.response_zip, mimetype='application/zip') - - -@bp.route('/api/download', methods=['POST']) -def download_project(): - """ - Create a DeepCell Label zip file for the user to download - The submitted zip should contain the raw and labeled array buffers - in .dat files with the dimensions in dimensions.json, - which are transformed into OME TIFFs in the submitted zips. - """ - if 'labels' not in request.files: - return abort(400, description='Attach labels.zip to download.') - labels_zip = request.files['labels'] - id = request.form['id'] - export = Export(labels_zip) - data = export.export_zip - return send_file(data, as_attachment=True, attachment_filename=f'{id}.zip') - - -@bp.route('/api/upload', methods=['POST']) -def submit_project(): - """ - Create and upload an edited DeepCell Label zip file to an S3 bucket. - The submitted zip should contain the raw and labeled array buffers - in .dat files with the dimensions in dimensions.json, - which are transformed into OME TIFFs in the submitted zips. - """ - start = timeit.default_timer() - if 'labels' not in request.files: - return abort(400, description='Attach labels.zip to submit.') - labels_zip = request.files['labels'] - id = request.form['id'] - bucket = request.form['bucket'] - export = Export(labels_zip) - data = export.export_zip - - # store npz file object in bucket/path - s3 = boto3.client( - 's3', - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - ) - s3.upload_fileobj(data, bucket, f'{id}.zip') - - current_app.logger.debug( - 'Uploaded %s to S3 bucket %s in %s s.', - f'{id}.zip', - bucket, - timeit.default_timer() - start, - ) - return {} +"""Flask blueprint for modular routes.""" +from __future__ import absolute_import, division, print_function + +import io +import os +import tempfile +import timeit +import traceback + +import boto3 +import requests +from flask import Blueprint, abort, current_app, jsonify, request, send_file +from werkzeug.exceptions import HTTPException + +from deepcell_label.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, DELETE_TEMP +from deepcell_label.export import Export +from deepcell_label.label import Edit +from deepcell_label.loaders import Loader +from deepcell_label.models import Project + +bp = Blueprint('label', __name__) # pylint: disable=C0103 + + +@bp.route('/health') +def health(): + """Returns success if the application is ready.""" + return jsonify({'message': 'success'}), 200 + + +@bp.errorhandler(Exception) +def handle_exception(error): + """Handle all uncaught exceptions""" + # pass through HTTP errors + if isinstance(error, HTTPException): + return error + + current_app.logger.error( + 'Encountered %s: %s', error.__class__.__name__, error, exc_info=1 + ) + + traceback.print_exc() + # now you're handling non-HTTP exceptions only + return jsonify({'error': str(error)}), 500 + + +@bp.route('/api/project/', methods=['GET']) +def get_project(project): + start = timeit.default_timer() + project = Project.get(project) + if not project: + return abort(404, description=f'project {project} not found') + bucket = request.args.get('bucket', default=project.bucket) + s3 = boto3.client('s3') + data = io.BytesIO() + s3.download_fileobj(bucket, project.key, data) + data.seek(0) + current_app.logger.info( + f'Loaded project {project.key} from {bucket} in {timeit.default_timer() - start} s.', + ) + return send_file(data, mimetype='application/zip') + + +@bp.route('/api/project', methods=['POST']) +def create_project(): + """ + Create a new Project from URL. + """ + start = timeit.default_timer() + if 'images' in request.form: + images_url = request.form['images'] + else: + return abort( + 400, + description='Include "images" in the request form with a URL to download the project data.', + ) + labels_url = request.form['labels'] if 'labels' in request.form else None + axes = request.form['axes'] if 'axes' in request.form else None + with tempfile.NamedTemporaryFile( + delete=DELETE_TEMP + ) as image_file, tempfile.NamedTemporaryFile(delete=DELETE_TEMP) as label_file: + if images_url is not None: + image_response = requests.get(images_url) + if image_response.status_code != 200: + return ( + image_response.text, + image_response.status_code, + image_response.headers.items(), + ) + image_file.write(image_response.content) + image_file.seek(0) + if labels_url is not None: + labels_response = requests.get(labels_url) + if labels_response.status_code != 200: + return ( + labels_response.text, + labels_response.status_code, + labels_response.headers.items(), + ) + label_file.write(labels_response.content) + label_file.seek(0) + else: + label_file = image_file + loader = Loader(image_file, label_file, axes) + project = Project.create(loader) + if not DELETE_TEMP: + image_file.close() + label_file.close() + os.remove(image_file.name) # Manually close and delete if using Windows + current_app.logger.info( + 'Created project %s from %s in %s s.', + project.project, + f'{images_url}' if labels_url is None else f'{images_url} and {labels_url}', + timeit.default_timer() - start, + ) + return jsonify(project.project) + + +@bp.route('/api/project/dropped', methods=['POST']) +def create_project_from_dropped_file(): + """ + Create a new Project from drag & dropped file. + """ + start = timeit.default_timer() + input_file = request.files.get('images') + axes = request.form['axes'] if 'axes' in request.form else None + # axes = request.form['axes'] if 'axes' in request.form else DCL_AXES + with tempfile.NamedTemporaryFile(delete=DELETE_TEMP) as f: + f.write(input_file.read()) + f.seek(0) + loader = Loader(f, axes=axes) + project = Project.create(loader) + if not DELETE_TEMP: + f.close() + os.remove(f.name) # Manually close and delete if using Windows + current_app.logger.info( + 'Created project %s from %s in %s s.', + project.project, + input_file.filename, + timeit.default_timer() - start, + ) + return jsonify(project.project) + + +@bp.route('/api/edit', methods=['POST']) +def edit(): + """Loads labeled data from a zip, edits them, and responds with a zip with the edited labels.""" + start = timeit.default_timer() + if 'labels' not in request.files: + return abort(400, description='Attach the labeled data to edit in labels.zip.') + labels_zip = request.files['labels'] + edit = Edit(labels_zip) + current_app.logger.debug( + 'Finished action %s in %s s.', + edit.action, + timeit.default_timer() - start, + ) + return send_file(edit.response_zip, mimetype='application/zip') + + +@bp.route('/api/download', methods=['POST']) +def download_project(): + """ + Create a DeepCell Label zip file for the user to download + The submitted zip should contain the raw and labeled array buffers + in .dat files with the dimensions in dimensions.json, + which are transformed into OME TIFFs in the submitted zips. + """ + if 'labels' not in request.files: + return abort(400, description='Attach labels.zip to download.') + labels_zip = request.files['labels'] + id = request.form['id'] + export = Export(labels_zip) + data = export.export_zip + return send_file(data, as_attachment=True, attachment_filename=f'{id}.zip') + + +@bp.route('/api/upload', methods=['POST']) +def submit_project(): + """ + Create and upload an edited DeepCell Label zip file to an S3 bucket. + The submitted zip should contain the raw and labeled array buffers + in .dat files with the dimensions in dimensions.json, + which are transformed into OME TIFFs in the submitted zips. + """ + start = timeit.default_timer() + if 'labels' not in request.files: + return abort(400, description='Attach labels.zip to submit.') + labels_zip = request.files['labels'] + id = request.form['id'] + bucket = request.form['bucket'] + export = Export(labels_zip) + data = export.export_zip + + # store npz file object in bucket/path + s3 = boto3.client( + 's3', + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + ) + s3.upload_fileobj(data, bucket, f'{id}.zip') + + current_app.logger.debug( + 'Uploaded %s to S3 bucket %s in %s s.', + f'{id}.zip', + bucket, + timeit.default_timer() - start, + ) + return {} diff --git a/backend/deepcell_label/client.py b/backend/deepcell_label/client.py new file mode 100644 index 000000000..23a496540 --- /dev/null +++ b/backend/deepcell_label/client.py @@ -0,0 +1,34 @@ +import asyncio +import json + +import numpy as np +import websockets +from flask import current_app + + +def convert_to_json(to_send, label): + msg = { + 'data': to_send.ravel().tolist(), + 'shape': to_send.shape, + 'dtype': to_send.dtype.descr[0][-1], + 'label': label, + } + return json.dumps(msg) + + +async def perform_send(to_send, label): + uri = 'ws://131.215.2.183:8765' + async with websockets.connect(uri) as websocket: + packet = convert_to_json(to_send, label) + await websocket.send(packet) + print('sent') + serialized = await websocket.recv() + if serialized != 'yes': + msg = json.loads(serialized) + mask = np.array(msg['data'], dtype=msg['dtype']).reshape(msg['shape']) + np.save('deepcell_label/mask.npy', mask) + + +def send_to_server(to_send, label): + current_app.logger.info('Sent to server to generate mask for cellSAM') + asyncio.run(perform_send(to_send, label)) diff --git a/backend/deepcell_label/export.py b/backend/deepcell_label/export.py index de120c6b5..b34fda6f4 100644 --- a/backend/deepcell_label/export.py +++ b/backend/deepcell_label/export.py @@ -1,184 +1,184 @@ -""" -Converts data from the client into a ZIP to export from DeepCell Label. -""" - -import io -import itertools -import json -import zipfile - -import numpy as np -import tifffile - - -class Export: - def __init__(self, labels_zip): - self.labels_zip = labels_zip - self.export_zip = io.BytesIO() - - self.load_dimensions() - self.load_labeled() - self.load_raw() - self.load_channels() - self.load_cells() - - self.labeled, self.cells = rewrite_labeled(self.labeled, self.cells) - - self.write_export_zip() - self.export_zip.seek(0) - - def load_dimensions(self): - """Loads dimensions of the raw and labeled arrays from dimensions.json.""" - with zipfile.ZipFile(self.labels_zip) as zf: - with zf.open('dimensions.json') as f: - dimensions = json.load(f) - self.height = dimensions['height'] - self.width = dimensions['width'] - self.duration = dimensions['duration'] - self.num_channels = dimensions['numChannels'] - self.num_features = dimensions['numFeatures'] - self.dtype = self.get_dtype(dimensions['dtype']) - - def get_dtype(self, arr_type): - """Matches raw array dtype with a numpy dtype""" - mapping = { - 'Uint8Array': np.uint8, - 'Uint16Array': np.uint16, - 'Uint32Array': np.uint32, - 'Int32Array': np.int32, - 'Float32Array': np.float32, - 'Float64Array': np.float64, - } - try: - return mapping[arr_type] - except KeyError: - raise ValueError('Could not match dtype of raw array.') - - def load_labeled(self): - """Loads the labeled array from the labeled.dat file.""" - with zipfile.ZipFile(self.labels_zip) as zf: - with zf.open('labeled.dat') as f: - labeled = np.frombuffer(f.read(), np.int32) - self.labeled = np.reshape( - labeled, - (self.num_features, self.duration, self.height, self.width), - ) - - def load_raw(self): - """Loads the raw array from the raw.dat file.""" - with zipfile.ZipFile(self.labels_zip) as zf: - with zf.open('raw.dat') as f: - raw = np.frombuffer(f.read(), self.dtype) - self.raw = np.reshape( - raw, (self.num_channels, self.duration, self.height, self.width) - ) - - def load_channels(self): - """Loads the channels array from channels.json""" - with zipfile.ZipFile(self.labels_zip) as zf: - with zf.open('channels.json') as f: - self.channels = json.load(f) - - def load_cells(self): - """Loads cell labels from cells.json.""" - with zipfile.ZipFile(self.labels_zip) as zf: - with zf.open('cells.json') as f: - self.cells = json.load(f) - - def write_export_zip(self): - """Writes an export zip with OME TIFF files instead of raw.dat and labeled.dat.""" - # Rewrite all other files in input zip to export zip - with zipfile.ZipFile(self.export_zip, 'w') as export_zf, zipfile.ZipFile( - self.labels_zip - ) as input_zf: - for item in input_zf.infolist(): - # Writes all other files (divisions.json, spots.csv, etc.) to export zip - if item.filename not in [ - 'dimensions.json', - 'labeled.dat', - 'raw.dat', - 'cells.json', - ]: - buffer = input_zf.read(item.filename) - export_zf.writestr(item, buffer) - # Write updated cells - export_zf.writestr('cells.json', json.dumps(self.cells)) - # Write OME TIFF for labeled - labeled_ome_tiff = io.BytesIO() - tifffile.imwrite( - labeled_ome_tiff, - self.labeled, - ome=True, - photometric='minisblack', - compression='zlib', - metadata={'axes': 'CZYX'}, - ) - labeled_ome_tiff.seek(0) - export_zf.writestr('y.ome.tiff', labeled_ome_tiff.read()) - # Write OME TIFF for raw - raw_ome_tiff = io.BytesIO() - tifffile.imwrite( - raw_ome_tiff, - self.raw, - ome=True, - photometric='minisblack', - compression='zlib', - metadata={'axes': 'CZYX', 'Channel': {'Name': self.channels}}, - ) - raw_ome_tiff.seek(0) - export_zf.writestr('X.ome.tiff', raw_ome_tiff.read()) - - -def rewrite_labeled(labeled, cells): - """ - Rewrites the labeled to use values from cell labels. - - Args: - labeled: numpy array of shape (num_features, duration, height, width) - cells: list of cells labels like { "cell": 1, "value": 1, "t": 0} - - Returns: - (numpy array of shape (num_features, duration, height, width), cells with updated values) - """ - new_labeled = np.zeros(labeled.shape, dtype=np.int32) - (num_features, duration, height, width) = labeled.shape - new_cells = [] - for c in range(num_features): - cells_in_feature = list(filter(lambda cell, c=c: cell['c'] == c, cells)) - for t in range(duration): - cells_at_t = list( - filter(lambda cell, t=t: cell['t'] == t, cells_in_feature) - ) - values = itertools.groupby(cells_at_t, lambda c: c['value']) - overlap_values = [] - - # Rewrite non-overlapping values with cells - for value, group in values: - group = list(group) - if len(group) == 1: - cell = group[0]['cell'] - frame = labeled[:, t, :, :] - new_labeled[:, t, :, :][frame == value] = cell - new_cells.append({'cell': cell, 'value': cell, 't': t, 'c': c}) - else: - overlap_values.append([value, group]) - - # Rewrite overlapping values with values higher than all cells - if len(cells_at_t) == 0: - new_overlap_value = 0 - else: - new_overlap_value = max(cells_at_t, key=lambda c: c['cell'])['cell'] + 1 - for overlap_value, overlap_cells in overlap_values: - for cell in overlap_cells: - frame = labeled[:, t, :, :] - new_labeled[:, t, :, :][frame == overlap_value] = new_overlap_value - new_cells.append( - { - 'cell': cell['cell'], - 'value': new_overlap_value, - 't': t, - 'c': c, - } - ) - new_overlap_value += 1 - return new_labeled, new_cells +""" +Converts data from the client into a ZIP to export from DeepCell Label. +""" + +import io +import itertools +import json +import zipfile + +import numpy as np +import tifffile + + +class Export: + def __init__(self, labels_zip): + self.labels_zip = labels_zip + self.export_zip = io.BytesIO() + + self.load_dimensions() + self.load_labeled() + self.load_raw() + self.load_channels() + self.load_cells() + + self.labeled, self.cells = rewrite_labeled(self.labeled, self.cells) + + self.write_export_zip() + self.export_zip.seek(0) + + def load_dimensions(self): + """Loads dimensions of the raw and labeled arrays from dimensions.json.""" + with zipfile.ZipFile(self.labels_zip) as zf: + with zf.open('dimensions.json') as f: + dimensions = json.load(f) + self.height = dimensions['height'] + self.width = dimensions['width'] + self.duration = dimensions['duration'] + self.num_channels = dimensions['numChannels'] + self.num_features = dimensions['numFeatures'] + self.dtype = self.get_dtype(dimensions['dtype']) + + def get_dtype(self, arr_type): + """Matches raw array dtype with a numpy dtype""" + mapping = { + 'Uint8Array': np.uint8, + 'Uint16Array': np.uint16, + 'Uint32Array': np.uint32, + 'Int32Array': np.int32, + 'Float32Array': np.float32, + 'Float64Array': np.float64, + } + try: + return mapping[arr_type] + except KeyError: + raise ValueError('Could not match dtype of raw array.') + + def load_labeled(self): + """Loads the labeled array from the labeled.dat file.""" + with zipfile.ZipFile(self.labels_zip) as zf: + with zf.open('labeled.dat') as f: + labeled = np.frombuffer(f.read(), np.int32) + self.labeled = np.reshape( + labeled, + (self.num_features, self.duration, self.height, self.width), + ) + + def load_raw(self): + """Loads the raw array from the raw.dat file.""" + with zipfile.ZipFile(self.labels_zip) as zf: + with zf.open('raw.dat') as f: + raw = np.frombuffer(f.read(), self.dtype) + self.raw = np.reshape( + raw, (self.num_channels, self.duration, self.height, self.width) + ) + + def load_channels(self): + """Loads the channels array from channels.json""" + with zipfile.ZipFile(self.labels_zip) as zf: + with zf.open('channels.json') as f: + self.channels = json.load(f) + + def load_cells(self): + """Loads cell labels from cells.json.""" + with zipfile.ZipFile(self.labels_zip) as zf: + with zf.open('cells.json') as f: + self.cells = json.load(f) + + def write_export_zip(self): + """Writes an export zip with OME TIFF files instead of raw.dat and labeled.dat.""" + # Rewrite all other files in input zip to export zip + with zipfile.ZipFile(self.export_zip, 'w') as export_zf, zipfile.ZipFile( + self.labels_zip + ) as input_zf: + for item in input_zf.infolist(): + # Writes all other files (divisions.json, spots.csv, etc.) to export zip + if item.filename not in [ + 'dimensions.json', + 'labeled.dat', + 'raw.dat', + 'cells.json', + ]: + buffer = input_zf.read(item.filename) + export_zf.writestr(item, buffer) + # Write updated cells + export_zf.writestr('cells.json', json.dumps(self.cells)) + # Write OME TIFF for labeled + labeled_ome_tiff = io.BytesIO() + tifffile.imwrite( + labeled_ome_tiff, + self.labeled, + ome=True, + photometric='minisblack', + compression='zlib', + metadata={'axes': 'CZYX'}, + ) + labeled_ome_tiff.seek(0) + export_zf.writestr('y.ome.tiff', labeled_ome_tiff.read()) + # Write OME TIFF for raw + raw_ome_tiff = io.BytesIO() + tifffile.imwrite( + raw_ome_tiff, + self.raw, + ome=True, + photometric='minisblack', + compression='zlib', + metadata={'axes': 'CZYX', 'Channel': {'Name': self.channels}}, + ) + raw_ome_tiff.seek(0) + export_zf.writestr('X.ome.tiff', raw_ome_tiff.read()) + + +def rewrite_labeled(labeled, cells): + """ + Rewrites the labeled to use values from cell labels. + + Args: + labeled: numpy array of shape (num_features, duration, height, width) + cells: list of cells labels like { "cell": 1, "value": 1, "t": 0} + + Returns: + (numpy array of shape (num_features, duration, height, width), cells with updated values) + """ + new_labeled = np.zeros(labeled.shape, dtype=np.int32) + (num_features, duration, height, width) = labeled.shape + new_cells = [] + for c in range(num_features): + cells_in_feature = list(filter(lambda cell, c=c: cell['c'] == c, cells)) + for t in range(duration): + cells_at_t = list( + filter(lambda cell, t=t: cell['t'] == t, cells_in_feature) + ) + values = itertools.groupby(cells_at_t, lambda c: c['value']) + overlap_values = [] + + # Rewrite non-overlapping values with cells + for value, group in values: + group = list(group) + if len(group) == 1: + cell = group[0]['cell'] + frame = labeled[:, t, :, :] + new_labeled[:, t, :, :][frame == value] = cell + new_cells.append({'cell': cell, 'value': cell, 't': t, 'c': c}) + else: + overlap_values.append([value, group]) + + # Rewrite overlapping values with values higher than all cells + if len(cells_at_t) == 0: + new_overlap_value = 0 + else: + new_overlap_value = max(cells_at_t, key=lambda c: c['cell'])['cell'] + 1 + for overlap_value, overlap_cells in overlap_values: + for cell in overlap_cells: + frame = labeled[:, t, :, :] + new_labeled[:, t, :, :][frame == overlap_value] = new_overlap_value + new_cells.append( + { + 'cell': cell['cell'], + 'value': new_overlap_value, + 't': t, + 'c': c, + } + ) + new_overlap_value += 1 + return new_labeled, new_cells diff --git a/backend/deepcell_label/label.py b/backend/deepcell_label/label.py index 5cd7c5a56..29a5d7f1d 100644 --- a/backend/deepcell_label/label.py +++ b/backend/deepcell_label/label.py @@ -1,420 +1,447 @@ -"""Classes to view and edit DeepCell Label Projects""" -from __future__ import absolute_import, division, print_function - -import io -import json -import zipfile - -import numpy as np -import skimage -from skimage import filters -from skimage.exposure import rescale_intensity -from skimage.measure import regionprops -from skimage.morphology import dilation, disk, erosion, flood, square -from skimage.segmentation import morphological_chan_vese, watershed - - -class Edit(object): - """ - Loads labeled data from a zip file, - edits the labels according to edit.json in the zip, - and writes the edited labels to a new zip file. - - The labels zipfile must contain: - labeled.dat - a binary array buffer of the labeled data (int32) - overlaps.json - a 2D json array describing values encode which cells - the (i, j)th element of overlaps.json is 1 if value i encodes cell j and 0 otherwise - edit.json - a json object describing the edit to be made including - - action (e.g. ) - - the args for the action - - write_mode: one of 'overlap', 'overwrite', or 'exclude' - - height: the height of the labeled (and raw) arrays - - width: the width of the labeled (and raw) arrays - It additionally may contain: - raw.dat - a binary array buffer of the raw data (uint8) - lineage.json - a json object describing the lineage of the cells - """ - - def __init__(self, labels_zip): - - self.valid_modes = ['overlap', 'overwrite', 'exclude'] - self.raw_required = ['watershed', 'active_contour', 'threshold'] - - self.load(labels_zip) - self.dispatch_action() - self.write_response_zip() - - @property - def new_value(self): - """Returns a value not in the segmentation.""" - if len(self.cells) == 0: - return 1 - return max(map(lambda c: c['value'], self.cells)) + 1 - - @property - def new_cell(self): - """Returns a cell not in the segmentation.""" - if len(self.cells) == 0: - return 1 - return max(map(lambda c: c['cell'], self.cells)) + 1 - - def load(self, labels_zip): - """ - Load the project data to edit from a zip file. - """ - if not zipfile.is_zipfile(labels_zip): - raise ValueError('Attached labels.zip is not a zip file.') - zf = zipfile.ZipFile(labels_zip) - - # Load edit args - if 'edit.json' not in zf.namelist(): - raise ValueError('Attached labels.zip must contain edit.json.') - with zf.open('edit.json') as f: - edit = json.load(f) - if 'action' not in edit: - raise ValueError('No action specified in edit.json.') - self.action = edit['action'] - self.height = edit['height'] - self.width = edit['width'] - self.args = edit.get('args', None) - # TODO: specify write mode per cell? - self.write_mode = edit.get('writeMode', 'overlap') - if self.write_mode not in self.valid_modes: - raise ValueError( - f'Invalid writeMode {self.write_mode} in edit.json. Choose from cell, overwrite, or exclude.' - ) - - # Load label array - if 'labeled.dat' not in zf.namelist(): - raise ValueError('zip must contain labeled.dat.') - with zf.open('labeled.dat') as f: - labels = np.frombuffer(f.read(), np.int32) - self.initial_labels = np.reshape(labels, (self.height, self.width)) - self.labels = self.initial_labels.copy() - - # Load cells array - if 'cells.json' not in zf.namelist(): - raise ValueError('zip must contain cells.json.') - with zf.open('cells.json') as f: - self.cells = json.load(f) - - # Load raw image - if 'raw.dat' in zf.namelist(): - with zf.open('raw.dat') as f: - raw = np.frombuffer(f.read(), np.uint8) - self.raw = np.reshape(raw, (self.width, self.height)) - elif self.action in self.raw_required: - raise ValueError( - f'Include raw array in raw.json to use action {self.action}.' - ) - - def write_response_zip(self): - """Write edited segmentation to zip.""" - f = io.BytesIO() - with zipfile.ZipFile(f, 'w', compression=zipfile.ZIP_DEFLATED) as zf: - zf.writestr('labeled.dat', self.labels.tobytes()) - # Remove cell labels that are not in the segmentation - values = np.unique(self.labels) - self.cells = list(filter(lambda c: c['value'] in values, self.cells)) - zf.writestr('cells.json', json.dumps(self.cells)) - f.seek(0) - self.response_zip = f - - def get_cells(self, value): - """ - Returns a list of cells encoded by the value - """ - return list( - map(lambda c: c['cell'], filter(lambda c: c['value'] == value, self.cells)) - ) - - def get_values(self, cell): - """ - Returns a list of values that encode a cell - """ - return list( - map(lambda c: c['value'], filter(lambda c: c['cell'] == cell, self.cells)) - ) - - def get_value(self, cells): - """ - Returns the value that encodes the list of cells - """ - if cells == []: - return 0 - values = set(map(lambda c: c['value'], self.cells)) - for cell in cells: - values = values & set(self.get_values(cell)) - for value in values: - if set(self.get_cells(value)) == set(cells): - return value - value = self.new_value - for cell in cells: - self.cells.append({'value': value, 'cell': cell}) - return value - - def get_mask(self, cell): - """ - Returns a boolean mask of the cell (or the background when cell == 0) - """ - if cell == 0: - return self.labels == 0 - mask = np.zeros(self.labels.shape, dtype=bool) - for value in self.get_values(cell): - mask[self.labels == value] = True - return mask - - def add_mask(self, mask, cell): - self.labels = self.clean_labels(self.labels, self.cells) - if self.write_mode == 'overwrite': - self.labels[mask] = self.get_value([cell]) - elif self.write_mode == 'exclude': - mask = mask & (self.labels == 0) - self.labels[mask] = self.get_value([cell]) - else: # self.write_mode == 'overlap' - self.overlap_mask(mask, cell) - - def remove_mask(self, mask, cell): - self.overlap_mask(mask, cell, remove=True) - - def overlap_mask(self, mask, cell, remove=False): - """ - Adds the cell to the segmentation in the mask area, - overlapping with existing cells. - """ - # Rewrite values inside mask to encode label - values = np.unique(self.labels[mask]) - for value in values: - # Get value to encode new set of labels - cells = self.get_cells(value) - if remove: - if cell in cells: - cells.remove(cell) - else: - cells.append(cell) - new_value = self.get_value(cells) - self.labels[mask & (self.labels == value)] = new_value - - def clean_cell(self, cell): - """Ensures that a cell is a positive integer""" - return int(max(0, cell)) - - def clean_labels(self, labeled, cells): - """ - Ensures that labels do not include any values that do not correspond - to cells (eg. for deleted cells.) - - Args: - labeled: numpy array of shape (height, width) - cells: list of cells labels like { "cell": 1, "value": 1, "t": 0} - - Returns: - (numpy array of shape (height, width), cells with updated values) - """ - values = [cell['value'] for cell in cells] # get list of values - deleted_mask = np.isin(labeled, values, invert=True) - labeled[deleted_mask] = 0 # delete any labels not in values - return labeled - - def dispatch_action(self): - """ - Call an action method based on an action type. - - Args: - action (str): name of action method after "action_" - e.g. "draw" to call "action_draw" - info (dict): key value pairs with arguments for action - """ - attr_name = 'action_{}'.format(self.action) - try: - action_fn = getattr(self, attr_name) - except AttributeError: - raise ValueError('Invalid action "{}"'.format(self.action)) - action_fn(**self.args) - - def action_draw(self, trace, brush_size, cell, erase=False): - """ - Use a "brush" to draw in the brush value along trace locations of - the annotated data. - - Args: - trace (list): list of (x, y) coordinates where the brush has painted - brush_size (int): radius of the brush in pixels - cell (int): cell to edit with the brush - erase (bool): whether to add or remove label from brush stroke area - """ - trace = json.loads(trace) - # Create mask for brush stroke - brush_mask = np.zeros(self.labels.shape, dtype=bool) - for loc in trace: - x = loc[0] - y = loc[1] - disk = skimage.draw.disk((y, x), brush_size, shape=self.labels.shape) - brush_mask[disk] = True - - if erase: - self.remove_mask(brush_mask, cell) - else: - self.add_mask(brush_mask, cell) - - def action_trim_pixels(self, cell, x, y): - """ - Removes parts of cell not connected to (x, y). - - Args: - cell (int): cell to trim - x (int): x position of seed - y (int): y position of seed - """ - mask = self.get_mask(cell) - if mask[y, x]: - connected_mask = flood(mask, (y, x)) - self.remove_mask(~connected_mask, cell) - - # TODO: come back to flooding with overlaps... - def action_flood(self, foreground, background, x, y): - """ - Floods the connected component of the background label at (x, y) with the foreground label. - When the background label is 0, does not flood diagonally connected pixels. - - Args: - foreground (int): label to flood with - bacgkround (int): label to flood - x (int): x coordinate of region to flood - y (int): y coordinate of region to flood - """ - mask = self.get_mask(background) - flooded = flood(mask, (y, x), connectivity=2 if background != 0 else 1) - self.add_mask(flooded, foreground) - - def action_watershed(self, cell, new_cell, x1, y1, x2, y2): - """Use watershed to segment different objects""" - # Create markers for to seed watershed labels - markers = np.zeros(self.labels.shape) - markers[y1, x1] = cell - markers[y2, x2] = new_cell - - # Cut images to cell bounding box - mask = self.get_mask(cell) - props = regionprops(mask.astype(np.uint8)) - top, left, bottom, right = props[0].bbox - raw = np.copy(self.raw[top:bottom, left:right]) - markers = np.copy(markers[top:bottom, left:right]) - mask = np.copy(mask[top:bottom, left:right]) - - # Contrast adjust and invert the raw image - raw = -rescale_intensity(raw) - # Apply watershed - results = watershed(raw, markers, mask=mask) - - # Dilate small cells to prevent "dimmer" cell from being eroded by the "brighter" cell - if np.sum(results == new_cell) < 5: - dilated = dilation(results == new_cell, disk(3)) - results[dilated] = new_cell - if np.sum(results == cell) < 5: - dilated = dilation(results == cell, disk(3)) - results[dilated] = cell - - # Update cells where watershed changed cell - new_cell_mask = np.zeros(self.labels.shape, dtype=bool) - cell_mask = np.zeros(self.labels.shape, dtype=bool) - new_cell_mask[top:bottom, left:right] = results == new_cell - cell_mask[top:bottom, left:right] = results == cell - self.remove_mask(self.get_mask(cell), cell) - self.add_mask(cell_mask, cell) - self.add_mask(new_cell_mask, new_cell) - - def action_threshold(self, y1, x1, y2, x2, cell): - """ - Threshold the raw image for annotation prediction within the - user-determined bounding box. - - Args: - y1 (int): first y coordinate to bound threshold area - x1 (int): first x coordinate to bound threshold area - y2 (int): second y coordinate to bound threshold area - x2 (int): second x coordinate to bound threshold area - cell (int): cell drawn in threshold area - """ - cell = self.clean_cell(cell) - # Make bounding box from coordinates - top = min(y1, y2) - bottom = max(y1, y2) + 1 - left = min(x1, x2) - right = max(x1, x2) + 1 - image = self.raw[top:bottom, left:right].astype('float64') - # Hysteresis thresholding strategy needs two thresholds - # triangle threshold picked after trying a few on one dataset - # it may not be the best approach for other datasets! - low = filters.threshold_triangle(image=image) - high = 1.10 * low - # Limit stray pixelst - thresholded = filters.apply_hysteresis_threshold(image, low, high) - mask = np.zeros(self.labels.shape, dtype=bool) - mask[top:bottom, left:right] = thresholded - self.add_mask(mask, cell) - - def action_active_contour(self, cell, min_pixels=20, iterations=100, dilate=0): - """ - Uses active contouring to reshape a cell to match the raw image. - """ - mask = self.get_mask(cell) - # Limit contouring to a bounding box twice the size of the cell - props = regionprops(mask.astype(np.uint8))[0] - top, left, bottom, right = props.bbox - cell_height = bottom - top - cell_width = right - left - # Double size of bounding box - height, width = self.labels.shape - top = max(0, top - height // 2) - bottom = min(height, bottom + cell_height // 2) - left = max(0, left - width // 2) - right = min(width, right + cell_width // 2) - - # Contour the cell - init_level_set = mask[top:bottom, left:right] - # Normalize to range [0., 1.] - _vmin, _vmax = self.raw.min(), self.raw.max() - if _vmin == _vmax: - image = np.zeros_like(self.raw) - else: - image = self.raw.copy() - image -= _vmin - image = image / (_vmax - _vmin) - image = image[top:bottom, left:right] - contoured = morphological_chan_vese( - image, iterations, init_level_set=init_level_set - ) - - # Dilate to adjust for tight fit - contoured = dilation(contoured, disk(dilate)) - - # Keep only the largest connected component - regions = skimage.measure.label(contoured) - if np.any(regions): - largest_component = regions == ( - np.argmax(np.bincount(regions.flat)[1:]) + 1 - ) - mask = np.zeros(self.labels.shape, dtype=bool) - mask[top:bottom, left:right] = largest_component - - # Throw away small contoured cells - if np.count_nonzero(mask) >= min_pixels: - self.remove_mask(~mask, cell) - self.add_mask(mask, cell) - - def action_erode(self, cell): - """ - Shrink the selected cell. - """ - mask = self.get_mask(cell) - eroded = erosion(mask, square(3)) - self.remove_mask(mask & ~eroded, cell) - - def action_dilate(self, cell): - """ - Expand the selected cell. - """ - mask = self.get_mask(cell) - dilated = dilation(mask, square(3)) - self.add_mask(dilated, cell) +"""Classes to view and edit DeepCell Label Projects""" +from __future__ import absolute_import, division, print_function + +import io +import json +import zipfile + +import numpy as np +import skimage + +# from flask import current_app +from skimage import filters +from skimage.exposure import rescale_intensity +from skimage.measure import regionprops +from skimage.morphology import dilation, disk, erosion, flood, square +from skimage.segmentation import morphological_chan_vese, watershed + +# from deepcell_label.samlog import generate_masks + + +class Edit(object): + """ + Loads labeled data from a zip file, + edits the labels according to edit.json in the zip, + and writes the edited labels to a new zip file. + + The labels zipfile must contain: + labeled.dat - a binary array buffer of the labeled data (int32) + overlaps.json - a 2D json array describing values encode which cells + the (i, j)th element of overlaps.json is 1 if value i encodes cell j and 0 otherwise + edit.json - a json object describing the edit to be made including + - action (e.g. ) + - the args for the action + - write_mode: one of 'overlap', 'overwrite', or 'exclude' + - height: the height of the labeled (and raw) arrays + - width: the width of the labeled (and raw) arrays + It additionally may contain: + raw.dat - a binary array buffer of the raw data (uint8) + lineage.json - a json object describing the lineage of the cells + """ + + def __init__(self, labels_zip): + + self.valid_modes = ['overlap', 'overwrite', 'exclude'] + self.raw_required = ['watershed', 'active_contour', 'threshold'] + + self.load(labels_zip) + self.dispatch_action() + self.write_response_zip() + + @property + def new_value(self): + """Returns a value not in the segmentation.""" + if len(self.cells) == 0: + return 1 + return max(map(lambda c: c['value'], self.cells)) + 1 + + @property + def new_cell(self): + """Returns a cell not in the segmentation.""" + if len(self.cells) == 0: + return 1 + return max(map(lambda c: c['cell'], self.cells)) + 1 + + def load(self, labels_zip): + """ + Load the project data to edit from a zip file. + """ + if not zipfile.is_zipfile(labels_zip): + raise ValueError('Attached labels.zip is not a zip file.') + zf = zipfile.ZipFile(labels_zip) + + # Load edit args + if 'edit.json' not in zf.namelist(): + raise ValueError('Attached labels.zip must contain edit.json.') + with zf.open('edit.json') as f: + edit = json.load(f) + if 'action' not in edit: + raise ValueError('No action specified in edit.json.') + self.action = edit['action'] + self.height = edit['height'] + self.width = edit['width'] + self.args = edit.get('args', None) + # TODO: specify write mode per cell? + self.write_mode = edit.get('writeMode', 'overlap') + if self.write_mode not in self.valid_modes: + raise ValueError( + f'Invalid writeMode {self.write_mode} in edit.json. Choose from cell, overwrite, or exclude.' + ) + + # Load label array + if 'labeled.dat' not in zf.namelist(): + raise ValueError('zip must contain labeled.dat.') + with zf.open('labeled.dat') as f: + labels = np.frombuffer(f.read(), np.int32) + self.initial_labels = np.reshape(labels, (self.height, self.width)) + self.labels = self.initial_labels.copy() + + # Load cells array + if 'cells.json' not in zf.namelist(): + raise ValueError('zip must contain cells.json.') + with zf.open('cells.json') as f: + self.cells = json.load(f) + + # Load raw image + if 'raw.dat' in zf.namelist(): + with zf.open('raw.dat') as f: + raw = np.frombuffer(f.read(), np.uint8) + self.raw = np.reshape(raw, (self.width, self.height)) + elif self.action in self.raw_required: + raise ValueError( + f'Include raw array in raw.json to use action {self.action}.' + ) + + def write_response_zip(self): + """Write edited segmentation to zip.""" + f = io.BytesIO() + with zipfile.ZipFile(f, 'w', compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr('labeled.dat', self.labels.tobytes()) + # Remove cell labels that are not in the segmentation + values = np.unique(self.labels) + self.cells = list(filter(lambda c: c['value'] in values, self.cells)) + zf.writestr('cells.json', json.dumps(self.cells)) + f.seek(0) + self.response_zip = f + + def get_cells(self, value): + """ + Returns a list of cells encoded by the value + """ + return list( + map(lambda c: c['cell'], filter(lambda c: c['value'] == value, self.cells)) + ) + + def get_values(self, cell): + """ + Returns a list of values that encode a cell + """ + return list( + map(lambda c: c['value'], filter(lambda c: c['cell'] == cell, self.cells)) + ) + + def get_value(self, cells): + """ + Returns the value that encodes the list of cells + """ + if cells == []: + return 0 + values = set(map(lambda c: c['value'], self.cells)) + for cell in cells: + values = values & set(self.get_values(cell)) + for value in values: + if set(self.get_cells(value)) == set(cells): + return value + value = self.new_value + for cell in cells: + self.cells.append({'value': value, 'cell': cell}) + return value + + def get_mask(self, cell): + """ + Returns a boolean mask of the cell (or the background when cell == 0) + """ + if cell == 0: + return self.labels == 0 + mask = np.zeros(self.labels.shape, dtype=bool) + for value in self.get_values(cell): + mask[self.labels == value] = True + return mask + + def add_mask(self, mask, cell): + self.labels = self.clean_labels(self.labels, self.cells) + if self.write_mode == 'overwrite': + self.labels[mask] = self.get_value([cell]) + elif self.write_mode == 'exclude': + mask = mask & (self.labels == 0) + self.labels[mask] = self.get_value([cell]) + else: # self.write_mode == 'overlap' + self.overlap_mask(mask, cell) + + def remove_mask(self, mask, cell): + self.overlap_mask(mask, cell, remove=True) + + def overlap_mask(self, mask, cell, remove=False): + """ + Adds the cell to the segmentation in the mask area, + overlapping with existing cells. + """ + # Rewrite values inside mask to encode label + values = np.unique(self.labels[mask]) + for value in values: + # Get value to encode new set of labels + cells = self.get_cells(value) + if remove: + if cell in cells: + cells.remove(cell) + else: + cells.append(cell) + new_value = self.get_value(cells) + self.labels[mask & (self.labels == value)] = new_value + + def clean_cell(self, cell): + """Ensures that a cell is a positive integer""" + return int(max(0, cell)) + + def clean_labels(self, labeled, cells): + """ + Ensures that labels do not include any values that do not correspond + to cells (eg. for deleted cells.) + + Args: + labeled: numpy array of shape (height, width) + cells: list of cells labels like { "cell": 1, "value": 1, "t": 0} + + Returns: + (numpy array of shape (height, width), cells with updated values) + """ + values = [cell['value'] for cell in cells] # get list of values + deleted_mask = np.isin(labeled, values, invert=True) + labeled[deleted_mask] = 0 # delete any labels not in values + return labeled + + def dispatch_action(self): + """ + Call an action method based on an action type. + + Args: + action (str): name of action method after "action_" + e.g. "draw" to call "action_draw" + info (dict): key value pairs with arguments for action + """ + attr_name = 'action_{}'.format(self.action) + try: + action_fn = getattr(self, attr_name) + except AttributeError: + raise ValueError('Invalid action "{}"'.format(self.action)) + action_fn(**self.args) + + def action_draw(self, trace, brush_size, cell, erase=False): + """ + Use a "brush" to draw in the brush value along trace locations of + the annotated data. + + Args: + trace (list): list of (x, y) coordinates where the brush has painted + brush_size (int): radius of the brush in pixels + cell (int): cell to edit with the brush + erase (bool): whether to add or remove label from brush stroke area + """ + trace = json.loads(trace) + # Create mask for brush stroke + brush_mask = np.zeros(self.labels.shape, dtype=bool) + for loc in trace: + x = loc[0] + y = loc[1] + disk = skimage.draw.disk((y, x), brush_size, shape=self.labels.shape) + brush_mask[disk] = True + + if erase: + self.remove_mask(brush_mask, cell) + else: + self.add_mask(brush_mask, cell) + + def action_trim_pixels(self, cell, x, y): + """ + Removes parts of cell not connected to (x, y). + + Args: + cell (int): cell to trim + x (int): x position of seed + y (int): y position of seed + """ + mask = self.get_mask(cell) + if mask[y, x]: + connected_mask = flood(mask, (y, x)) + self.remove_mask(~connected_mask, cell) + + # TODO: come back to flooding with overlaps... + def action_flood(self, foreground, background, x, y): + """ + Floods the connected component of the background label at (x, y) with the foreground label. + When the background label is 0, does not flood diagonally connected pixels. + + Args: + foreground (int): label to flood with + bacgkround (int): label to flood + x (int): x coordinate of region to flood + y (int): y coordinate of region to flood + """ + mask = self.get_mask(background) + flooded = flood(mask, (y, x), connectivity=2 if background != 0 else 1) + self.add_mask(flooded, foreground) + + def action_watershed(self, cell, new_cell, x1, y1, x2, y2): + """Use watershed to segment different objects""" + # Create markers for to seed watershed labels + markers = np.zeros(self.labels.shape) + markers[y1, x1] = cell + markers[y2, x2] = new_cell + + # Cut images to cell bounding box + mask = self.get_mask(cell) + props = regionprops(mask.astype(np.uint8)) + top, left, bottom, right = props[0].bbox + raw = np.copy(self.raw[top:bottom, left:right]) + markers = np.copy(markers[top:bottom, left:right]) + mask = np.copy(mask[top:bottom, left:right]) + + # Contrast adjust and invert the raw image + raw = -rescale_intensity(raw) + # Apply watershed + results = watershed(raw, markers, mask=mask) + + # Dilate small cells to prevent "dimmer" cell from being eroded by the "brighter" cell + if np.sum(results == new_cell) < 5: + dilated = dilation(results == new_cell, disk(3)) + results[dilated] = new_cell + if np.sum(results == cell) < 5: + dilated = dilation(results == cell, disk(3)) + results[dilated] = cell + + # Update cells where watershed changed cell + new_cell_mask = np.zeros(self.labels.shape, dtype=bool) + cell_mask = np.zeros(self.labels.shape, dtype=bool) + new_cell_mask[top:bottom, left:right] = results == new_cell + cell_mask[top:bottom, left:right] = results == cell + self.remove_mask(self.get_mask(cell), cell) + self.add_mask(cell_mask, cell) + self.add_mask(new_cell_mask, new_cell) + + def action_threshold(self, y1, x1, y2, x2, cell): + """ + Threshold the raw image for annotation prediction within the + user-determined bounding box. + + Args: + y1 (int): first y coordinate to bound threshold area + x1 (int): first x coordinate to bound threshold area + y2 (int): second y coordinate to bound threshold area + x2 (int): second x coordinate to bound threshold area + cell (int): cell drawn in threshold area + """ + cell = self.clean_cell(cell) + # Make bounding box from coordinates + top = min(y1, y2) + bottom = max(y1, y2) + 1 + left = min(x1, x2) + right = max(x1, x2) + 1 + image = self.raw[top:bottom, left:right].astype('float64') + # Hysteresis thresholding strategy needs two thresholds + # triangle threshold picked after trying a few on one dataset + # it may not be the best approach for other datasets! + low = filters.threshold_triangle(image=image) + high = 1.10 * low + # Limit stray pixelst + thresholded = filters.apply_hysteresis_threshold(image, low, high) + mask = np.zeros(self.labels.shape, dtype=bool) + mask[top:bottom, left:right] = thresholded + self.add_mask(mask, cell) + + def action_active_contour(self, cell, min_pixels=20, iterations=100, dilate=0): + """ + Uses active contouring to reshape a cell to match the raw image. + """ + mask = self.get_mask(cell) + # Limit contouring to a bounding box twice the size of the cell + props = regionprops(mask.astype(np.uint8))[0] + top, left, bottom, right = props.bbox + cell_height = bottom - top + cell_width = right - left + # Double size of bounding box + height, width = self.labels.shape + top = max(0, top - height // 2) + bottom = min(height, bottom + cell_height // 2) + left = max(0, left - width // 2) + right = min(width, right + cell_width // 2) + + # Contour the cell + init_level_set = mask[top:bottom, left:right] + # Normalize to range [0., 1.] + _vmin, _vmax = self.raw.min(), self.raw.max() + if _vmin == _vmax: + image = np.zeros_like(self.raw) + else: + image = self.raw.copy() + image -= _vmin + image = image / (_vmax - _vmin) + image = image[top:bottom, left:right] + contoured = morphological_chan_vese( + image, iterations, init_level_set=init_level_set + ) + + # Dilate to adjust for tight fit + contoured = dilation(contoured, disk(dilate)) + + # Keep only the largest connected component + regions = skimage.measure.label(contoured) + if np.any(regions): + largest_component = regions == ( + np.argmax(np.bincount(regions.flat)[1:]) + 1 + ) + mask = np.zeros(self.labels.shape, dtype=bool) + mask[top:bottom, left:right] = largest_component + + # Throw away small contoured cells + if np.count_nonzero(mask) >= min_pixels: + self.remove_mask(~mask, cell) + self.add_mask(mask, cell) + + def action_erode(self, cell): + """ + Shrink the selected cell. + """ + mask = self.get_mask(cell) + eroded = erosion(mask, square(3)) + self.remove_mask(mask & ~eroded, cell) + + def action_dilate(self, cell): + """ + Expand the selected cell. + """ + mask = self.get_mask(cell) + dilated = dilation(mask, square(3)) + self.add_mask(dilated, cell) + + def action_segment_all(self, cell): + # client-side only + # self.labels = generate_masks(self.labels) + # self.labels = self.labels.astype(np.int32).T + # with server + # self.labels = np.load("deepcell_label/mask.npy").astype(np.int32).T + if len(self.labels.shape) == 2: + self.labels = np.expand_dims(np.expand_dims(self.labels, 0), 3) + cells = [] + for t in range(self.labels.shape[0]): + for c in range(self.labels.shape[-1]): + for value in np.unique(self.labels[t, :, :, c]): + if value != 0: + cells.append( + { + 'cell': int(value), + 'value': int(value), + 't': int(t), + 'c': int(c), + } + ) + self.cells = cells diff --git a/backend/deepcell_label/loaders.py b/backend/deepcell_label/loaders.py index 53c2d5a62..ef9ccd7ed 100644 --- a/backend/deepcell_label/loaders.py +++ b/backend/deepcell_label/loaders.py @@ -1,645 +1,656 @@ -""" -Class to load data into a DeepCell Label project file -Loads both raw image data and labels -""" - -import io -import itertools -import json -import re -import tarfile -import tempfile -import zipfile -from xml.etree import ElementTree as ET - -import magic -import numpy as np -from PIL import Image -from tifffile import TiffFile, TiffWriter - -from deepcell_label.utils import convert_lineage, reshape - - -class Loader: - """ - Loads and writes data into a DeepCell Label project zip. - """ - - def __init__(self, image_file=None, label_file=None, axes=None): - """ - Args: - image_file: file zip object containing a png, zip, tiff, or npz file - label_file: file like object containing a zip - axes: dimension order of the image data - """ - self.X = None - self.y = None - self.spots = None - self.divisions = None - self.cellTypes = None - self.cells = None - self.channels = [] - self.embeddings = None - - self.image_file = image_file - self.label_file = label_file if label_file else image_file - self.axes = axes - - with tempfile.TemporaryFile() as project_file: - with zipfile.ZipFile(project_file, 'w', zipfile.ZIP_DEFLATED) as zip: - self.zip = zip - self.load() - self.write() - project_file.seek(0) - self.data = project_file.read() - - def load(self): - """Loads data from input files.""" - self.X = load_images(self.image_file, self.axes) - self.y = load_segmentation(self.label_file) - self.spots = load_spots(self.label_file) - self.divisions = load_divisions(self.label_file) - self.cellTypes = load_cellTypes(self.label_file) - self.cells = load_cells(self.label_file) - self.channels = load_channels(self.image_file) - self.embeddings = load_embeddings(self.label_file) - - if self.y is None: - shape = (*self.X.shape[:-1], 1) - self.y = np.zeros(shape) - - def write(self): - """Writes loaded data to zip.""" - self.write_images() - self.write_segmentation() - self.write_spots() - self.write_divisions() - self.write_cellTypes() - self.write_cells() - self.write_embeddings() - - def write_images(self): - """ - Writes raw images to X.ome.tiff in the output zip. - - Raises: - ValueError: no image data has been loaded to write - """ - X = self.X - if X is not None: - # Move channel axis - X = np.moveaxis(X, -1, 1) - images = io.BytesIO() - channels = [] - for i in range(len(self.channels)): - channels.append({'Name': self.channels[i]}) - with TiffWriter(images, ome=True) as tif: - tif.write( - X, - compression='zlib', - metadata={'axes': 'ZCYX', 'Pixels': {'Channel': channels}}, - ) - images.seek(0) - self.zip.writestr('X.ome.tiff', images.read()) - # else: - # raise ValueError('No images found in files') - - def write_segmentation(self): - """Writes segmentation to y.ome.tiff in the output zip.""" - y = self.y - if y.shape[:-1] != self.X.shape[:-1]: - raise ValueError( - 'Segmentation shape %s is incompatible with image shape %s' - % (y.shape, self.X.shape) - ) - # TODO: check if float vs int matters - y = y.astype(np.int32) - # Move channel axis - y = np.moveaxis(y, -1, 1) - - segmentation = io.BytesIO() - with TiffWriter(segmentation, ome=True) as tif: - tif.write(y, compression='zlib', metadata={'axes': 'ZCYX'}) - segmentation.seek(0) - self.zip.writestr('y.ome.tiff', segmentation.read()) - - def write_spots(self): - """Writes spots to spots.csv in the output zip.""" - if self.spots is not None: - buffer = io.BytesIO() - buffer.write(self.spots) - buffer.seek(0) - self.zip.writestr('spots.csv', buffer.read()) - - def write_divisions(self): - """Writes divisions to divisions.json in the output zip.""" - self.zip.writestr('divisions.json', json.dumps(self.divisions)) - - def write_cellTypes(self): - """Writes cell types to cellTypes.json in the output zip.""" - self.zip.writestr('cellTypes.json', json.dumps(self.cellTypes)) - - def write_embeddings(self): - """Writes embeddings to embeddings.json in the output zip.""" - self.zip.writestr('embeddings.json', json.dumps(self.embeddings)) - - def write_cells(self): - """Writes cells to cells.json in the output zip.""" - if self.cells is None: - cells = [] - for t in range(self.y.shape[0]): - for c in range(self.y.shape[-1]): - for value in np.unique(self.y[t, :, :, c]): - if value != 0: - cells.append( - { - 'cell': int(value), - 'value': int(value), - 't': int(t), - 'c': int(c), - } - ) - self.cells = cells - self.zip.writestr('cells.json', json.dumps(self.cells)) - - -def load_images(image_file, axes=None): - """ - Loads image data from image file. - - Args: - image_file: zip, npy, tiff, or png file object containing image data - - Returns: - numpy array or None if no image data found - """ - X = load_zip(image_file) - if X is None: - X = load_npy(image_file) - if X is None: - X = load_tiff(image_file, axes) - if X is None: - X = load_png(image_file) - if X is None: - X = load_trk(image_file, filename='raw.npy') - return X - - -def load_segmentation(f): - """ - Loads segmentation array from label file. - - Args: - label_file: file with zipped npy or tiff containing segmentation data - - Returns: - numpy array or None if no segmentation data found - """ - f.seek(0) - if zipfile.is_zipfile(f): - zf = zipfile.ZipFile(f, 'r') - y = load_zip_numpy(zf, name='y') - if y is None: - y = load_zip_tiffs(zf, filename='y.ome.tiff') - return y - if tarfile.is_tarfile(f.name): - return load_trk(f, filename='tracked.npy') - - -def load_spots(f): - """ - Load spots data from label file. - - Args: - f: file with zipped csv containing spots data - - Returns: - bytes read from csv in zip or None if no csv in zip - """ - f.seek(0) - if zipfile.is_zipfile(f): - zf = zipfile.ZipFile(f, 'r') - return load_zip_csv(zf) - - -def load_divisions(f): - """ - Load divisions from divisions.json in project archive - - Loading from lineage.json from .trk file is supported, but deprecated. - - Args: - f: zip file with divisions.json - or tarfile with lineage.json - - Returns: - dict or None if divisions.json not found - """ - f.seek(0) - divisions = None - if zipfile.is_zipfile(f): - zf = zipfile.ZipFile(f, 'r') - divisions = load_zip_json(zf, filename='divisions.json') - lineage = load_zip_json(zf, filename='lineage.json') - if lineage: - divisions = convert_lineage(lineage) - elif tarfile.is_tarfile(f.name): - lineage = load_trk(f, filename='lineage.json') - divisions = convert_lineage(lineage) - if divisions is None: - return [] - return divisions - - -def load_cellTypes(f): - """ - Load cell types from cellTypes.json in project archive - - Args: - f: zip file with cellTypes.json - - Returns: - dict or None if cellTypes.json not found - """ - f.seek(0) - cellTypes = None - if zipfile.is_zipfile(f): - zf = zipfile.ZipFile(f, 'r') - cellTypes = load_zip_json(zf, filename='cellTypes.json') - if cellTypes is None: - return [] - return cellTypes - - -def load_embeddings(f): - """ - Load embeddings from embeddings.json in project archive - - Args: - f: zip file with embeddings.json - - Returns: - dict or None if embeddings.json not found - """ - f.seek(0) - embeddings = None - if zipfile.is_zipfile(f): - zf = zipfile.ZipFile(f, 'r') - embeddings = load_zip_json(zf, filename='embeddings.json') - return embeddings - - -def load_cells(f): - """ - Load cells from label file. - - Args: - f: zip file with cells json - - Returns: - dict or None if no json in zip - """ - f.seek(0) - if zipfile.is_zipfile(f): - zf = zipfile.ZipFile(f, 'r') - return load_zip_json(zf, filename='cells.json') - - -def load_channels(f): - """ - Load channels from raw file. - - Args: - f: X.ome.tiff or zip with X.ome.tiff with channel metadata - - Returns: - list or None if no channel metadata - """ - f.seek(0) - channels = [] - if zipfile.is_zipfile(f): - zf = zipfile.ZipFile(f, 'r') - for filename in zf.namelist(): - if filename == 'X.ome.tiff': - with zf.open(filename) as X: - tiff = TiffFile(X) - if tiff.is_ome: - root = ET.fromstring(tiff.ome_metadata) - for child in root.iter(): - if child.tag.endswith('Channel') and 'Name' in child.attrib: - channels.append(child.attrib['Name']) - return channels - - -def load_zip_numpy(zf, name='X'): - """ - Loads a numpy array from the zip file - If loading an NPZ with multiple arrays, name selects which one to load - - Args: - zf: a ZipFile with a npy or npz file - name (str): name of the array to load - - Returns: - numpy array or None if no png in zip - """ - for filename in zf.namelist(): - if filename == f'{name}.npy': - with zf.open(filename) as f: - return np.load(f) - if filename.endswith('.npz'): - with zf.open(filename) as f: - npz = np.load(f) - return npz[name] if name in npz.files else npz[npz.files[0]] - - -def load_zip_tiffs(zf, filename): - """ - Returns an array with all tiff image data in the zip file - - Args: - zf: a ZipFile containing tiffs to load - - Returns: - numpy array or None if no tiffs in zip - """ - if filename in zf.namelist(): - with zf.open(filename) as f: - if 'TIFF image data' in magic.from_buffer(f.read(2048)): - f.seek(0) - tiff = TiffFile(f) - # TODO: check when there are multiple series - axes = tiff.series[0].axes - array = reshape(tiff.asarray(), axes, 'ZYXC') - return array - else: - print(f'{filename} is not a tiff file.') - else: - print(f'{filename} not found in zip.') - print('Loading all tiffs in zip.') - tiffs = {} - for name in zf.namelist(): - with zf.open(name) as f: - if 'TIFF image data' in magic.from_buffer(f.read(2048)): - f.seek(0) - tiff = TiffFile(f).asarray() - tiffs[name] = tiff - if len(tiffs) > 0: - regex = r'(.*)batch_(\d*)_feature_(\d*)\.tif' - - def get_batch(filename): - match = re.match(regex, filename) - if match: - return int(match.group(2)) - - def get_feature(filename): - match = re.match(regex, filename) - if match: - return int(match.group(3)) - - filenames = list(tiffs.keys()) - all_have_batch = all(map(lambda x: x is not None, map(get_batch, filenames))) - if all_have_batch: # Use batches as Z dimension - batches = {} - for batch, batch_group in itertools.groupby(filenames, get_batch): - # Stack features on last axis - features = [ - tiffs[filename] - for filename in sorted(list(batch_group), key=get_feature) - ] - batches[batch] = np.stack(features, axis=-1) - # Stack batches on first axis - batches = map(lambda x: x[1], sorted(batches.items())) - array = np.stack(list(batches), axis=0) - return array - else: # Use each tiff as a channel and stack on the last axis - y = np.stack(list(tiffs.values()), axis=-1) - # Add Z axis - if y.ndim == 3: - y = y[np.newaxis, ...] - return y - - -def load_zip_png(zf): - """ - Returns the image data array for the first PNG image in the zip file - - Args: - zf: a ZipFile with a PNG - - Returns: - numpy array or None if no png in zip - """ - for name in zf.namelist(): - with zf.open(name) as f: - if 'PNG image data' in magic.from_buffer(f.read(2048)): - f.seek(0) - png = Image.open(f) - return np.array(png) - - -def load_zip_csv(zf): - """ - Returns the binary data for the first CSV file in the zip file, if it exists. - - Args: - zf: a ZipFile with a CSV - - Returns: - bytes or None if not a csv file - """ - for name in zf.namelist(): - if name.endswith('.csv'): - with zf.open(name) as f: - return f.read() - - -def load_zip_json(zf, filename=None): - """ - Returns a dicstion json file in the zip file, if it exists. - - Args: - zf: a ZipFile with a CSV - - Returns: - bytes or None if not a csv file - """ - if filename in zf.namelist(): - with zf.open(filename) as f: - try: - f.seek(0) - return json.load(f) - except json.JSONDecodeError as e: - print(f'Warning: Could not load {filename} as JSON. {e.msg}') - return - print(f'Warning: JSON file {filename} not found.') - - -def load_zip(f): - """ - Loads image data from a zip file by loading from the npz, tiff, or png files in the archive - - Args: - f: file object - - Returns: - numpy array or None if not a zip file - """ - f.seek(0) - if zipfile.is_zipfile(f): - zf = zipfile.ZipFile(f, 'r') - X = load_zip_numpy(zf) - if X is None: - X = load_zip_tiffs(zf, filename='X.ome.tiff') - if X is None: - X = load_zip_png(zf) - return X - - -def load_npy(f): - """ - Loads image data from a npy file - - Args: - f: file object - - Returns: - numpy array or None if not a npy file - """ - f.seek(0) - if 'NumPy data file' in magic.from_buffer(f.read(2048)): - f.seek(0) - npy = np.load(f) - return npy - - -def load_tiff(f, axes=None): - """ - Loads image data from a tiff file - - Args: - f: file object - - Returns: - numpy array or None if not a tiff file - - Raises: - ValueError: tiff has less than 2 or more than 4 dimensions - """ - f.seek(0) - if 'TIFF image data' in magic.from_buffer(f.read(2048)): - f.seek(0) - tiff = TiffFile(io.BytesIO(f.read())) - # Load array - if tiff.is_imagej: - X = tiff.asarray() - # TODO: use axes to know which axes to add and permute - # TODO: handle tiffs with multiple series - axes = tiff.series[0].axes - if len(axes) != len(X.shape): - print( - f'Warning: TIFF has shape {X.shape} and axes {axes} in ImageJ metadata' - ) - elif tiff.is_ome: - # TODO: use DimensionOrder from OME-TIFF metadata to know which axes to add and permute - X = tiff.asarray(squeeze=False) - else: - X = tiff.asarray(squeeze=False) - # Standardize dimensions - if X.ndim == 0: - raise ValueError('Loaded image has no data') - elif X.ndim == 1: - raise ValueError('Loaded tiff is 1 dimensional') - elif X.ndim == 2: - # Add Z and C axes - return X[np.newaxis, ..., np.newaxis] - elif X.ndim == 3: - if axes[0] == 'B': - return X[..., np.newaxis] - elif axes[-1] == 'B': - X = np.moveaxis(X, -1, 0) - return X[..., np.newaxis] - elif axes[0] == 'C': - X = np.moveaxis(X, 0, -1) - return X[np.newaxis, ...] - elif axes[-1] == 'C': - return X[np.newaxis, ...] - else: # Add channel axis by default - return X[..., np.newaxis] - elif X.ndim == 4: - if axes == 'BXYC': - return X - elif axes == 'CXYB': - X = np.moveaxis(X, (0, -1), (-1, 0)) - return X - else: - print( - f'Warning: tiff with shape {X.shape} has 4 dimensions, ' - f'but axes is {axes}. Assuming BXYC.' - ) - return X - else: - raise ValueError( - f'Loaded tiff with shape {X.shape} has more than 4 dimensions.' - ) - - -def load_png(f): - """ - Loads image data from a png file - - Args: - f: file object - - Returns: - numpy array or None if not a png file - """ - f.seek(0) - if 'PNG image data' in magic.from_buffer(f.read(2048)): - f.seek(0) - image = Image.open(f, formats=['PNG']) - # Add channel dimension at end to single channel images - if image.mode == 'L': # uint8 - X = np.array(image) - X = np.expand_dims(X, -1) - # TODO: support higher bit raw images - # Currently all images are converted to uint8 - elif image.mode == 'I' or image.mode == 'F': # int32 and float32 - # Rescale data - max, min = np.max(image), np.min(image) - X = (image - min) / (max - min if max - min > 0 else 1) * 255 - X = X.astype(np.uint8) - X = np.expand_dims(X, -1) - else: # P, RGB, RGBA, CMYK,YCbCr - # Create three RGB channels - # Handles RGB, RGBA, P modes - X = np.array(image.convert('RGB')) - # Add T axis at start - X = np.expand_dims(X, 0) - return X - - -def load_trk(f, filename='raw.npy'): - """ - Loads image data from a .trk file containing raw.npy, tracked.npy, and lineage.json - - Args: - f: file object containing a .trk file - filename: name of the file within the .trk to load - - Returns: - numpy array (for raw.npy or tracked.npy) or dictionary (for lineage.json) - """ - f.seek(0) - if tarfile.is_tarfile(f.name): - with tarfile.open(fileobj=f) as trks: - if filename == 'raw.npy' or filename == 'tracked.npy': - # numpy can't read these from disk... - with io.BytesIO() as array_file: - array_file.write(trks.extractfile(filename).read()) - array_file.seek(0) - return np.load(array_file) - if filename == 'lineage.json': - return json.loads( - trks.extractfile(trks.getmember('lineage.json')).read().decode() - ) +""" +Class to load data into a DeepCell Label project file +Loads both raw image data and labels +""" + +import io +import itertools +import json +import re +import tarfile +import tempfile +import zipfile +from xml.etree import ElementTree as ET + +import magic +import numpy as np + +# from flask import current_app +from PIL import Image +from tifffile import TiffFile, TiffWriter + +# from deepcell_label.client import send_to_server +from deepcell_label.utils import convert_lineage, reshape + + +class Loader: + """ + Loads and writes data into a DeepCell Label project zip. + """ + + def __init__(self, image_file=None, label_file=None, axes=None): + """ + Args: + image_file: file zip object containing a png, zip, tiff, or npz file + label_file: file like object containing a zip + axes: dimension order of the image data + """ + self.X = None + self.y = None + self.spots = None + self.divisions = None + self.cellTypes = None + self.cells = None + self.channels = [] + self.embeddings = None + + self.image_file = image_file + self.label_file = label_file if label_file else image_file + self.axes = axes + + with tempfile.TemporaryFile() as project_file: + with zipfile.ZipFile(project_file, 'w', zipfile.ZIP_DEFLATED) as zip: + self.zip = zip + self.load() + self.write() + project_file.seek(0) + self.data = project_file.read() + + def load(self): + """Loads data from input files.""" + self.X = load_images(self.image_file, self.axes) + self.y = load_segmentation(self.label_file) + self.spots = load_spots(self.label_file) + self.divisions = load_divisions(self.label_file) + self.cellTypes = load_cellTypes(self.label_file) + self.cells = load_cells(self.label_file) + self.channels = load_channels(self.image_file) + self.embeddings = load_embeddings(self.label_file) + + if self.y is None: + shape = (*self.X.shape[:-1], 1) + self.y = np.zeros(shape) + + def write(self): + """Writes loaded data to zip.""" + self.write_images() + self.write_segmentation() + self.write_spots() + self.write_divisions() + self.write_cellTypes() + self.write_cells() + self.write_embeddings() + + def write_images(self): + """ + Writes raw images to X.ome.tiff in the output zip. + + Raises: + ValueError: no image data has been loaded to write + """ + X = self.X + # send_to_server(X, "X") + if X is not None: + if len(X.shape) == 3: + X = np.expand_dims(X, 0) + self.X = X + # Move channel axis + X = np.moveaxis(X, -1, 1) + images = io.BytesIO() + channels = [] + for i in range(len(self.channels)): + channels.append({'Name': self.channels[i]}) + with TiffWriter(images, ome=True) as tif: + tif.write( + X, + compression='zlib', + metadata={'axes': 'ZCYX', 'Pixels': {'Channel': channels}}, + ) + images.seek(0) + self.zip.writestr('X.ome.tiff', images.read()) + # raise ValueError('No images found in files') + + def write_segmentation(self): + """Writes segmentation to y.ome.tiff in the output zip.""" + y = self.y + # send_to_server(y, "y") + if len(y.shape) == 2: + y = np.expand_dims(np.expand_dims(y, 0), 3) + self.y = y + if y.shape[:-1] != self.X.shape[:-1]: + raise ValueError( + 'Segmentation shape %s is incompatible with image shape %s' + % (y.shape, self.X.shape) + ) + # TODO: check if float vs int matters + y = y.astype(np.int32) + # Move channel axis + y = np.moveaxis(y, -1, 1) + + segmentation = io.BytesIO() + with TiffWriter(segmentation, ome=True) as tif: + tif.write(y, compression='zlib', metadata={'axes': 'ZCYX'}) + segmentation.seek(0) + self.zip.writestr('y.ome.tiff', segmentation.read()) + + def write_spots(self): + """Writes spots to spots.csv in the output zip.""" + if self.spots is not None: + buffer = io.BytesIO() + buffer.write(self.spots) + buffer.seek(0) + self.zip.writestr('spots.csv', buffer.read()) + + def write_divisions(self): + """Writes divisions to divisions.json in the output zip.""" + self.zip.writestr('divisions.json', json.dumps(self.divisions)) + + def write_cellTypes(self): + """Writes cell types to cellTypes.json in the output zip.""" + self.zip.writestr('cellTypes.json', json.dumps(self.cellTypes)) + + def write_embeddings(self): + """Writes embeddings to embeddings.json in the output zip.""" + self.zip.writestr('embeddings.json', json.dumps(self.embeddings)) + + def write_cells(self): + """Writes cells to cells.json in the output zip.""" + if self.cells is None: + cells = [] + # currently disable instant segmentation upon loading + # for t in range(self.y.shape[0]): + # for c in range(self.y.shape[-1]): + # for value in np.unique(self.y[t, :, :, c]): + # if value != 0: + # cells.append( + # { + # 'cell': int(value), + # 'value': int(value), + # 't': int(t), + # 'c': int(c), + # } + # ) + self.cells = cells + self.zip.writestr('cells.json', json.dumps(self.cells)) + + +def load_images(image_file, axes=None): + """ + Loads image data from image file. + + Args: + image_file: zip, npy, tiff, or png file object containing image data + + Returns: + numpy array or None if no image data found + """ + X = load_zip(image_file) + if X is None: + X = load_npy(image_file) + if X is None: + X = load_tiff(image_file, axes) + if X is None: + X = load_png(image_file) + if X is None: + X = load_trk(image_file, filename='raw.npy') + return X + + +def load_segmentation(f): + """ + Loads segmentation array from label file. + + Args: + label_file: file with zipped npy or tiff containing segmentation data + + Returns: + numpy array or None if no segmentation data found + """ + f.seek(0) + if zipfile.is_zipfile(f): + zf = zipfile.ZipFile(f, 'r') + y = load_zip_numpy(zf, name='y') + if y is None: + y = load_zip_tiffs(zf, filename='y.ome.tiff') + return y + if tarfile.is_tarfile(f.name): + return load_trk(f, filename='tracked.npy') + + +def load_spots(f): + """ + Load spots data from label file. + + Args: + f: file with zipped csv containing spots data + + Returns: + bytes read from csv in zip or None if no csv in zip + """ + f.seek(0) + if zipfile.is_zipfile(f): + zf = zipfile.ZipFile(f, 'r') + return load_zip_csv(zf) + + +def load_divisions(f): + """ + Load divisions from divisions.json in project archive + + Loading from lineage.json from .trk file is supported, but deprecated. + + Args: + f: zip file with divisions.json + or tarfile with lineage.json + + Returns: + dict or None if divisions.json not found + """ + f.seek(0) + divisions = None + if zipfile.is_zipfile(f): + zf = zipfile.ZipFile(f, 'r') + divisions = load_zip_json(zf, filename='divisions.json') + lineage = load_zip_json(zf, filename='lineage.json') + if lineage: + divisions = convert_lineage(lineage) + elif tarfile.is_tarfile(f.name): + lineage = load_trk(f, filename='lineage.json') + divisions = convert_lineage(lineage) + if divisions is None: + return [] + return divisions + + +def load_cellTypes(f): + """ + Load cell types from cellTypes.json in project archive + + Args: + f: zip file with cellTypes.json + + Returns: + dict or None if cellTypes.json not found + """ + f.seek(0) + cellTypes = None + if zipfile.is_zipfile(f): + zf = zipfile.ZipFile(f, 'r') + cellTypes = load_zip_json(zf, filename='cellTypes.json') + if cellTypes is None: + return [] + return cellTypes + + +def load_embeddings(f): + """ + Load embeddings from embeddings.json in project archive + + Args: + f: zip file with embeddings.json + + Returns: + dict or None if embeddings.json not found + """ + f.seek(0) + embeddings = None + if zipfile.is_zipfile(f): + zf = zipfile.ZipFile(f, 'r') + embeddings = load_zip_json(zf, filename='embeddings.json') + return embeddings + + +def load_cells(f): + """ + Load cells from label file. + + Args: + f: zip file with cells json + + Returns: + dict or None if no json in zip + """ + f.seek(0) + if zipfile.is_zipfile(f): + zf = zipfile.ZipFile(f, 'r') + return load_zip_json(zf, filename='cells.json') + + +def load_channels(f): + """ + Load channels from raw file. + + Args: + f: X.ome.tiff or zip with X.ome.tiff with channel metadata + + Returns: + list or None if no channel metadata + """ + f.seek(0) + channels = [] + if zipfile.is_zipfile(f): + zf = zipfile.ZipFile(f, 'r') + for filename in zf.namelist(): + if filename == 'X.ome.tiff': + with zf.open(filename) as X: + tiff = TiffFile(X) + if tiff.is_ome: + root = ET.fromstring(tiff.ome_metadata) + for child in root.iter(): + if child.tag.endswith('Channel') and 'Name' in child.attrib: + channels.append(child.attrib['Name']) + return channels + + +def load_zip_numpy(zf, name='X'): + """ + Loads a numpy array from the zip file + If loading an NPZ with multiple arrays, name selects which one to load + + Args: + zf: a ZipFile with a npy or npz file + name (str): name of the array to load + + Returns: + numpy array or None if no png in zip + """ + for filename in zf.namelist(): + if filename == f'{name}.npy': + with zf.open(filename) as f: + return np.load(f) + if filename.endswith('.npz'): + with zf.open(filename) as f: + npz = np.load(f) + return npz[name] if name in npz.files else npz[npz.files[0]] + + +def load_zip_tiffs(zf, filename): + """ + Returns an array with all tiff image data in the zip file + + Args: + zf: a ZipFile containing tiffs to load + + Returns: + numpy array or None if no tiffs in zip + """ + if filename in zf.namelist(): + with zf.open(filename) as f: + if 'TIFF image data' in magic.from_buffer(f.read(2048)): + f.seek(0) + tiff = TiffFile(f) + # TODO: check when there are multiple series + axes = tiff.series[0].axes + array = reshape(tiff.asarray(), axes, 'ZYXC') + return array + else: + print(f'{filename} is not a tiff file.') + else: + print(f'{filename} not found in zip.') + print('Loading all tiffs in zip.') + tiffs = {} + for name in zf.namelist(): + with zf.open(name) as f: + if 'TIFF image data' in magic.from_buffer(f.read(2048)): + f.seek(0) + tiff = TiffFile(f).asarray() + tiffs[name] = tiff + if len(tiffs) > 0: + regex = r'(.*)batch_(\d*)_feature_(\d*)\.tif' + + def get_batch(filename): + match = re.match(regex, filename) + if match: + return int(match.group(2)) + + def get_feature(filename): + match = re.match(regex, filename) + if match: + return int(match.group(3)) + + filenames = list(tiffs.keys()) + all_have_batch = all(map(lambda x: x is not None, map(get_batch, filenames))) + if all_have_batch: # Use batches as Z dimension + batches = {} + for batch, batch_group in itertools.groupby(filenames, get_batch): + # Stack features on last axis + features = [ + tiffs[filename] + for filename in sorted(list(batch_group), key=get_feature) + ] + batches[batch] = np.stack(features, axis=-1) + # Stack batches on first axis + batches = map(lambda x: x[1], sorted(batches.items())) + array = np.stack(list(batches), axis=0) + return array + else: # Use each tiff as a channel and stack on the last axis + y = np.stack(list(tiffs.values()), axis=-1) + # Add Z axis + if y.ndim == 3: + y = y[np.newaxis, ...] + return y + + +def load_zip_png(zf): + """ + Returns the image data array for the first PNG image in the zip file + + Args: + zf: a ZipFile with a PNG + + Returns: + numpy array or None if no png in zip + """ + for name in zf.namelist(): + with zf.open(name) as f: + if 'PNG image data' in magic.from_buffer(f.read(2048)): + f.seek(0) + png = Image.open(f) + return np.array(png) + + +def load_zip_csv(zf): + """ + Returns the binary data for the first CSV file in the zip file, if it exists. + + Args: + zf: a ZipFile with a CSV + + Returns: + bytes or None if not a csv file + """ + for name in zf.namelist(): + if name.endswith('.csv'): + with zf.open(name) as f: + return f.read() + + +def load_zip_json(zf, filename=None): + """ + Returns a dicstion json file in the zip file, if it exists. + + Args: + zf: a ZipFile with a CSV + + Returns: + bytes or None if not a csv file + """ + if filename in zf.namelist(): + with zf.open(filename) as f: + try: + f.seek(0) + return json.load(f) + except json.JSONDecodeError as e: + print(f'Warning: Could not load {filename} as JSON. {e.msg}') + return + print(f'Warning: JSON file {filename} not found.') + + +def load_zip(f): + """ + Loads image data from a zip file by loading from the npz, tiff, or png files in the archive + + Args: + f: file object + + Returns: + numpy array or None if not a zip file + """ + f.seek(0) + if zipfile.is_zipfile(f): + zf = zipfile.ZipFile(f, 'r') + X = load_zip_numpy(zf) + if X is None: + X = load_zip_tiffs(zf, filename='X.ome.tiff') + if X is None: + X = load_zip_png(zf) + return X + + +def load_npy(f): + """ + Loads image data from a npy file + + Args: + f: file object + + Returns: + numpy array or None if not a npy file + """ + f.seek(0) + if 'NumPy data file' in magic.from_buffer(f.read(2048)): + f.seek(0) + npy = np.load(f) + return npy + + +def load_tiff(f, axes=None): + """ + Loads image data from a tiff file + + Args: + f: file object + + Returns: + numpy array or None if not a tiff file + + Raises: + ValueError: tiff has less than 2 or more than 4 dimensions + """ + f.seek(0) + if 'TIFF image data' in magic.from_buffer(f.read(2048)): + f.seek(0) + tiff = TiffFile(io.BytesIO(f.read())) + # Load array + if tiff.is_imagej: + X = tiff.asarray() + # TODO: use axes to know which axes to add and permute + # TODO: handle tiffs with multiple series + axes = tiff.series[0].axes + if len(axes) != len(X.shape): + print( + f'Warning: TIFF has shape {X.shape} and axes {axes} in ImageJ metadata' + ) + elif tiff.is_ome: + # TODO: use DimensionOrder from OME-TIFF metadata to know which axes to add and permute + X = tiff.asarray(squeeze=False) + else: + X = tiff.asarray(squeeze=False) + # Standardize dimensions + if X.ndim == 0: + raise ValueError('Loaded image has no data') + elif X.ndim == 1: + raise ValueError('Loaded tiff is 1 dimensional') + elif X.ndim == 2: + # Add Z and C axes + return X[np.newaxis, ..., np.newaxis] + elif X.ndim == 3: + if axes[0] == 'B': + return X[..., np.newaxis] + elif axes[-1] == 'B': + X = np.moveaxis(X, -1, 0) + return X[..., np.newaxis] + elif axes[0] == 'C': + X = np.moveaxis(X, 0, -1) + return X[np.newaxis, ...] + elif axes[-1] == 'C': + return X[np.newaxis, ...] + else: # Add channel axis by default + return X[..., np.newaxis] + elif X.ndim == 4: + if axes == 'BXYC': + return X + elif axes == 'CXYB': + X = np.moveaxis(X, (0, -1), (-1, 0)) + return X + else: + print( + f'Warning: tiff with shape {X.shape} has 4 dimensions, ' + f'but axes is {axes}. Assuming BXYC.' + ) + return X + else: + raise ValueError( + f'Loaded tiff with shape {X.shape} has more than 4 dimensions.' + ) + + +def load_png(f): + """ + Loads image data from a png file + + Args: + f: file object + + Returns: + numpy array or None if not a png file + """ + f.seek(0) + if 'PNG image data' in magic.from_buffer(f.read(2048)): + f.seek(0) + image = Image.open(f, formats=['PNG']) + # Add channel dimension at end to single channel images + if image.mode == 'L': # uint8 + X = np.array(image) + X = np.expand_dims(X, -1) + # TODO: support higher bit raw images + # Currently all images are converted to uint8 + elif image.mode == 'I' or image.mode == 'F': # int32 and float32 + # Rescale data + max, min = np.max(image), np.min(image) + X = (image - min) / (max - min if max - min > 0 else 1) * 255 + X = X.astype(np.uint8) + X = np.expand_dims(X, -1) + else: # P, RGB, RGBA, CMYK,YCbCr + # Create three RGB channels + # Handles RGB, RGBA, P modes + X = np.array(image.convert('RGB')) + # Add T axis at start + X = np.expand_dims(X, 0) + return X + + +def load_trk(f, filename='raw.npy'): + """ + Loads image data from a .trk file containing raw.npy, tracked.npy, and lineage.json + + Args: + f: file object containing a .trk file + filename: name of the file within the .trk to load + + Returns: + numpy array (for raw.npy or tracked.npy) or dictionary (for lineage.json) + """ + f.seek(0) + if tarfile.is_tarfile(f.name): + with tarfile.open(fileobj=f) as trks: + if filename == 'raw.npy' or filename == 'tracked.npy': + # numpy can't read these from disk... + with io.BytesIO() as array_file: + array_file.write(trks.extractfile(filename).read()) + array_file.seek(0) + return np.load(array_file) + if filename == 'lineage.json': + return json.loads( + trks.extractfile(trks.getmember('lineage.json')).read().decode() + ) diff --git a/backend/deepcell_label/models.py b/backend/deepcell_label/models.py index b5cc56784..fe854eb17 100644 --- a/backend/deepcell_label/models.py +++ b/backend/deepcell_label/models.py @@ -1,97 +1,97 @@ -"""SQL Alchemy database models.""" -from __future__ import absolute_import, division, print_function - -import io -import logging -import timeit -from secrets import token_urlsafe - -import boto3 -from flask_sqlalchemy import SQLAlchemy - -from deepcell_label.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, S3_BUCKET - -logger = logging.getLogger('models.Project') # pylint: disable=C0103 -db = SQLAlchemy() # pylint: disable=C0103 - - -class Project(db.Model): - """Project table definition.""" - - # pylint: disable=E1101 - __tablename__ = 'projects' - id = db.Column(db.Integer, primary_key=True, autoincrement=True) - project = db.Column(db.String(12), unique=True, nullable=False, index=True) - createdAt = db.Column(db.TIMESTAMP, nullable=False, default=db.func.now()) - bucket = db.Column(db.Text, nullable=False) - key = db.Column(db.Text, nullable=False) - - def __init__(self, loader): - """ - Args: - loader: loaders.Loader object - """ - start = timeit.default_timer() - - # Create a unique 12 character base64 project ID - while True: - project = token_urlsafe(9) # 9 bytes is 12 base64 characters - if not db.session.query(Project).filter_by(project=project).first(): - self.project = project - break - - # Upload to s3 - s3 = boto3.client( - 's3', - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - ) - self.bucket = S3_BUCKET - self.key = f'{self.project}.zip' - fileobj = io.BytesIO(loader.data) - s3.upload_fileobj(fileobj, self.bucket, self.key) - - logger.debug( - 'Initialized project %s and uploaded to %s in %ss.', - self.project, - self.bucket, - timeit.default_timer() - start, - ) - - @staticmethod - def get(project): - """ - Return the project with the given ID, if it exists. - - Args: - project (int): unique 12 character base64 string to identify project - - Returns: - Project: row from the Project table - """ - start = timeit.default_timer() - project = db.session.query(Project).filter_by(project=project).first() - logger.debug('Got project %s in %ss.', project, timeit.default_timer() - start) - return project - - @staticmethod - def create(data): - """ - Create a new project in the Project table. - - Args: - data: zip file with loaded project data - - Returns: - Project: new row in the Project table - """ - start = timeit.default_timer() - project = Project(data) - db.session.add(project) - db.session.commit() - logger.debug( - 'Created new project %s in %ss.', - project.project, - timeit.default_timer() - start, - ) - return project +"""SQL Alchemy database models.""" +from __future__ import absolute_import, division, print_function + +import io +import logging +import timeit +from secrets import token_urlsafe + +import boto3 +from flask_sqlalchemy import SQLAlchemy + +from deepcell_label.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, S3_BUCKET + +logger = logging.getLogger('models.Project') # pylint: disable=C0103 +db = SQLAlchemy() # pylint: disable=C0103 + + +class Project(db.Model): + """Project table definition.""" + + # pylint: disable=E1101 + __tablename__ = 'projects' + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + project = db.Column(db.String(12), unique=True, nullable=False, index=True) + createdAt = db.Column(db.TIMESTAMP, nullable=False, default=db.func.now()) + bucket = db.Column(db.Text, nullable=False) + key = db.Column(db.Text, nullable=False) + + def __init__(self, loader): + """ + Args: + loader: loaders.Loader object + """ + start = timeit.default_timer() + + # Create a unique 12 character base64 project ID + while True: + project = token_urlsafe(9) # 9 bytes is 12 base64 characters + if not db.session.query(Project).filter_by(project=project).first(): + self.project = project + break + + # Upload to s3 + s3 = boto3.client( + 's3', + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + ) + self.bucket = S3_BUCKET + self.key = f'{self.project}.zip' + fileobj = io.BytesIO(loader.data) + s3.upload_fileobj(fileobj, self.bucket, self.key) + + logger.debug( + 'Initialized project %s and uploaded to %s in %ss.', + self.project, + self.bucket, + timeit.default_timer() - start, + ) + + @staticmethod + def get(project): + """ + Return the project with the given ID, if it exists. + + Args: + project (int): unique 12 character base64 string to identify project + + Returns: + Project: row from the Project table + """ + start = timeit.default_timer() + project = db.session.query(Project).filter_by(project=project).first() + logger.debug('Got project %s in %ss.', project, timeit.default_timer() - start) + return project + + @staticmethod + def create(data): + """ + Create a new project in the Project table. + + Args: + data: zip file with loaded project data + + Returns: + Project: new row in the Project table + """ + start = timeit.default_timer() + project = Project(data) + db.session.add(project) + db.session.commit() + logger.debug( + 'Created new project %s in %ss.', + project.project, + timeit.default_timer() - start, + ) + return project diff --git a/backend/requirements.txt b/backend/requirements.txt index 4a671ec73..9d4e93454 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -10,9 +10,10 @@ mysqlclient~=2.1.0 numpy python-decouple~=3.1 python-dotenv~=0.19.2 -python-magic~=0.4.25 +# python-magic~=0.4.25 requests~=2.29.0 scikit-image~=0.19.0 sqlalchemy~=1.3.24 tifffile imagecodecs +websockets diff --git a/cypress.config.js b/cypress.config.js new file mode 100644 index 000000000..0969aae3f --- /dev/null +++ b/cypress.config.js @@ -0,0 +1,7 @@ +module.exports = { + e2e: { + setupNodeEvents(on, config) { + // implement node event listeners here + }, + }, +}; diff --git a/cypress/fixtures/example.json b/cypress/fixtures/example.json new file mode 100644 index 000000000..02e425437 --- /dev/null +++ b/cypress/fixtures/example.json @@ -0,0 +1,5 @@ +{ + "name": "Using fixtures to represent data", + "email": "hello@cypress.io", + "body": "Fixtures are a great way to mock data for responses to routes" +} diff --git a/cypress/support/commands.js b/cypress/support/commands.js new file mode 100644 index 000000000..119ab03f7 --- /dev/null +++ b/cypress/support/commands.js @@ -0,0 +1,25 @@ +// *********************************************** +// This example commands.js shows you how to +// create various custom commands and overwrite +// existing commands. +// +// For more comprehensive examples of custom +// commands please read more here: +// https://on.cypress.io/custom-commands +// *********************************************** +// +// +// -- This is a parent command -- +// Cypress.Commands.add('login', (email, password) => { ... }) +// +// +// -- This is a child command -- +// Cypress.Commands.add('drag', { prevSubject: 'element'}, (subject, options) => { ... }) +// +// +// -- This is a dual command -- +// Cypress.Commands.add('dismiss', { prevSubject: 'optional'}, (subject, options) => { ... }) +// +// +// -- This will overwrite an existing command -- +// Cypress.Commands.overwrite('visit', (originalFn, url, options) => { ... }) diff --git a/cypress/support/e2e.js b/cypress/support/e2e.js new file mode 100644 index 000000000..5df9c0186 --- /dev/null +++ b/cypress/support/e2e.js @@ -0,0 +1,20 @@ +// *********************************************************** +// This example support/e2e.js is processed and +// loaded automatically before your test files. +// +// This is a great place to put global configuration and +// behavior that modifies Cypress. +// +// You can change the location of this file or turn off +// automatically serving support files with the +// 'supportFile' configuration option. +// +// You can read more here: +// https://on.cypress.io/configuration +// *********************************************************** + +// Import commands.js using ES2015 syntax: +import './commands'; + +// Alternatively you can use CommonJS syntax: +// require('./commands') diff --git a/frontend/cypress/screenshots/caliban.cy.js/shows loading spinner (failed).png b/frontend/cypress/screenshots/caliban.cy.js/shows loading spinner (failed).png new file mode 100644 index 000000000..d88144624 Binary files /dev/null and b/frontend/cypress/screenshots/caliban.cy.js/shows loading spinner (failed).png differ diff --git a/frontend/src/Project/DisplayControls/LabeledControls/CellsOpacitySlider.js b/frontend/src/Project/DisplayControls/LabeledControls/CellsOpacitySlider.js index bbb7cc248..68684ea88 100644 --- a/frontend/src/Project/DisplayControls/LabeledControls/CellsOpacitySlider.js +++ b/frontend/src/Project/DisplayControls/LabeledControls/CellsOpacitySlider.js @@ -3,7 +3,7 @@ import Slider from '@mui/material/Slider'; // import Tooltip from '@mui/material/Tooltip'; import { useSelector } from '@xstate/react'; import React, { useEffect } from 'react'; -import { useLabeled, useMousetrapRef } from '../../ProjectContext'; +import { useLabeled, useMousetrapRef, useLabelMode } from '../../ProjectContext'; let numMounted = 0; @@ -18,6 +18,9 @@ function CellsOpacitySlider() { const handleDoubleClick = () => labeled.send({ type: 'SET_CELLS_OPACITY', opacity: 0.3 }); // [0.3, 1] for range slider + const labelMode = useLabelMode(); + const isCellTypes = useSelector(labelMode, (state) => state.matches('editCellTypes')); + useEffect(() => { const listener = (e) => { if (e.key === 'z') { @@ -47,7 +50,7 @@ function CellsOpacitySlider() { return ( Z} placement='right'> ); diff --git a/frontend/src/Project/EditControls/EditControls.js b/frontend/src/Project/EditControls/EditControls.js index 5c582a78f..1a22c4399 100644 --- a/frontend/src/Project/EditControls/EditControls.js +++ b/frontend/src/Project/EditControls/EditControls.js @@ -6,6 +6,7 @@ import CellControls from './CellControls'; import CellTypeControls from './CellTypeControls'; import TrackingControls from './DivisionsControls'; import SegmentControls from './SegmentControls'; +import SegmentSamControls from './SegmentSamControls'; function TabPanel(props) { const { children, value, index, ...other } = props; @@ -28,14 +29,16 @@ function EditControls() { const value = useSelector(labelMode, (state) => { return state.matches('editSegment') ? 0 - : state.matches('editCells') + : state.matches('editSegmentSam') ? 1 - : state.matches('editDivisions') + : state.matches('editCells') ? 2 - : state.matches('editCellTypes') + : state.matches('editDivisions') ? 3 - : state.matches('editSpots') + : state.matches('editCellTypes') ? 4 + : state.matches('editSpots') + ? 5 : false; }); @@ -51,15 +54,18 @@ function EditControls() { - + - + - + + + + diff --git a/frontend/src/Project/EditControls/EditTabs.js b/frontend/src/Project/EditControls/EditTabs.js index fdabcae2a..4cf7f2dc6 100644 --- a/frontend/src/Project/EditControls/EditTabs.js +++ b/frontend/src/Project/EditControls/EditTabs.js @@ -13,14 +13,16 @@ function EditTabs() { const value = useSelector(labelMode, (state) => { return state.matches('editSegment') ? 0 - : state.matches('editCells') + : state.matches('editSegmentSam') ? 1 - : state.matches('editDivisions') + : state.matches('editCells') ? 2 - : state.matches('editCellTypes') + : state.matches('editDivisions') ? 3 - : state.matches('editSpots') + : state.matches('editCellTypes') ? 4 + : state.matches('editSpots') + ? 5 : false; }); const handleChange = (event, newValue) => { @@ -29,15 +31,18 @@ function EditTabs() { labelMode.send('EDIT_SEGMENT'); break; case 1: - labelMode.send('EDIT_CELLS'); + labelMode.send('EDIT_SEGMENT_SAM'); break; case 2: - labelMode.send('EDIT_DIVISIONS'); + labelMode.send('EDIT_CELLS'); break; case 3: - labelMode.send('EDIT_CELLTYPES'); + labelMode.send('EDIT_DIVISIONS'); break; case 4: + labelMode.send('EDIT_CELLTYPES'); + break; + case 5: labelMode.send('EDIT_SPOTS'); break; default: @@ -72,6 +77,7 @@ function EditTabs() { variant='scrollable' > + diff --git a/frontend/src/Project/EditControls/SegmentSamControls/ActionButtons.js b/frontend/src/Project/EditControls/SegmentSamControls/ActionButtons.js new file mode 100644 index 000000000..11057e318 --- /dev/null +++ b/frontend/src/Project/EditControls/SegmentSamControls/ActionButtons.js @@ -0,0 +1,17 @@ +import { FormLabel } from '@mui/material'; +import Box from '@mui/material/Box'; +import ButtonGroup from '@mui/material/ButtonGroup'; +import SegmentAllButton from './ActionButtons/SegmentAllButton'; + +function ActionButtons() { + return ( + + Actions + + + + + ); +} + +export default ActionButtons; diff --git a/frontend/src/Project/EditControls/SegmentSamControls/ActionButtons/ActionButton.js b/frontend/src/Project/EditControls/SegmentSamControls/ActionButtons/ActionButton.js new file mode 100644 index 000000000..2c1e09c51 --- /dev/null +++ b/frontend/src/Project/EditControls/SegmentSamControls/ActionButtons/ActionButton.js @@ -0,0 +1,36 @@ +import Button from '@mui/material/Button'; +import Tooltip from '@mui/material/Tooltip'; +import { bind } from 'mousetrap'; +import React, { useEffect } from 'react'; + +// for adding tooltip to disabled buttons +// from https://stackoverflow.com/questions/61115913 + +const ActionButton = ({ tooltipText, disabled, onClick, hotkey, ...other }) => { + const adjustedButtonProps = { + disabled: disabled, + component: disabled ? 'div' : undefined, + onClick: disabled ? undefined : onClick, + }; + + useEffect(() => { + bind(hotkey, onClick); + }, [hotkey, onClick]); + + return ( + +