diff --git a/requirements-dev.txt b/requirements-dev.txt index 84b5700d3..f02124df1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ -accelerate>=0.11.0 +accelerate>=0.22.0 fire flaky future>=0.17.1 diff --git a/skorch/tests/test_hf.py b/skorch/tests/test_hf.py index e41c30496..63f0c2476 100644 --- a/skorch/tests/test_hf.py +++ b/skorch/tests/test_hf.py @@ -660,10 +660,32 @@ def test_mixed_precision_save_load_params( accelerator = accelerator_cls(mixed_precision=mixed_precision) net = net_cls(accelerator=accelerator) net.initialize() + filename = tmp_path / 'accel-net-params.pth' net.save_params(f_params=filename) net.load_params(f_params=filename) + @pytest.mark.parametrize('mixed_precision', ['fp16', 'bf16', 'no']) + def test_mixed_precision_inference( + self, net_cls, accelerator_cls, data, mixed_precision, tmp_path + ): + from accelerate.utils import is_bf16_available + + if (mixed_precision != 'no') and not torch.cuda.is_available(): + pytest.skip('skipping AMP test because device does not support it') + if (mixed_precision == 'bf16') and not is_bf16_available(): + pytest.skip('skipping bf16 test because device does not support it') + + X, y = data + accelerator = accelerator_cls(mixed_precision=mixed_precision) + net = net_cls(accelerator=accelerator) + net.fit(X, y) + net.predict(X) + net.predict_proba(X) + + Xt = torch.from_numpy(X).to(net.device) + net.forward(Xt) + def test_force_cpu(self, net_cls, accelerator_cls, data): accelerator = accelerator_cls(device_placement=False, cpu=True) net = net_cls(accelerator=accelerator)