Skip to content

Commit

Permalink
ACLSD model
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 7, 2023
1 parent 57da521 commit f6a4b94
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions src/raygun/torch/models/ACLSDModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from raygun.torch.networks import UNet
from raygun.torch.networks.UNet import ConvPass
import torch
import logging

torch.backends.cudnn.benchmark = True
logging.basicConfig(level=logging.INFO)

# long range affs - use 20 output features
# increase number of downsampling layers for more features in the bottleneck
class ACLSDModel(torch.nn.Module):
def __init__(
self,
unet_kwargs={
"input_nc": 1,
"ngf": 12,
"fmap_inc_factor": 6,
"downsample_factors": [(2, 2, 2), (2, 2, 2), (2, 2, 2)],
"constant_upsample": True,
},
num_affs=3,
):
super().__init__()

self.unet = UNet(**unet_kwargs)

self.aff_head = ConvPass(
unet_kwargs["ngf"], num_affs, [[1, 1, 1]], activation="Sigmoid"
)

self.output_arrays = ["pred_affs"]
self.data_dict = {}

def add_log(self, writer, step):
# add loss input image examples
for name, data in self.data_dict.items():
if len(data.shape) > 3: # pull out batch dimension if necessary
img = data[0].squeeze()
else:
img = data.squeeze()

if len(img.shape) == 3:
mid = img.shape[0] // 2 # for 3D volume
img = img[mid]

if (
(img.min() < 0) and (img.min() >= -1.0) and (img.max() <= 1.0)
): # scale img to [0,1] if necessary
img = (img * 0.5) + 0.5
writer.add_image(name, img, global_step=step, dataformats="HW")

def forward(self, raw):
self.data_dict.update({"raw": raw.detach()})
z = self.unet(raw)
affs = self.aff_head(z)

return affs

0 comments on commit f6a4b94

Please sign in to comment.