Skip to content

Commit

Permalink
Merge pull request #243 from lkeegan/gpu_volume_creation
Browse files Browse the repository at this point in the history
Create simulation volumes on the GPU
  • Loading branch information
kdreher authored Sep 21, 2023
2 parents 71fc5bd + 709804a commit ed02703
Show file tree
Hide file tree
Showing 14 changed files with 146 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from simpa.core.simulation_modules.reconstruction_module import ReconstructionAdapterBase
import numpy as np
import torch
from simpa.core.simulation_modules.reconstruction_module.reconstruction_utils import compute_delay_and_sum_values,\
from simpa.core.simulation_modules.reconstruction_module.reconstruction_utils import compute_delay_and_sum_values, \
compute_image_dimensions, preparing_reconstruction_and_obtaining_reconstruction_settings
from simpa.core.device_digital_twins import DetectionGeometryBase
from simpa.core.simulation_modules.reconstruction_module import create_reconstruction_settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from simpa.utils import Tags
from simpa.utils.tissue_properties import TissueProperties
import numpy as np
import torch
from simpa.core import SimulationModule
from simpa.utils.dict_path_manager import generate_dict_path
from simpa.io_handling import save_data_field
from simpa.utils.quality_assurance.data_sanity_testing import assert_equal_shapes, assert_array_well_defined
from simpa.utils.processing_device import get_processing_device


class VolumeCreatorModuleBase(SimulationModule):
Expand All @@ -22,6 +24,7 @@ class VolumeCreatorModuleBase(SimulationModule):
def __init__(self, global_settings: Settings):
super(VolumeCreatorModuleBase, self).__init__(global_settings=global_settings)
self.component_settings = global_settings.get_volume_creation_settings()
self.torch_device = get_processing_device(self.global_settings)

def create_empty_volumes(self):
volumes = dict()
Expand All @@ -38,7 +41,7 @@ def create_empty_volumes(self):
# Create wavelength-independent properties only in the first wavelength run
if key in TissueProperties.wavelength_independent_properties and wavelength != first_wavelength:
continue
volumes[key] = np.zeros(sizes)
volumes[key] = torch.zeros(sizes, dtype=torch.float, device=self.torch_device)

return volumes, volume_x_dim, volume_y_dim, volume_z_dim

Expand All @@ -57,6 +60,8 @@ def run(self, device):
self.logger.info("VOLUME CREATION")

volumes = self.create_simulation_volume()
# explicitly empty cache to free reserved GPU memory after volume creation
torch.cuda.empty_cache()

if not (Tags.IGNORE_QA_ASSERTIONS in self.global_settings and Tags.IGNORE_QA_ASSERTIONS):
assert_equal_shapes(list(volumes.values()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: MIT

from simpa.core.simulation_modules.volume_creation_module import VolumeCreatorModuleBase
from simpa.utils.libraries.structure_library import Structures
from simpa.utils.libraries.structure_library import priority_sorted_structures
from simpa.utils import Tags
import numpy as np
from simpa.utils import create_deformation_settings
Expand Down Expand Up @@ -60,19 +60,18 @@ def create_simulation_volume(self) -> dict:
cosine_scaling_factor=1)

volumes, x_dim_px, y_dim_px, z_dim_px = self.create_empty_volumes()
global_volume_fractions = np.zeros((x_dim_px, y_dim_px, z_dim_px))
max_added_fractions = np.zeros((x_dim_px, y_dim_px, z_dim_px))
global_volume_fractions = torch.zeros((x_dim_px, y_dim_px, z_dim_px),
dtype=torch.float, device=self.torch_device)
max_added_fractions = torch.zeros((x_dim_px, y_dim_px, z_dim_px), dtype=torch.float, device=self.torch_device)
wavelength = self.global_settings[Tags.WAVELENGTH]

structure_list = Structures(self.global_settings, self.component_settings)
priority_sorted_structures = structure_list.sorted_structures

for structure in priority_sorted_structures:
for structure in priority_sorted_structures(self.global_settings, self.component_settings):
self.logger.debug(type(structure))

structure_properties = structure.properties_for_wavelength(wavelength)

structure_volume_fractions = structure.geometrical_volume
structure_volume_fractions = torch.as_tensor(
structure.geometrical_volume, dtype=torch.float, device=self.torch_device)
structure_indexes_mask = structure_volume_fractions > 0
global_volume_fractions_mask = global_volume_fractions < 1
mask = structure_indexes_mask & global_volume_fractions_mask
Expand All @@ -82,11 +81,11 @@ def create_simulation_volume(self) -> dict:
added_volume_fraction <= 1 & mask]

selector_more_than_1 = added_volume_fraction > 1
if selector_more_than_1.any():
if torch.any(selector_more_than_1):
remaining_volume_fraction_to_fill = 1 - global_volume_fractions[selector_more_than_1]
fraction_to_be_filled = structure_volume_fractions[selector_more_than_1]
added_volume_fraction[selector_more_than_1] = np.min([remaining_volume_fraction_to_fill,
fraction_to_be_filled], axis=0)
added_volume_fraction[selector_more_than_1] = torch.min(torch.stack((remaining_volume_fraction_to_fill,
fraction_to_be_filled)), 0).values
for key in volumes.keys():
if structure_properties[key] is None:
continue
Expand All @@ -100,7 +99,8 @@ def create_simulation_volume(self) -> dict:

global_volume_fractions[mask] += added_volume_fraction[mask]

# explicitly empty cache to free reserved GPU memory after volume creation
torch.cuda.empty_cache()
# convert volumes back to CPU
for key in volumes.keys():
volumes[key] = volumes[key].cpu().numpy().astype(np.float64, copy=False)

return volumes
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Background(GeometricalStructure):
def get_enclosed_indices(self):
array = np.ones((self.volume_dimensions_voxels[0],
self.volume_dimensions_voxels[1],
self.volume_dimensions_voxels[2]))
self.volume_dimensions_voxels[2]), dtype=np.float32)
return array == 1, 1

def get_params_from_settings(self, single_structure_settings):
Expand Down
34 changes: 18 additions & 16 deletions simpa/utils/libraries/structure_library/CircularTubularStructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,26 @@ def to_settings(self):

def get_enclosed_indices(self):
start_mm, end_mm, radius_mm, partial_volume = self.params
start_mm = torch.tensor(start_mm, dtype=torch.float).to(self.torch_device)
end_mm = torch.tensor(end_mm, dtype=torch.float).to(self.torch_device)
radius_mm = torch.tensor(radius_mm, dtype=torch.float).to(self.torch_device)
start_mm = torch.tensor(start_mm, dtype=torch.float, device=self.torch_device)
end_mm = torch.tensor(end_mm, dtype=torch.float, device=self.torch_device)
radius_mm = torch.tensor(radius_mm, dtype=torch.float, device=self.torch_device)
start_voxels = start_mm / self.voxel_spacing
end_voxels = end_mm / self.voxel_spacing
radius_voxels = radius_mm / self.voxel_spacing

x, y, z = torch.meshgrid(torch.arange(self.volume_dimensions_voxels[0]).to(self.torch_device),
torch.arange(self.volume_dimensions_voxels[1]).to(self.torch_device),
torch.arange(self.volume_dimensions_voxels[2]).to(self.torch_device),
indexing='ij')

x = x + 0.5
y = y + 0.5
z = z + 0.5
target_vector = torch.stack(torch.meshgrid(torch.arange(start=0.5, end=self.volume_dimensions_voxels[0], dtype=torch.float, device=self.torch_device),
torch.arange(
start=0.5, end=self.volume_dimensions_voxels[1], dtype=torch.float, device=self.torch_device),
torch.arange(
start=0.5, end=self.volume_dimensions_voxels[2], dtype=torch.float, device=self.torch_device),
indexing='ij'), dim=-1)
target_vector -= start_voxels

if partial_volume:
radius_margin = 0.5
else:
radius_margin = 0.7071

target_vector = torch.subtract(torch.stack([x, y, z], axis=-1), start_voxels)
if self.do_deformation:
# the deformation functional needs mm as inputs and returns the result in reverse indexing order...
deformation_values_mm = self.deformation_functional_mm(torch.arange(self.volume_dimensions_voxels[0]) *
Expand All @@ -78,16 +76,20 @@ def get_enclosed_indices(self):
self.voxel_spacing).T
deformation_values_mm = deformation_values_mm.reshape(self.volume_dimensions_voxels[0],
self.volume_dimensions_voxels[1], 1, 1)
deformation_values_mm = torch.tile(torch.from_numpy(deformation_values_mm).to(
self.torch_device), (1, 1, self.volume_dimensions_voxels[2], 3))
target_vector = (target_vector + (deformation_values_mm / self.voxel_spacing)).float()
deformation_values_mm = torch.tile(torch.as_tensor(
deformation_values_mm, dtype=torch.float, device=self.torch_device), (1, 1, self.volume_dimensions_voxels[2], 3))
deformation_values_mm /= self.voxel_spacing
target_vector += deformation_values_mm
del deformation_values_mm
cylinder_vector = torch.subtract(end_voxels, start_voxels)

target_radius = torch.linalg.norm(target_vector, axis=-1) * torch.sin(
torch.arccos((torch.matmul(target_vector, cylinder_vector)) /
(torch.linalg.norm(target_vector, axis=-1) * torch.linalg.norm(cylinder_vector))))
del target_vector

volume_fractions = torch.zeros(tuple(self.volume_dimensions_voxels), dtype=torch.float).to(self.torch_device)
volume_fractions = torch.zeros(tuple(self.volume_dimensions_voxels),
dtype=torch.float, device=self.torch_device)

filled_mask = target_radius <= radius_voxels - 1 + radius_margin
border_mask = (target_radius > radius_voxels - 1 + radius_margin) & \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,27 @@ def get_enclosed_indices(self):
if direction_mm[0] != 0 or direction_mm[1] != 0 or direction_mm[2] == 0:
raise ValueError("Horizontal Layer structure needs a start and end vector in the form of [0, 0, n].")

x, y, z = torch.meshgrid(torch.arange(self.volume_dimensions_voxels[0]).to(self.torch_device),
torch.arange(self.volume_dimensions_voxels[1]).to(self.torch_device),
torch.arange(self.volume_dimensions_voxels[2]).to(self.torch_device),
indexing='ij')

target_vector_voxels = torch.subtract(torch.stack([x, y, z], axis=-1), start_voxels)
target_vector_voxels = torch.stack(torch.meshgrid(torch.arange(self.volume_dimensions_voxels[0], dtype=torch.float, device=self.torch_device),
torch.arange(
self.volume_dimensions_voxels[1], dtype=torch.float, device=self.torch_device),
torch.arange(
self.volume_dimensions_voxels[2], dtype=torch.float, device=self.torch_device),
indexing='ij'), dim=-1)

target_vector_voxels -= start_voxels
target_vector_voxels = target_vector_voxels[:, :, :, 2]
if self.do_deformation:
# the deformation functional needs mm as inputs and returns the result in reverse indexing order...
deformation_values_mm = self.deformation_functional_mm(torch.arange(self.volume_dimensions_voxels[0]) *
deformation_values_mm = self.deformation_functional_mm(torch.arange(self.volume_dimensions_voxels[0], dtype=torch.float) *
self.voxel_spacing,
torch.arange(self.volume_dimensions_voxels[1]) *
torch.arange(self.volume_dimensions_voxels[1], dtype=torch.float) *
self.voxel_spacing).T
target_vector_voxels = (target_vector_voxels + torch.from_numpy(deformation_values_mm.reshape(
self.volume_dimensions_voxels[0],
self.volume_dimensions_voxels[1], 1)).to(self.torch_device) / self.voxel_spacing).float()

volume_fractions = torch.zeros(tuple(self.volume_dimensions_voxels), dtype=torch.float).to(self.torch_device)
volume_fractions = torch.zeros(tuple(self.volume_dimensions_voxels),
dtype=torch.float, device=self.torch_device)

if partial_volume:
bools_first_layer = ((target_vector_voxels >= -1) & (target_vector_voxels < 0))
Expand Down
29 changes: 16 additions & 13 deletions simpa/utils/libraries/structure_library/ParallelepipedStructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,35 +45,38 @@ def to_settings(self):

def get_enclosed_indices(self):
start_mm, x_edge_mm, y_edge_mm, z_edge_mm = self.params
start_mm = torch.tensor(start_mm, dtype=torch.float).to(self.torch_device)
x_edge_mm = torch.tensor(x_edge_mm, dtype=torch.float).to(self.torch_device)
y_edge_mm = torch.tensor(y_edge_mm, dtype=torch.float).to(self.torch_device)
z_edge_mm = torch.tensor(z_edge_mm, dtype=torch.float).to(self.torch_device)
start_mm = torch.tensor(start_mm, dtype=torch.float, device=self.torch_device)
x_edge_mm = torch.tensor(x_edge_mm, dtype=torch.float, device=self.torch_device)
y_edge_mm = torch.tensor(y_edge_mm, dtype=torch.float, device=self.torch_device)
z_edge_mm = torch.tensor(z_edge_mm, dtype=torch.float, device=self.torch_device)

start_voxels = start_mm / self.voxel_spacing
x_edge_voxels = x_edge_mm / self.voxel_spacing
y_edge_voxels = y_edge_mm / self.voxel_spacing
z_edge_voxels = z_edge_mm / self.voxel_spacing

x, y, z = torch.meshgrid(torch.arange(self.volume_dimensions_voxels[0]).to(self.torch_device),
torch.arange(self.volume_dimensions_voxels[1]).to(self.torch_device),
torch.arange(self.volume_dimensions_voxels[2]).to(self.torch_device),
indexing='ij')

target_vector = torch.subtract(torch.stack([x, y, z], axis=-1), start_voxels)
target_vector = torch.stack(torch.meshgrid(torch.arange(self.volume_dimensions_voxels[0], dtype=torch.float, device=self.torch_device),
torch.arange(
self.volume_dimensions_voxels[1], dtype=torch.float, device=self.torch_device),
torch.arange(
self.volume_dimensions_voxels[2], dtype=torch.float, device=self.torch_device),
indexing='ij'), dim=-1)
target_vector -= start_voxels

matrix = torch.stack((x_edge_voxels, y_edge_voxels, z_edge_voxels))

result = torch.linalg.solve(matrix.T.expand((target_vector.shape[:-1]+matrix.shape)), target_vector)
del target_vector

norm_vector = torch.tensor([1/torch.linalg.norm(x_edge_voxels),
1/torch.linalg.norm(y_edge_voxels),
1/torch.linalg.norm(z_edge_voxels)]).to(self.torch_device)
1/torch.linalg.norm(z_edge_voxels)], device=self.torch_device)

filled_mask_bool = (0 <= result) & (result <= 1 - norm_vector)

volume_fractions = torch.zeros(tuple(self.volume_dimensions_voxels), dtype=torch.float).to(self.torch_device)
filled_mask = torch.all(filled_mask_bool, axis=-1)
volume_fractions = torch.zeros(tuple(self.volume_dimensions_voxels),
dtype=torch.float, device=self.torch_device)
filled_mask = torch.all(filled_mask_bool, dim=-1)

volume_fractions[filled_mask] = 1

Expand Down
Loading

0 comments on commit ed02703

Please sign in to comment.