diff --git a/easyeditor/models/pmet/pmet_main.py b/easyeditor/models/pmet/pmet_main.py index 5112e0d7..23b0bc65 100644 --- a/easyeditor/models/pmet/pmet_main.py +++ b/easyeditor/models/pmet/pmet_main.py @@ -336,6 +336,7 @@ def get_cov( sample_size=mom2_n_samples, precision=mom2_dtype, force_recompute=force_recompute, + hparams=hparams ) COV_CACHE[key] = stat.mom2.moment().float().to("cpu")