diff --git a/moshi/moshi/models/lm.py b/moshi/moshi/models/lm.py index a64d736..31cbb24 100644 --- a/moshi/moshi/models/lm.py +++ b/moshi/moshi/models/lm.py @@ -976,7 +976,7 @@ def load_voice_prompt(self, voice_prompt: str): def load_voice_prompt_embeddings(self, path: str): self.voice_prompt = path - state = torch.load(path) + state = torch.load(path, map_location=torch.device(self.lm_model.device)) self.voice_prompt_audio = None self.voice_prompt_embeddings = state["embeddings"].to(self.lm_model.device) diff --git a/moshi/moshi/modules/transformer.py b/moshi/moshi/modules/transformer.py index 553a2a8..e597790 100644 --- a/moshi/moshi/modules/transformer.py +++ b/moshi/moshi/modules/transformer.py @@ -265,8 +265,12 @@ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult: B, H, T, D = k.shape indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset indexes = indexes % self.capacity - self.cache[0].index_copy_(2, indexes, k) - self.cache[1].index_copy_(2, indexes, v) + if k.device.type == "mps": + self.cache[0][:, :, indexes, :] = k + self.cache[1][:, :, indexes, :] = v + else: + self.cache[0].index_copy_(2, indexes, k) + self.cache[1].index_copy_(2, indexes, v) self.end_offset.add_(T) keys = self.cache[0]