Skip to content

Commit

Permalink
working warpfield
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut committed Jul 2, 2024
1 parent 1e0f4ca commit 2fb2e12
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
19 changes: 12 additions & 7 deletions data/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,18 +438,20 @@ def torch_transform(self, image: torch.Tensor) -> T.Transform:
total_sigma = self.sigma * min_length
pad = round(truncate * total_sigma)
kernel_size = 2 * pad + 1
random_x = torch.rand((height + 2 * pad, width + 2 * pad), device=image.device) * 2 - 1
random_y = torch.rand((height + 2 * pad, width + 2 * pad), device=image.device) * 2 - 1
random_x = torch.rand((1, height + 2 * pad, width + 2 * pad), device=image.device) * 2 - 1
random_y = torch.rand((1, height + 2 * pad, width + 2 * pad), device=image.device) * 2 - 1

dx = F.gaussian_blur(
random_x,
kernel_size=[kernel_size, kernel_size],
sigma=[total_sigma, total_sigma],
)[pad:-pad, pad:-pad]
)[..., pad:-pad, pad:-pad]
dy = F.gaussian_blur(
random_y,
kernel_size=[kernel_size, kernel_size],
sigma=[total_sigma, total_sigma],
)[pad:-pad, pad:-pad]
)[..., pad:-pad, pad:-pad]

warpfield[0] = dx * min_length * self.alpha
warpfield[1] = dy * min_length * self.alpha

Expand Down Expand Up @@ -1706,10 +1708,10 @@ def test(args) -> None:
# augs = build_augmentation(cfg, mode="train")
# aug = T.AugmentationList(augs)

augs = [RandomAffine()]
augs = [RandomElastic()]
aug = T.AugmentationList(augs)

input_image = image.clone()
input_image = image.copy() if isinstance(image, np.ndarray) else image.clone()
output = AugInput(image=input_image, sem_seg=sem_seg)
transforms = aug(output)
transforms = [t for t in transforms.transforms if not isinstance(t, T.NoOpTransform)]
Expand All @@ -1718,7 +1720,10 @@ def test(args) -> None:
print(image.shape)
print(image.dtype)
print(image.min(), image.max())
im = Image.fromarray(image.permute(1, 2, 0).cpu().numpy())

if isinstance(image, torch.Tensor):
image = image.permute(1, 2, 0).cpu().numpy()
im = Image.fromarray(image)
im.show("Original")

if isinstance(output.image, torch.Tensor):
Expand Down
4 changes: 2 additions & 2 deletions data/torch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,12 @@ def apply_segmentation(self, segmentation: torch.Tensor) -> torch.Tensor:
torch.Tensor: warped segmentation
"""
segmentation = segmentation.to(dtype=torch.float32)
segmentation = torch.stack([segmentation, torch.ones_like(segmentation)], dim=0)
segmentation = torch.stack([segmentation, torch.ones_like(segmentation)], dim=1)

sampled_segmentation = torch.nn.functional.grid_sample(
segmentation, self.indices[None, ...], mode="nearest", padding_mode="zeros", align_corners=False
)
out_of_bounds = segmentation[:, 1] == 0
out_of_bounds = sampled_segmentation[:, 1] == 0
# Set out of bounds to ignore value (remove if you don't want to ignore)
sampled_segmentation[:, 0][out_of_bounds] = self.ignore_value

Expand Down

0 comments on commit 2fb2e12

Please sign in to comment.