From cd1489cce70e81197f24035c51f851a17ef5ee5a Mon Sep 17 00:00:00 2001 From: Jon Craton Date: Thu, 26 Sep 2024 14:18:48 -0400 Subject: [PATCH] Support streaming on internal API --- languagemodels/inference.py | 89 +++++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 24 deletions(-) diff --git a/languagemodels/inference.py b/languagemodels/inference.py index e9696af..d6643b3 100644 --- a/languagemodels/inference.py +++ b/languagemodels/inference.py @@ -112,6 +112,18 @@ def chat_oa(engine, prompt, max_tokens=200, temperature=0): raise InferenceException(f"OpenAI error: {resp}") +def stream_results(results, tokenizer): + """Map a token iterator to a substring iterator""" + tokens = [] + last_len = 0 + + for result in results: + tokens.append(result.token_id) + text = tokenizer.decode(tokens) + yield text[last_len:] + last_len = len(text) + + def generate( instructions: List[str], max_tokens: int = 200, @@ -121,6 +133,7 @@ def generate( prefix: str = "", suppress: List[str] = [], model: str = "instruct", + stream: bool = False, ): """Generates completions for a prompt @@ -129,6 +142,9 @@ def generate( >>> generate(["What is the capital of France?"]) ['...Paris...'] + + >>> list(generate(["What is the capital of France?"], stream=True)) + ['...Paris...'] """ if os.environ.get("LANGUAGEMODELS_TS_KEY") or os.environ.get( "LANGUAGEMODELS_TS_SERVER" @@ -157,31 +173,56 @@ def generate( outputs_ids = [] if hasattr(model, "translate_batch"): prefix = tokenizer.encode(prefix, add_special_tokens=False).tokens - results = model.translate_batch( - prompts_tok, - target_prefix=[prefix] * len(prompts), - repetition_penalty=repetition_penalty, - max_decoding_length=max_tokens, - sampling_temperature=temperature, - sampling_topk=topk, - suppress_sequences=suppress, - beam_size=1, - ) - outputs_tokens = [r.hypotheses[0] for r in results] - for output in outputs_tokens: - outputs_ids.append([tokenizer.token_to_id(t) for t in output]) + if stream: + results = model.generate_tokens( + prompts_tok[0], + target_prefix=prefix, + repetition_penalty=repetition_penalty, + max_decoding_length=max_tokens, + sampling_temperature=temperature, + sampling_topk=topk, + suppress_sequences=suppress, + ) + + return stream_results(results, tokenizer) + else: + results = model.translate_batch( + prompts_tok, + target_prefix=[prefix] * len(prompts), + repetition_penalty=repetition_penalty, + max_decoding_length=max_tokens, + sampling_temperature=temperature, + sampling_topk=topk, + suppress_sequences=suppress, + beam_size=1, + ) + outputs_tokens = [r.hypotheses[0] for r in results] + for output in outputs_tokens: + outputs_ids.append([tokenizer.token_to_id(t) for t in output]) else: - results = model.generate_batch( - prompts_tok, - repetition_penalty=repetition_penalty, - max_length=max_tokens, - sampling_temperature=temperature, - sampling_topk=topk, - suppress_sequences=suppress, - beam_size=1, - include_prompt_in_result=False, - ) - outputs_ids = [r.sequences_ids[0] for r in results] + if stream: + results = model.generate_tokens( + prompts_tok, + repetition_penalty=repetition_penalty, + max_length=max_tokens, + sampling_temperature=temperature, + sampling_topk=topk, + suppress_sequences=suppress, + ) + + return stream_results(results, tokenizer) + else: + results = model.generate_batch( + prompts_tok, + repetition_penalty=repetition_penalty, + max_length=max_tokens, + sampling_temperature=temperature, + sampling_topk=topk, + suppress_sequences=suppress, + beam_size=1, + include_prompt_in_result=False, + ) + outputs_ids = [r.sequences_ids[0] for r in results] model_info["requests"] = model_info.get("requests", 0) + len(prompts)