|
19 | 19 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
20 | 20 | import pytorch_lightning as pl
|
21 | 21 | import torch
|
| 22 | +from types import SimpleNamespace |
22 | 23 |
|
23 | 24 | # Deal with the filesystem
|
24 | 25 | import torch.multiprocessing
|
@@ -105,7 +106,7 @@ def shape_embed_process():
|
105 | 106 | window_size = 128 * 2
|
106 | 107 |
|
107 | 108 | params = {
|
108 |
| - "model":"resnet50_vqvae", |
| 109 | + "model":"resnet18_vqvae_legacy", |
109 | 110 | "epochs": 75,
|
110 | 111 | "batch_size": 4,
|
111 | 112 | "num_workers": 2**4,
|
@@ -142,14 +143,15 @@ def shape_embed_process():
|
142 | 143 |
|
143 | 144 | args = SimpleNamespace(**params, **optimizer_params, **lr_scheduler_params)
|
144 | 145 |
|
145 |
| - dataset_path = "bbbc010/BBBC010_v1_foreground_eachworm" |
| 146 | + #dataset_path = "bbbc010/BBBC010_v1_foreground_eachworm" |
| 147 | + dataset_path = "shape_embed_data/data/bbbc010/BBBC010_v1_foreground_eachworm/" |
146 | 148 | # dataset_path = "vampire/mefs/data/processed/Control"
|
147 | 149 | # dataset_path = "shape_embed_data/data/vampire/torchvision/Control/"
|
148 | 150 | # dataset_path = "vampire/torchvision/Control"
|
149 | 151 | # dataset = "bbbc010"
|
150 | 152 |
|
151 | 153 | # train_data_path = f"scripts/shapes/data/{dataset_path}"
|
152 |
| - train_data_path = f"data/{dataset_path}" |
| 154 | + train_data_path = f"scripts/shapes/data/{dataset_path}" |
153 | 155 | metadata = lambda x: f"results/{dataset_path}_{args.model}/{x}"
|
154 | 156 |
|
155 | 157 | path = Path(metadata(""))
|
@@ -336,8 +338,10 @@ def shape_embed_process():
|
336 | 338 | dataloader.setup()
|
337 | 339 |
|
338 | 340 | predictions = trainer.predict(lit_model, datamodule=dataloader)
|
339 |
| - latent_space = torch.stack([d["z"].flatten() for d in predictions]) |
340 |
| - scalings = torch.stack([d["scalings"].flatten() for d in predictions]) |
| 341 | + |
| 342 | + # Use the namespace variables |
| 343 | + latent_space = torch.stack([d.out.z.flatten() for d in predictions]) |
| 344 | + scalings = torch.stack([d.x.scalings.flatten() for d in predictions]) |
341 | 345 |
|
342 | 346 | idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()}
|
343 | 347 |
|
@@ -366,7 +370,7 @@ def shape_embed_process():
|
366 | 370 | # Map numeric classes to their labels
|
367 | 371 | idx_to_class = {0: "alive", 1: "dead"}
|
368 | 372 | df["Class"] = df["Class"].map(idx_to_class)
|
369 |
| - df["Scale"] = scalings |
| 373 | + df["Scale"] = scalings[:, 0].squeeze() |
370 | 374 | df = df.set_index("Class")
|
371 | 375 | df_shape_embed = df.copy()
|
372 | 376 |
|
|
0 commit comments