-
Notifications
You must be signed in to change notification settings - Fork 1
/
simulation_voltage.py
216 lines (166 loc) · 8.82 KB
/
simulation_voltage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import copy
import numpy as np
import vtk
from vtk.util import numpy_support
from cell import Cell
def calculate_neighborhood_activation_score(column, cell_info, drid_idx):
cell, cell_index, local_neighbors, role, code, input_non_local, output_non_local = cell_info
drid = column.drids[drid_idx]
# Calculate the neighborhood activation score for local neighbors
activation_score = 0
for neighbor_index in local_neighbors:
neighbor_key = f"cell{np.ravel_multi_index(neighbor_index, drid.grid.array.shape, mode='wrap', order='C')}"
neighbor_cell = drid.cell_dict[neighbor_key][0]
activation_score += neighbor_cell.activation_state
# Calculate the neighborhood activation score for non-local input neighbors
for non_local_info in input_non_local:
non_local_drid_idx, non_local_cell_key, _ = non_local_info
non_local_drid = column.drids[non_local_drid_idx]
non_local_cell = non_local_drid.cell_dict[non_local_cell_key][0]
activation_score += non_local_cell.activation_state
# Calculate the neighborhoodactivation score as a factor of activated cells to the total number of neighbors
total_neighbors = len(local_neighbors) + len(input_non_local)
if total_neighbors == 0:
neighborhood_activation_score = 0
else:
neighborhood_activation_score = activation_score * 1.5 / total_neighbors
return neighborhood_activation_score
def count_active_inactive_neighbors(column, cell_info, drid_idx):
cell, cell_index, local_neighbors, role, code, input_non_local, output_non_local = cell_info
drid = column.drids[drid_idx]
active_neighbors = 0
inactive_neighbors = 0
# Count active and inactive local neighbors
for neighbor_index in local_neighbors:
neighbor_key = f"cell{np.ravel_multi_index(neighbor_index, drid.grid.array.shape) + 1}"
neighbor_cell = drid.cell_dict[neighbor_key][0]
if neighbor_cell.activation_state:
active_neighbors += 1
else:
inactive_neighbors += 1
# Count active and inactive non-local input neighbors
for non_local_info in input_non_local:
non_local_drid_idx, non_local_cell_key, _ = non_local_info
non_local_drid = column.drids[non_local_drid_idx]
non_local_cell = non_local_drid.cell_dict[non_local_cell_key][0]
if non_local_cell.activation_state:
active_neighbors += 1
else:
inactive_neighbors += 1
return active_neighbors, inactive_neighbors
class Worker:
def __init__(self, column):
self.column = column
def process_column(self):
for drid_idx, drid in enumerate(self.column.drids):
updated_drid = copy.deepcopy(drid)
for cell_key, cell_info in drid.cell_dict.items():
cell, cell_index, local_neighbors, role, code, input_non_local, output_non_local = cell_info
# Create a copy of the cell to store the updated state
updated_cell = Cell(cell.voltage, cell.activation_state,
cell.excitability)
activation_score = calculate_neighborhood_activation_score(
self.column, cell_info, drid_idx)
# Check the cell's excitability
if cell.excitability == 1 and cell.activation_state == 0:
updated_cell.voltage += activation_score
if cell.excitability == 1 and cell.activation_state == 0 and cell.voltage > drid.grid.upper_tresh:
updated_cell.activation_state = 1
updated_cell.voltage = 1
updated_cell.excitability = 0
if cell.excitability == 1 and cell.activation_state == 0 and cell.voltage < drid.grid.upper_tresh:
updated_cell.activation_state = 0
updated_cell.voltage -= drid.grid.decay_value
updated_cell.excitability = 1
if cell.excitability == 0 and cell.activation_state == 0 and cell.voltage <= drid.grid.lower_tresh:
updated_cell.excitability = 1
updated_cell.voltage += activation_score
if cell.excitability == 0 and cell.activation_state == 1 and cell.voltage >= drid.grid.upper_tresh:
updated_cell.activation_state = 0
updated_cell.voltage -= drid.grid.decay_value
# Update the cell's dictionary entry with the updated cell instance
updated_drid.cell_dict[cell_key] = (updated_cell, cell_index,
local_neighbors, role, code,
input_non_local,
output_non_local)
self.column.drids[drid_idx] = updated_drid
class ResultColumn:
def __init__(self):
self.voltage_grids = {}
self.activation_grids = {}
def add_voltage_grid(self, drid_idx, grid):
if drid_idx not in self.voltage_grids:
self.voltage_grids[drid_idx] = []
self.voltage_grids[drid_idx].append(grid)
def add_activation_grid(self, drid_idx, grid):
if drid_idx not in self.activation_grids:
self.activation_grids[drid_idx] = []
self.activation_grids[drid_idx].append(grid)
def run_simulation(column, num_frames):
# Initialize the Worker and ResultColumn instances
worker = Worker(column)
result_column = ResultColumn()
# Run the simulation for the specified number of frames
for frame in range(num_frames):
# Process the column to update cell states
worker.process_column()
# Loop through all drids
for drid_idx, drid in enumerate(column.drids):
# Create new grids to store voltage and activation states for this frame
voltage_grid = np.zeros(drid.grid.array.shape)
activation_grid = np.zeros(drid.grid.array.shape)
# Populate the voltage and activation grids with values from the cells
for cell_key, cell_info in drid.cell_dict.items():
cell, cell_index = cell_info[0], cell_info[1]
voltage_grid[cell_index] = cell.voltage
activation_grid[cell_index] = cell.activation_state
# Add the new grids to the ResultColumn after creating deep copies
result_column.add_voltage_grid(drid_idx, voltage_grid)
result_column.add_activation_grid(drid_idx, activation_grid)
return result_column
def are_grids_different(grid1, grid2, threshold=0.001):
diff = np.abs(grid1 - grid2)
max_diff = np.max(diff)
return max_diff > threshold
def numpy_to_vtk_image_data(numpy_array):
vtk_image_data = vtk.vtkImageData()
vtk_image_data.SetDimensions(numpy_array.shape)
vtk_image_data.SetSpacing(1, 1, 1)
vtk_image_data.SetOrigin(0, 0, 0)
vtk_array = numpy_support.numpy_to_vtk(numpy_array.ravel(order='F'),
deep=True,
array_type=vtk.VTK_FLOAT)
vtk_image_data.GetPointData().SetScalars(vtk_array)
return vtk_image_data
def save_combined_vti_files(result_column,
num_frames,
output_folder,
offset):
num_drids = len(result_column.voltage_grids)
first_grid_shape = result_column.voltage_grids[0][0].shape
for frame_idx in range(num_frames):
# Initialize a combined 3D array with zeros, adding extra space for the offsets
combined_shape = (first_grid_shape[0] + (num_drids - 1) * offset,
first_grid_shape[1], first_grid_shape[2])
combined_array = np.zeros(combined_shape)
for drid_idx, voltage_grids in result_column.voltage_grids.items():
activation_grid = voltage_grids[frame_idx]
x_start = drid_idx * offset
x_end = x_start + first_grid_shape[0]
combined_array[x_start:x_end, :, :] = activation_grid
# Convert the combined array to a vtkImageData object
vtk_image_data = vtk.vtkImageData()
vtk_image_data.SetDimensions(combined_shape[0], combined_shape[1],
combined_shape[2])
vtk_image_data.SetSpacing(1, 1, 1)
vtk_array = numpy_support.numpy_to_vtk(combined_array.ravel(),
deep=True,
array_type=vtk.VTK_FLOAT)
vtk_array.SetName("Voltage")
vtk_image_data.GetPointData().SetScalars(vtk_array)
# Save the vtkImageData object as a VTI file
output_file_name = f"{output_folder}/combined_drids_frame_{frame_idx:04d}.vti"
writer = vtk.vtkXMLImageDataWriter()
writer.SetFileName(output_file_name)
writer.SetInputData(vtk_image_data)
writer.Write()