diff --git a/config/misc/default.yaml b/config/misc/default.yaml index b41c965..db0a12b 100644 --- a/config/misc/default.yaml +++ b/config/misc/default.yaml @@ -1,3 +1,3 @@ seed: 42 debug: False -modes: [ train, valid ] +modes: [ train , valid , test ] diff --git a/sage/data/dataloader.py b/sage/data/dataloader.py index e9e355a..60c8a8f 100755 --- a/sage/data/dataloader.py +++ b/sage/data/dataloader.py @@ -283,19 +283,6 @@ def remove_duplicates(self, labels: pd.DataFrame) -> pd.DataFrame: labels = labels[~_dups_bool] return labels - def _exclude_data(self, - lst: pd.DataFrame, - root: Path, - exclusion_fname: str = "exclusion.csv") -> List[Path]: - try: - exc = pd.read_csv(root / exclusion_fname, header=None) - exclusion = set(exc.values.flatten().tolist()) - lst = [f for f in lst if f not in exclusion] - except: - logger.info("No exclusion file found. %s", root / exclusion_fname) - pass - return lst - class UKBClassification(UKBDataset): def __init__(self, diff --git a/sage/xai/trainer.py b/sage/xai/trainer.py index 0a286a1..2eadd10 100644 --- a/sage/xai/trainer.py +++ b/sage/xai/trainer.py @@ -124,7 +124,7 @@ def _configure_xai(self, attr_mtd = ca.InputXGradient(forward_func=model.backbone) xai = ca.NoiseTunnel(attribution_method=attr_mtd) if xai_call_kwarg is None: - xai_call_kwarg = dict(nt_type="smoothgrad", nt_samples=15) + xai_call_kwarg = dict(nt_type="smoothgrad", nt_samples=10) else: breakpoint() self.xai_call_kwarg = dict() if xai_call_kwarg is None else xai_call_kwarg