diff --git a/moshi/moshi/models/lm.py b/moshi/moshi/models/lm.py index a64d736..ee8d867 100644 --- a/moshi/moshi/models/lm.py +++ b/moshi/moshi/models/lm.py @@ -976,7 +976,7 @@ def load_voice_prompt(self, voice_prompt: str): def load_voice_prompt_embeddings(self, path: str): self.voice_prompt = path - state = torch.load(path) + state = torch.load(path, map_location=self.lm_model.device) self.voice_prompt_audio = None self.voice_prompt_embeddings = state["embeddings"].to(self.lm_model.device)