Skip to content

Commit c957f71

Browse files
committed
added colorbar for spheroid
1 parent 5710e3b commit c957f71

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

saenopy/gui/common/ModuleColorBar.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
from qtpy import QtCore, QtWidgets, QtGui
4+
from qimage2ndarray import array2qimage
5+
6+
7+
class ModuleColorBar(QtWidgets.QGroupBox):
8+
min_v = None
9+
max_v = None
10+
cmap = None
11+
tick_count = 3
12+
13+
def __init__(self, parent, view):
14+
QtWidgets.QWidget.__init__(self)
15+
self.parent = parent
16+
17+
self.font = QtGui.QFont()
18+
self.font.setPointSize(16)
19+
20+
self.tick = []
21+
for i in range(self.tick_count):
22+
tick_ticks = QtWidgets.QGraphicsTextItem("", view.hud_lowerLeft)
23+
tick_ticks.setFont(self.font)
24+
tick_ticks.setDefaultTextColor(QtGui.QColor("white"))
25+
26+
tick_line = QtWidgets.QGraphicsRectItem(0, 0, 1, 5, tick_ticks)
27+
tick_line.setBrush(QtGui.QBrush(QtGui.QColor("white")))
28+
tick_line.setPen(QtGui.QPen(QtGui.QColor("white")))
29+
self.tick.append([tick_ticks, tick_line])
30+
self.scalebar = QtWidgets.QGraphicsPixmapItem(view.hud_lowerLeft)
31+
self.scalebar.setPos(20, -20)
32+
33+
self.updateStatus()
34+
35+
def updateStatus(self):
36+
self.updateBar()
37+
38+
def setScale(self, min_v, max_v, cmap):
39+
self.min_v = min_v
40+
self.max_v = max_v
41+
self.cmap = cmap
42+
self.updateBar()
43+
44+
def updateBar(self):
45+
if self.min_v is None or self.max_v is None or self.cmap is None:
46+
return
47+
bar_width = 200
48+
ofset_x = 20
49+
ofset_y = 25
50+
colors = np.zeros((10, bar_width, 3), dtype=np.uint8)
51+
for i in range(bar_width):
52+
c = plt.get_cmap(self.cmap)(int(i/bar_width*255))
53+
colors[:, i, :] = [c[0]*255, c[1]*255, c[2]*255]
54+
self.scalebar.setPixmap(QtGui.QPixmap(array2qimage(colors)))
55+
self.scalebar.setPos(ofset_x, -ofset_y)
56+
57+
import matplotlib.ticker as ticker
58+
59+
locator = ticker.MaxNLocator(nbins=self.tick_count-1)
60+
tick_positions = locator.tick_values(self.min_v, self.max_v)
61+
tick_positions = np.linspace(self.min_v, self.max_v, self.tick_count)
62+
for i, pos in enumerate(tick_positions):
63+
self.tick[i][0].setPos(ofset_x - 50 + (bar_width-1)/(self.tick_count-1)*i, -ofset_y - 33)
64+
self.tick[i][1].setPos(+ 50, 33 - 5)
65+
self.tick[i][0].setTextWidth(100)
66+
self.tick[i][0].setHtml(f"<center>{int(pos):d}</center>")

saenopy/gui/solver/modules/exporter/ExporterRender2D.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def render_2d_image(params, result, exporter):
6161
return pil_image, display_image, im_scale, aa_scale
6262

6363

64-
def render_2d_arrows(params, result, pil_image, im_scale, aa_scale, display_image):
64+
def render_2d_arrows(params, result, pil_image, im_scale, aa_scale, display_image, return_scale=False):
6565
def project_data(R, field, skip=1):
6666
length = np.linalg.norm(field, axis=1)
6767
angle = np.arctan2(field[:, 1], field[:, 0])
@@ -77,6 +77,8 @@ def project_data(R, field, skip=1):
7777
mesh, field, params_arrows, name = get_mesh_arrows(params, result)
7878

7979
if mesh is None:
80+
if return_scale:
81+
return pil_image, None
8082
return pil_image
8183

8284
scale_max = params_arrows["scale_max"] if not params_arrows["autoscale"] else None
@@ -106,6 +108,7 @@ def project_data(R, field, skip=1):
106108

107109
if scale_max is None:
108110
max_length = np.nanmax(np.linalg.norm(field, axis=1))# * params_arrows["arrow_scale"]
111+
scale_max = max_length / params_arrows["arrow_scale"]
109112
else:
110113
max_length = scale_max * params_arrows["arrow_scale"]
111114

@@ -129,6 +132,8 @@ def project_data(R, field, skip=1):
129132
width=params["2D_arrows"]["width"],
130133
headlength=params["2D_arrows"]["headlength"],
131134
headheight=params["2D_arrows"]["headheight"])
135+
if return_scale:
136+
return pil_image, scale_max
132137
return pil_image
133138

134139

saenopy/gui/spheroid/modules/DeformationDetector.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from saenopy.gui.common.resources import resource_icon
1616
from saenopy.gui.common.code_export import get_code
1717
from saenopy.gui.common.ModuleScaleBar import ModuleScaleBar
18+
from saenopy.gui.common.ModuleColorBar import ModuleColorBar
1819

1920

2021
class DeformationDetector(PipelineModule):
@@ -76,6 +77,7 @@ def __init__(self, parent: "BatchEvaluate", layout):
7677
#layout.addWidget(self.plotter.interactor)
7778
self.label = QExtendedGraphicsView.QExtendedGraphicsView().addToLayout()
7879
self.scale1 = ModuleScaleBar(self, self.label)
80+
self.color1 = ModuleColorBar(self, self.label)
7981
#self.label.setMinimumWidth(300)
8082
self.pixmap = QtWidgets.QGraphicsPixmapItem(self.label.origin)
8183
self.contour = QtWidgets.QGraphicsPathItem(self.label.origin)
@@ -207,10 +209,10 @@ def update_display(self, *, plotter=None):
207209
pil_image = pil_image.resize(
208210
[int(pil_image.width * im_scale * aa_scale), int(pil_image.height * im_scale * aa_scale)])
209211
#print(self.auto_scale.value(), self.getScaleMax())
210-
pil_image = render_2d_arrows({
212+
pil_image, scale_max = render_2d_arrows({
211213
'arrows': 'deformation',
212214
'deformation_arrows': {
213-
"autoscale": not self.auto_scale.value(),
215+
"autoscale": self.auto_scale.value(),
214216
"scale_max": self.getScaleMax(),
215217
"colormap": self.colormap_chooser.value(),
216218
"skip": 1,
@@ -219,12 +221,13 @@ def update_display(self, *, plotter=None):
219221
},
220222
"time": {"t": t},
221223
'2D_arrows': {'width': 2.0, 'headlength': 5.0, 'headheight': 5.0},
222-
}, self.result, pil_image, im_scale, aa_scale, display_image)
224+
}, self.result, pil_image, im_scale, aa_scale, display_image, return_scale=True)
223225

224226
im = np.asarray(pil_image)
225227
self.pixmap.setPixmap(QtGui.QPixmap(array2qimage(im)))
226228
self.label.setExtend(im.shape[1], im.shape[0])
227229
self.scale1.setScale([self.result.pixel_size])
230+
self.color1.setScale(0, scale_max, self.colormap_chooser.value())
228231

229232
if self.show_seg.value():
230233
thresh_segmentation = self.thresh_segmentation.value()

0 commit comments

Comments
 (0)