diff --git a/pe/callback/common/compute_fid.py b/pe/callback/common/compute_fid.py index 8d3dd6d..a584c9a 100644 --- a/pe/callback/common/compute_fid.py +++ b/pe/callback/common/compute_fid.py @@ -47,6 +47,6 @@ def __call__(self, syn_data): mu2=syn_mu, sigma2=syn_sigma, ) - metric_item = FloatMetricItem(name=f"fid_{type(self._embedding).__name__}", value=fid) + metric_item = FloatMetricItem(name=f"fid_{self._embedding.column_name}", value=fid) execution_logger.info(f"Finished computing FID ({type(self._embedding).__name__})") return [metric_item]