diff --git a/moshi/moshi/server.py b/moshi/moshi/server.py index 771f491..dc35310 100644 --- a/moshi/moshi/server.py +++ b/moshi/moshi/server.py @@ -170,6 +170,9 @@ async def handle_chat(self, request): self.lm_gen.text_prompt_tokens = self.text_tokenizer.encode(wrap_with_system_tags(request.query["text_prompt"])) if len(request.query["text_prompt"]) > 0 else None seed = int(request["seed"]) if "seed" in request.query else None + # Track received opus bytes for processing + opus_bytes_queue = [] + async def recv_loop(): nonlocal close try: @@ -194,13 +197,17 @@ async def recv_loop(): kind = message[0] if kind == 1: # audio payload = message[1:] - opus_reader.append_bytes(payload) + # Kyutai sphn API: append_bytes returns decoded PCM immediately + opus_bytes_queue.append(payload) else: clog.log("warning", f"unknown message kind {kind}") finally: close = True clog.log("info", "connection closed") + # Track encoded opus bytes for sending + opus_bytes_out_queue = [] + async def opus_loop(): all_pcm_data = None @@ -208,47 +215,60 @@ async def opus_loop(): if close: return await asyncio.sleep(0.001) - pcm = opus_reader.read_pcm() - if pcm.shape[-1] == 0: - continue - if all_pcm_data is None: - all_pcm_data = pcm - else: - all_pcm_data = np.concatenate((all_pcm_data, pcm)) - while all_pcm_data.shape[-1] >= self.frame_size: - be = time.time() - chunk = all_pcm_data[: self.frame_size] - all_pcm_data = all_pcm_data[self.frame_size:] - chunk = torch.from_numpy(chunk) - chunk = chunk.to(device=self.device)[None, None] - codes = self.mimi.encode(chunk) - _ = self.other_mimi.encode(chunk) - for c in range(codes.shape[-1]): - tokens = self.lm_gen.step(codes[:, :, c: c + 1]) - if tokens is None: - continue - assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1 - main_pcm = self.mimi.decode(tokens[:, 1:9]) - _ = self.other_mimi.decode(tokens[:, 1:9]) - main_pcm = main_pcm.cpu() - opus_writer.append_pcm(main_pcm[0, 0].numpy()) - text_token = tokens[0, 0, 0].item() - if text_token not in (0, 3): - _text = self.text_tokenizer.id_to_piece(text_token) # type: ignore - _text = _text.replace("▁", " ") - msg = b"\x02" + bytes(_text, encoding="utf8") - await ws.send_bytes(msg) - else: - text_token_map = ['EPAD', 'BOS', 'EOS', 'PAD'] + + # Process any queued opus bytes + while opus_bytes_queue: + payload = opus_bytes_queue.pop(0) + # Kyutai sphn API: append_bytes returns decoded PCM immediately + pcm = opus_reader.append_bytes(payload) + if pcm.shape[-1] == 0: + continue + if all_pcm_data is None: + all_pcm_data = pcm + else: + all_pcm_data = np.concatenate((all_pcm_data, pcm)) + + # Process accumulated PCM data + if all_pcm_data is not None: + while all_pcm_data.shape[-1] >= self.frame_size: + be = time.time() + chunk = all_pcm_data[: self.frame_size] + all_pcm_data = all_pcm_data[self.frame_size:] + chunk = torch.from_numpy(chunk) + chunk = chunk.to(device=self.device)[None, None] + codes = self.mimi.encode(chunk) + _ = self.other_mimi.encode(chunk) + for c in range(codes.shape[-1]): + tokens = self.lm_gen.step(codes[:, :, c: c + 1]) + if tokens is None: + continue + assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1 + main_pcm = self.mimi.decode(tokens[:, 1:9]) + _ = self.other_mimi.decode(tokens[:, 1:9]) + main_pcm = main_pcm.cpu() + # Kyutai sphn API: append_pcm returns encoded opus bytes immediately + opus_bytes = opus_writer.append_pcm(main_pcm[0, 0].numpy()) + if len(opus_bytes) > 0: + opus_bytes_out_queue.append(opus_bytes) + text_token = tokens[0, 0, 0].item() + if text_token not in (0, 3): + _text = self.text_tokenizer.id_to_piece(text_token) # type: ignore + _text = _text.replace("▁", " ") + msg = b"\x02" + bytes(_text, encoding="utf8") + await ws.send_bytes(msg) + else: + text_token_map = ['EPAD', 'BOS', 'EOS', 'PAD'] async def send_loop(): while True: if close: return await asyncio.sleep(0.001) - msg = opus_writer.read_bytes() - if len(msg) > 0: - await ws.send_bytes(b"\x01" + msg) + # Send any queued opus bytes + while opus_bytes_out_queue: + msg = opus_bytes_out_queue.pop(0) + if len(msg) > 0: + await ws.send_bytes(b"\x01" + msg) clog.log("info", "accepted connection") if len(request.query["text_prompt"]) > 0: