From 5842affb9f2d71f1a5baacfca3b0d90ea8b6087a Mon Sep 17 00:00:00 2001 From: Leo Wang Date: Fri, 14 Mar 2025 02:30:23 -0700 Subject: [PATCH] [server] directly initialize msg_enc to torch.float32 to support mps --- src/silentcipher/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/silentcipher/server.py b/src/silentcipher/server.py index 708ad75..7964133 100644 --- a/src/silentcipher/server.py +++ b/src/silentcipher/server.py @@ -314,7 +314,7 @@ def binary_encode(mes): binary_encoded_message = binary_encode(message) msgs, msgs_compact = self.letters_encoding(carrier.shape[3], [binary_encoded_message]) - msg_enc = torch.tensor(msgs, device=self.device).unsqueeze(0).float() + msg_enc = torch.tensor(msgs, device=self.device, dtype=torch.float32).unsqueeze(0) carrier_enc = self.enc_c(carrier) # encode the carrier msg_enc = self.enc_c.transform_message(msg_enc)