From 1a361ccec9b60a1fc00079df419d5612cd2a857a Mon Sep 17 00:00:00 2001 From: arkkienkeli Date: Wed, 19 May 2021 12:37:40 +0200 Subject: [PATCH] https://github.com/cytomining/DeepProfiler/issues/258 --- deepprofiler/__main__.py | 8 ++++---- deepprofiler/dataset/image_dataset.py | 19 +++++++++++-------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/deepprofiler/__main__.py b/deepprofiler/__main__.py index f195955..70ef5b6 100644 --- a/deepprofiler/__main__.py +++ b/deepprofiler/__main__.py @@ -146,7 +146,7 @@ def prepare(context): def sample_sc(context): if context.parent.obj["config"]["prepare"]["compression"]["implement"]: context.parent.obj["config"]["paths"]["images"] = context.obj["config"]["paths"]["compressed_images"] - dset = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"]) + dset = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"], mode='train') deepprofiler.dataset.sampling.sample_dataset(context.obj["config"], dset) print("Single-cell sampling complete.") @@ -159,7 +159,7 @@ def sample_sc(context): def train(context, epoch, seed): if context.parent.obj["config"]["prepare"]["compression"]["implement"]: context.parent.obj["config"]["paths"]["images"] = context.obj["config"]["paths"]["compressed_images"] - dset = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"]) + dset = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"], mode='train') deepprofiler.learning.training.learn_model(context.obj["config"], dset, epoch, seed) @@ -177,8 +177,8 @@ def profile(context, part): if part >= 0: partfile = "index-{0:03d}.csv".format(part) config["paths"]["index"] = context.obj["config"]["paths"]["index"].replace("index.csv", partfile) - metadata = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"]) - deepprofiler.learning.profiling.profile(context.obj["config"], metadata) + dset = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"], mode='profile') + deepprofiler.learning.profiling.profile(context.obj["config"], dset) # Auxiliary tool: Split index in multiple parts diff --git a/deepprofiler/dataset/image_dataset.py b/deepprofiler/dataset/image_dataset.py index 3b6b8c1..5f99a85 100644 --- a/deepprofiler/dataset/image_dataset.py +++ b/deepprofiler/dataset/image_dataset.py @@ -195,7 +195,7 @@ def number_of_records(self, dataset): def add_target(self, new_target): self.targets.append(new_target) -def read_dataset(config): +def read_dataset(config, mode = 'train'): # Read metadata and split dataset in training and validation metadata = deepprofiler.dataset.metadata.Metadata(config["paths"]["index"], dtype=None) if config["prepare"]["compression"]["implement"]: @@ -211,10 +211,12 @@ def read_dataset(config): print(metadata.data.info()) # Split training data - split_field = config["train"]["partition"]["split_field"] - trainingFilter = lambda df: df[split_field].isin(config["train"]["partition"]["training_values"]) - validationFilter = lambda df: df[split_field].isin(config["train"]["partition"]["validation_values"]) - metadata.splitMetadata(trainingFilter, validationFilter) + if mode == 'train': + split_field = config["train"]["partition"]["split_field"] + trainingFilter = lambda df: df[split_field].isin(config["train"]["partition"]["training_values"]) + validationFilter = lambda df: df[split_field].isin(config["train"]["partition"]["validation_values"]) + metadata.splitMetadata(trainingFilter, validationFilter) + # Create a dataset keyGen = lambda r: "{}/{}-{}".format(r["Metadata_Plate"], r["Metadata_Well"], r["Metadata_Site"]) @@ -228,9 +230,10 @@ def read_dataset(config): ) # Add training targets - for t in config["train"]["partition"]["targets"]: - new_target = deepprofiler.dataset.target.MetadataColumnTarget(t, metadata.data[t].unique()) - dset.add_target(new_target) + if mode == 'train': + for t in config["train"]["partition"]["targets"]: + new_target = deepprofiler.dataset.target.MetadataColumnTarget(t, metadata.data[t].unique()) + dset.add_target(new_target) # Activate outlines for masking if needed if config["dataset"]["locations"]["mask_objects"]: