Skip to content

Commit 0d5dc62

Browse files
committed
fix for shape embeded: using name space variables and fix the scaling dimensions
1 parent 581f98f commit 0d5dc62

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

scripts/shapes/shape_embed.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
2020
import pytorch_lightning as pl
2121
import torch
22+
from types import SimpleNamespace
2223

2324
# Deal with the filesystem
2425
import torch.multiprocessing
@@ -105,7 +106,7 @@ def shape_embed_process():
105106
window_size = 128 * 2
106107

107108
params = {
108-
"model":"resnet50_vqvae",
109+
"model":"resnet18_vqvae_legacy",
109110
"epochs": 75,
110111
"batch_size": 4,
111112
"num_workers": 2**4,
@@ -142,14 +143,15 @@ def shape_embed_process():
142143

143144
args = SimpleNamespace(**params, **optimizer_params, **lr_scheduler_params)
144145

145-
dataset_path = "bbbc010/BBBC010_v1_foreground_eachworm"
146+
#dataset_path = "bbbc010/BBBC010_v1_foreground_eachworm"
147+
dataset_path = "shape_embed_data/data/bbbc010/BBBC010_v1_foreground_eachworm/"
146148
# dataset_path = "vampire/mefs/data/processed/Control"
147149
# dataset_path = "shape_embed_data/data/vampire/torchvision/Control/"
148150
# dataset_path = "vampire/torchvision/Control"
149151
# dataset = "bbbc010"
150152

151153
# train_data_path = f"scripts/shapes/data/{dataset_path}"
152-
train_data_path = f"data/{dataset_path}"
154+
train_data_path = f"scripts/shapes/data/{dataset_path}"
153155
metadata = lambda x: f"results/{dataset_path}_{args.model}/{x}"
154156

155157
path = Path(metadata(""))
@@ -336,8 +338,10 @@ def shape_embed_process():
336338
dataloader.setup()
337339

338340
predictions = trainer.predict(lit_model, datamodule=dataloader)
339-
latent_space = torch.stack([d["z"].flatten() for d in predictions])
340-
scalings = torch.stack([d["scalings"].flatten() for d in predictions])
341+
342+
# Use the namespace variables
343+
latent_space = torch.stack([d.out.z.flatten() for d in predictions])
344+
scalings = torch.stack([d.x.scalings.flatten() for d in predictions])
341345

342346
idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()}
343347

@@ -366,7 +370,7 @@ def shape_embed_process():
366370
# Map numeric classes to their labels
367371
idx_to_class = {0: "alive", 1: "dead"}
368372
df["Class"] = df["Class"].map(idx_to_class)
369-
df["Scale"] = scalings
373+
df["Scale"] = scalings[:, 0].squeeze()
370374
df = df.set_index("Class")
371375
df_shape_embed = df.copy()
372376

0 commit comments

Comments
 (0)