diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 1adf75e79..835e002ed 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -690,8 +690,8 @@ def to(self, *args, **kwargs): requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights, ) - new_param.CB = self.CB - new_param.SCB = self.SCB + new_param.CB = self.CB.to(device=device) if self.CB is not None else self.CB + new_param.SCB = self.SCB.to(device=device) if self.SCB is not None else self.SCB return new_param