Skip to content

Commit

Permalink
python: fixed bug where update_charts() didn't work on images with di…
Browse files Browse the repository at this point in the history
…fferent sizes

sewar_rmse() should return 1 metric if the image sizes are different
  • Loading branch information
adamgeorge309 committed Nov 15, 2024
1 parent 2d5193f commit 0ba3d0f
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions python/inet/test/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,15 @@


def sewar_rmse(a, b):
a,b = sewar_initial_check(a,b)
assert a.shape == b.shape, "Supplied images have different sizes " + \
str(a.shape) + " and " + str(b.shape)
if len(a.shape) == 2:
a = a[:,:,numpy.newaxis]
b = b[:,:,numpy.newaxis]
a = a.astype(numpy.float64)
b = b.astype(numpy.float64)
return numpy.sqrt(numpy.mean((a.astype(numpy.float64)-b.astype(numpy.float64))**2))

def sewar_initial_check(GT,P):
assert GT.shape == P.shape, "Supplied images have different sizes " + \
str(GT.shape) + " and " + str(P.shape)
if GT.dtype != P.dtype:
msg = "Supplied images have different dtypes " + \
str(GT.dtype) + " and " + str(P.dtype)
_logger.warn(msg)

if len(GT.shape) == 2:
GT = GT[:,:,numpy.newaxis]
P = P[:,:,numpy.newaxis]

return GT.astype(numpy.float64),P.astype(numpy.float64)

class ChartTestTask(TestTask):
def __init__(self, analysis_file_name, id, chart_name, simulation_project=None, name="chart test", **kwargs):
super().__init__(name=name, **kwargs)
Expand Down Expand Up @@ -77,8 +69,12 @@ def run_protected(self, keep_charts=True, output_stream=sys.stdout, **kwargs):
if os.path.exists(old_file_name):
new_image = matplotlib.image.imread(new_file_name)
old_image = matplotlib.image.imread(old_file_name)
if old_image.shape != new_image.shape:
return self.task_result_class(self, result="FAIL", reason="Supplied images have different sizes" + str(old_image.shape) + " and " + str(new_image.shape))
metric = sewar_rmse(old_image, new_image)
if metric == 0 or not keep_charts:
if type(metric) == str:
_logger.info(metric)
elif metric == 0 or not keep_charts:
os.remove(new_file_name)
else:
image_diff = numpy.abs(new_image - old_image)
Expand Down Expand Up @@ -183,7 +179,10 @@ def run_protected(self, keep_charts=True, **kwargs):
if os.path.exists(old_file_name):
new_image = matplotlib.image.imread(new_file_name)
old_image = matplotlib.image.imread(old_file_name)
metric = sewar_rmse(old_image, new_image)
if old_image.shape != new_image.shape:
metric = 1
else:
metric = sewar_rmse(old_image, new_image)
if metric == 0:
os.remove(new_file_name)
else:
Expand Down

0 comments on commit 0ba3d0f

Please sign in to comment.