Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 56 additions & 36 deletions moshi/moshi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ async def handle_chat(self, request):
self.lm_gen.text_prompt_tokens = self.text_tokenizer.encode(wrap_with_system_tags(request.query["text_prompt"])) if len(request.query["text_prompt"]) > 0 else None
seed = int(request["seed"]) if "seed" in request.query else None

# Track received opus bytes for processing
opus_bytes_queue = []

async def recv_loop():
nonlocal close
try:
Expand All @@ -194,61 +197,78 @@ async def recv_loop():
kind = message[0]
if kind == 1: # audio
payload = message[1:]
opus_reader.append_bytes(payload)
# Kyutai sphn API: append_bytes returns decoded PCM immediately
opus_bytes_queue.append(payload)
else:
clog.log("warning", f"unknown message kind {kind}")
finally:
close = True
clog.log("info", "connection closed")

# Track encoded opus bytes for sending
opus_bytes_out_queue = []

async def opus_loop():
all_pcm_data = None

while True:
if close:
return
await asyncio.sleep(0.001)
pcm = opus_reader.read_pcm()
if pcm.shape[-1] == 0:
continue
if all_pcm_data is None:
all_pcm_data = pcm
else:
all_pcm_data = np.concatenate((all_pcm_data, pcm))
while all_pcm_data.shape[-1] >= self.frame_size:
be = time.time()
chunk = all_pcm_data[: self.frame_size]
all_pcm_data = all_pcm_data[self.frame_size:]
chunk = torch.from_numpy(chunk)
chunk = chunk.to(device=self.device)[None, None]
codes = self.mimi.encode(chunk)
_ = self.other_mimi.encode(chunk)
for c in range(codes.shape[-1]):
tokens = self.lm_gen.step(codes[:, :, c: c + 1])
if tokens is None:
continue
assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1
main_pcm = self.mimi.decode(tokens[:, 1:9])
_ = self.other_mimi.decode(tokens[:, 1:9])
main_pcm = main_pcm.cpu()
opus_writer.append_pcm(main_pcm[0, 0].numpy())
text_token = tokens[0, 0, 0].item()
if text_token not in (0, 3):
_text = self.text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("▁", " ")
msg = b"\x02" + bytes(_text, encoding="utf8")
await ws.send_bytes(msg)
else:
text_token_map = ['EPAD', 'BOS', 'EOS', 'PAD']

# Process any queued opus bytes
while opus_bytes_queue:
payload = opus_bytes_queue.pop(0)
# Kyutai sphn API: append_bytes returns decoded PCM immediately
pcm = opus_reader.append_bytes(payload)
if pcm.shape[-1] == 0:
continue
if all_pcm_data is None:
all_pcm_data = pcm
else:
all_pcm_data = np.concatenate((all_pcm_data, pcm))

# Process accumulated PCM data
if all_pcm_data is not None:
while all_pcm_data.shape[-1] >= self.frame_size:
be = time.time()
chunk = all_pcm_data[: self.frame_size]
all_pcm_data = all_pcm_data[self.frame_size:]
chunk = torch.from_numpy(chunk)
chunk = chunk.to(device=self.device)[None, None]
codes = self.mimi.encode(chunk)
_ = self.other_mimi.encode(chunk)
for c in range(codes.shape[-1]):
tokens = self.lm_gen.step(codes[:, :, c: c + 1])
if tokens is None:
continue
assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1
main_pcm = self.mimi.decode(tokens[:, 1:9])
_ = self.other_mimi.decode(tokens[:, 1:9])
main_pcm = main_pcm.cpu()
# Kyutai sphn API: append_pcm returns encoded opus bytes immediately
opus_bytes = opus_writer.append_pcm(main_pcm[0, 0].numpy())
if len(opus_bytes) > 0:
opus_bytes_out_queue.append(opus_bytes)
text_token = tokens[0, 0, 0].item()
if text_token not in (0, 3):
_text = self.text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("▁", " ")
msg = b"\x02" + bytes(_text, encoding="utf8")
await ws.send_bytes(msg)
else:
text_token_map = ['EPAD', 'BOS', 'EOS', 'PAD']

async def send_loop():
while True:
if close:
return
await asyncio.sleep(0.001)
msg = opus_writer.read_bytes()
if len(msg) > 0:
await ws.send_bytes(b"\x01" + msg)
# Send any queued opus bytes
while opus_bytes_out_queue:
msg = opus_bytes_out_queue.pop(0)
if len(msg) > 0:
await ws.send_bytes(b"\x01" + msg)

clog.log("info", "accepted connection")
if len(request.query["text_prompt"]) > 0:
Expand Down