From 9343af468084aecc9a9dc98809a07a267ece897e Mon Sep 17 00:00:00 2001 From: Sunny Deng Date: Sun, 25 Jan 2026 23:17:45 +0800 Subject: [PATCH 1/3] Fix inference issue for mps --- moshi/moshi/models/lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moshi/moshi/models/lm.py b/moshi/moshi/models/lm.py index a64d736..31cbb24 100644 --- a/moshi/moshi/models/lm.py +++ b/moshi/moshi/models/lm.py @@ -976,7 +976,7 @@ def load_voice_prompt(self, voice_prompt: str): def load_voice_prompt_embeddings(self, path: str): self.voice_prompt = path - state = torch.load(path) + state = torch.load(path, map_location=torch.device(self.lm_model.device)) self.voice_prompt_audio = None self.voice_prompt_embeddings = state["embeddings"].to(self.lm_model.device) From da8a71a7cc7cafd03e6cf63c043429d791eaa0c9 Mon Sep 17 00:00:00 2001 From: Sunny Deng Date: Sun, 25 Jan 2026 23:32:56 +0800 Subject: [PATCH 2/3] Fix unsupported operation --- moshi/moshi/modules/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moshi/moshi/modules/transformer.py b/moshi/moshi/modules/transformer.py index 553a2a8..168a662 100644 --- a/moshi/moshi/modules/transformer.py +++ b/moshi/moshi/modules/transformer.py @@ -265,8 +265,8 @@ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult: B, H, T, D = k.shape indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset indexes = indexes % self.capacity - self.cache[0].index_copy_(2, indexes, k) - self.cache[1].index_copy_(2, indexes, v) + self.cache[0][:, :, indexes, :] = k + self.cache[1][:, :, indexes, :] = v self.end_offset.add_(T) keys = self.cache[0] From e6cd411ccbc3ee3d1e1e74318729d39a2bd88e6b Mon Sep 17 00:00:00 2001 From: Sunny Deng Date: Mon, 26 Jan 2026 09:38:35 +0800 Subject: [PATCH 3/3] only use fallback for mps --- moshi/moshi/modules/transformer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/moshi/moshi/modules/transformer.py b/moshi/moshi/modules/transformer.py index 168a662..e597790 100644 --- a/moshi/moshi/modules/transformer.py +++ b/moshi/moshi/modules/transformer.py @@ -265,8 +265,12 @@ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult: B, H, T, D = k.shape indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset indexes = indexes % self.capacity - self.cache[0][:, :, indexes, :] = k - self.cache[1][:, :, indexes, :] = v + if k.device.type == "mps": + self.cache[0][:, :, indexes, :] = k + self.cache[1][:, :, indexes, :] = v + else: + self.cache[0].index_copy_(2, indexes, k) + self.cache[1].index_copy_(2, indexes, v) self.end_offset.add_(T) keys = self.cache[0]