Skip to content

Commit 6377d91

Browse files
committed
added colorbar for 2d image export
1 parent f3d4d8d commit 6377d91

File tree

3 files changed

+95
-9
lines changed

3 files changed

+95
-9
lines changed

saenopy/gui/solver/modules/exporter/ExportRenderCommon.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@ def get_mesh_arrows(params, result):
1313
if data["fields"][params["arrows"]]["measure"] == "deformation":
1414
if mesh is not None and field is not None:
1515
return mesh, field, params["deformation_arrows"], data["fields"][params["arrows"]]["name"]
16+
else:
17+
return None, None, params["deformation_arrows"], data["fields"][params["arrows"]]["name"]
1618
if data["fields"][params["arrows"]]["measure"] == "force":
1719
if mesh is not None and field is not None:
1820
return mesh, field, params["force_arrows"], data["fields"][params["arrows"]]["name"]
21+
else:
22+
return None, None, params["force_arrows"], data["fields"][params["arrows"]]["name"]
1923
return None, None, {}, ""
2024

2125

saenopy/gui/solver/modules/exporter/ExporterRender2D.py

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ def render_2d(params, result, exporter=None):
1414
if pil_image is None:
1515
return np.zeros((10, 10))
1616

17-
pil_image = render_2d_arrows(params, result, pil_image, im_scale, aa_scale, display_image)
17+
pil_image, disp_params = render_2d_arrows(params, result, pil_image, im_scale, aa_scale, display_image, return_scale=True)
1818

1919
if aa_scale == 2:
2020
pil_image = pil_image.resize([pil_image.width // 2, pil_image.height // 2])
2121
aa_scale = 1
2222

2323
pil_image = render_2d_scalebar(params, result, pil_image, im_scale, aa_scale)
24+
if disp_params != None:
25+
pil_image = render_2d_colorbar(params, result, pil_image, im_scale, aa_scale, scale_max=disp_params["scale_max"], colormap=disp_params["colormap"])
2426

2527
pil_image = render_2d_time(params, result, pil_image)
2628

@@ -76,16 +78,23 @@ def project_data(R, field, skip=1):
7678

7779
mesh, field, params_arrows, name = get_mesh_arrows(params, result)
7880

81+
if params_arrows is None:
82+
scale_max = None
83+
else:
84+
scale_max = params_arrows["scale_max"] if not params_arrows["autoscale"] else None
85+
colormap = params_arrows["colormap"]
86+
skip = params_arrows["skip"]
87+
alpha = params_arrows["arrow_opacity"]
88+
7989
if mesh is None:
8090
if return_scale:
91+
if scale_max is None:
92+
return pil_image, None
93+
else:
94+
return pil_image, {"scale_max": scale_max, "colormap": colormap}
8195
return pil_image, None
8296
return pil_image
8397

84-
scale_max = params_arrows["scale_max"] if not params_arrows["autoscale"] else None
85-
colormap = params_arrows["colormap"]
86-
skip = params_arrows["skip"]
87-
alpha = params_arrows["arrow_opacity"]
88-
8998
if field is not None:
9099
# rescale and offset
91100
scale = 1e6 / display_image[1][0]
@@ -133,7 +142,7 @@ def project_data(R, field, skip=1):
133142
headlength=params["2D_arrows"]["headlength"],
134143
headheight=params["2D_arrows"]["headheight"])
135144
if return_scale:
136-
return pil_image, scale_max
145+
return pil_image, {"scale_max": scale_max, "colormap": colormap}
137146
return pil_image
138147

139148

@@ -166,6 +175,22 @@ def getBarParameters(pixtomu, scale=1):
166175
size_in_um=mu, color="w", unit="µm")
167176
return pil_image
168177

178+
def render_2d_colorbar(params, result, pil_image, im_scale, aa_scale, colormap="viridis", scale_max=1):
179+
pil_image = add_colorbar(pil_image, scale=1,
180+
colormap=colormap,#params["colorbar"]["colorbar"],
181+
#bar_width=params["colorbar"]["bar_width"] * aa_scale,
182+
#bar_height=params["colorbar"]["bar_height"] * aa_scale,
183+
#tick_height=params["colorbar"]["tick_height"] * aa_scale,
184+
#tick_count=params["colorbar"]["tick_count"],
185+
#min_v=params["scalebar"]["min_v"],
186+
max_v=scale_max,#params["colorbar"]["max_v"],
187+
#offset_x=params["colorbar"]["offset_x"] * aa_scale,
188+
#offset_y=params["colorbar"]["offset_y"] * aa_scale,
189+
#fontsize=params["colorbar"]["fontsize"] * aa_scale,
190+
)
191+
192+
return pil_image
193+
169194

170195
def render_2d_time(params, result, pil_image):
171196
data = result.get_data_structure()
@@ -245,6 +270,63 @@ def add_text(pil_image, text, position, fontsize=18):
245270
image.text((x, y), text, color, font=font)
246271
return pil_image
247272

273+
def add_colorbar(pil_image,
274+
colormap="viridis",
275+
bar_width=150,
276+
bar_height=10,
277+
tick_height=5,
278+
tick_count=3,
279+
min_v=0,
280+
max_v=10,
281+
offset_x=15,
282+
offset_y=-10,
283+
scale=1, fontsize=16, color="w"):
284+
cmap = plt.get_cmap(colormap)
285+
if offset_x < 0:
286+
offset_x = pil_image.size[0] + offset_x
287+
if offset_y < 0:
288+
offset_y = pil_image.size[1] + offset_y
289+
290+
color = tuple((matplotlib.colors.to_rgba_array(color)[0, :3] * 255).astype("uint8"))
291+
if pil_image.mode != "RGB":
292+
color = int(np.mean(color))
293+
294+
colors = np.zeros((bar_height, bar_width, 3), dtype=np.uint8)
295+
for i in range(bar_width):
296+
c = plt.get_cmap(cmap)(int(i / bar_width * 255))
297+
colors[:, i, :] = [c[0] * 255, c[1] * 255, c[2] * 255]
298+
pil_image.paste(Image.fromarray(colors), (offset_x, offset_y - bar_height))
299+
300+
image = ImageDraw.ImageDraw(pil_image)
301+
import matplotlib.ticker as ticker
302+
303+
font_size = int(
304+
round(fontsize * scale * 4 / 3)) # the 4/3 appears to be a factor of "converting" screel dpi to image dpi
305+
try:
306+
font = ImageFont.truetype("arial", font_size) # ImageFont.truetype("tahoma.ttf", font_size)
307+
except IOError:
308+
font = ImageFont.truetype("times", font_size)
309+
310+
locator = ticker.MaxNLocator(nbins=tick_count - 1)
311+
#tick_positions = locator.tick_values(min_v, max_v)
312+
tick_positions = np.linspace(min_v, max_v, tick_count)
313+
for i, pos in enumerate(tick_positions):
314+
x0 = offset_x + (bar_width - 2) / (tick_count - 1) * i
315+
y0 = offset_y - bar_height - 1
316+
317+
image.rectangle([x0, y0-5, x0+1, y0])
318+
319+
text = "%d" % pos
320+
length_number = image.textlength(text, font=font)
321+
height_number = image.textbbox((0, 0), text, font=font)[3]
322+
323+
x = x0 - length_number * 0.5 + 1
324+
y = y0 - height_number - tick_height - 3
325+
# draw the text for the number and the unit
326+
image.text((x, y), text, color, font=font)
327+
#image.rectangle([pil_image.size[0]-10, 0, pil_image.size[0], 10], fill="w")
328+
return pil_image
329+
248330
def add_scalebar(pil_image, scale, image_scale, width, xpos, ypos, fontsize, pixel_width, size_in_um, color="w", unit="µm"):
249331
image = ImageDraw.ImageDraw(pil_image)
250332
pixel_height = width

saenopy/gui/spheroid/modules/DeformationDetector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def update_display(self, *, plotter=None):
210210
pil_image = pil_image.resize(
211211
[int(pil_image.width * im_scale * aa_scale), int(pil_image.height * im_scale * aa_scale)])
212212
#print(self.auto_scale.value(), self.getScaleMax())
213-
pil_image, scale_max = render_2d_arrows({
213+
pil_image, disp_params = render_2d_arrows({
214214
'arrows': 'deformation',
215215
'deformation_arrows': {
216216
"autoscale": self.auto_scale.value(),
@@ -228,7 +228,7 @@ def update_display(self, *, plotter=None):
228228
self.pixmap.setPixmap(QtGui.QPixmap(array2qimage(im)))
229229
self.label.setExtend(im.shape[1], im.shape[0])
230230
self.scale1.setScale([self.result.pixel_size])
231-
self.color1.setScale(0, scale_max, self.colormap_chooser.value())
231+
self.color1.setScale(0, disp_params["scale_max"] if disp_params else None, self.colormap_chooser.value())
232232

233233
if self.show_seg.value():
234234
thresh_segmentation = self.thresh_segmentation.value()

0 commit comments

Comments
 (0)