Skip to content

Commit

Permalink
Misc fudging about
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolphpienaar committed Apr 17, 2024
1 parent 857355f commit e6964a1
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions spleen_segmentation_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
# In[7]:


set_determinism(seed=0)
set_determinism(seed=42)


# ## Setup transforms for training and validation
Expand Down Expand Up @@ -253,6 +253,21 @@
# In[10]:


def tensor_desc(T: torch.Tensor, **kwargs) -> torch.Tensor:
strAs: str = "meanstd"
v1: float = 0.0
v2: float = 0.0
tensor: torch.Tensor = torch.Tensor([v1, v2])
for k, v in kwargs.items():
if k.lower() == "as":
strAs = v
match strAs:
case "meanstd":
tensor = torch.Tensor([T.mean().item(), T.std().item()])
return tensor


torch.manual_seed(42)
train_ds = CacheDataset(
data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4
)
Expand All @@ -261,18 +276,29 @@
# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
# for _ in range(0, 1):
# print("")
# for sample in train_loader:
# input = sample["image"].to(torch.device("cuda:0"))
# print(tensor_desc(input))

val_ds = CacheDataset(
data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4
)
# val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
# for sample in val_loader:
# input = sample["image"].to(torch.device("cuda:0"))
# print(tensor_desc(input))


# ## Create Model, Loss, Optimizer

# In[12]:

print("")
# for sample in train_loader:
# input = sample["image"].to(torch.device("cuda:0"))
# print(tensor_desc(input))

# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
Expand All @@ -294,6 +320,10 @@

# In[13]:

print("")
for sample in train_loader:
input = sample["image"].to(torch.device("cuda:0"))
print(tensor_desc(input))

max_epochs = 600
val_interval = 2
Expand All @@ -304,6 +334,7 @@
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])


for epoch in range(max_epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
Expand All @@ -317,7 +348,9 @@
batch_data["label"].to(device),
)
optimizer.zero_grad()
print(tensor_desc(inputs))
outputs = model(inputs)
print(tensor_desc(outputs))
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
Expand Down

0 comments on commit e6964a1

Please sign in to comment.