Skip to content

Commit

Permalink
Support streaming on internal API
Browse files Browse the repository at this point in the history
  • Loading branch information
jncraton committed Sep 26, 2024
1 parent efd59d7 commit cd1489c
Showing 1 changed file with 65 additions and 24 deletions.
89 changes: 65 additions & 24 deletions languagemodels/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ def chat_oa(engine, prompt, max_tokens=200, temperature=0):
raise InferenceException(f"OpenAI error: {resp}")


def stream_results(results, tokenizer):
"""Map a token iterator to a substring iterator"""
tokens = []
last_len = 0

for result in results:
tokens.append(result.token_id)
text = tokenizer.decode(tokens)
yield text[last_len:]
last_len = len(text)


def generate(
instructions: List[str],
max_tokens: int = 200,
Expand All @@ -121,6 +133,7 @@ def generate(
prefix: str = "",
suppress: List[str] = [],
model: str = "instruct",
stream: bool = False,
):
"""Generates completions for a prompt
Expand All @@ -129,6 +142,9 @@ def generate(
>>> generate(["What is the capital of France?"])
['...Paris...']
>>> list(generate(["What is the capital of France?"], stream=True))
['...Paris...']
"""
if os.environ.get("LANGUAGEMODELS_TS_KEY") or os.environ.get(
"LANGUAGEMODELS_TS_SERVER"
Expand Down Expand Up @@ -157,31 +173,56 @@ def generate(
outputs_ids = []
if hasattr(model, "translate_batch"):
prefix = tokenizer.encode(prefix, add_special_tokens=False).tokens
results = model.translate_batch(
prompts_tok,
target_prefix=[prefix] * len(prompts),
repetition_penalty=repetition_penalty,
max_decoding_length=max_tokens,
sampling_temperature=temperature,
sampling_topk=topk,
suppress_sequences=suppress,
beam_size=1,
)
outputs_tokens = [r.hypotheses[0] for r in results]
for output in outputs_tokens:
outputs_ids.append([tokenizer.token_to_id(t) for t in output])
if stream:
results = model.generate_tokens(
prompts_tok[0],
target_prefix=prefix,
repetition_penalty=repetition_penalty,
max_decoding_length=max_tokens,
sampling_temperature=temperature,
sampling_topk=topk,
suppress_sequences=suppress,
)

return stream_results(results, tokenizer)
else:
results = model.translate_batch(
prompts_tok,
target_prefix=[prefix] * len(prompts),
repetition_penalty=repetition_penalty,
max_decoding_length=max_tokens,
sampling_temperature=temperature,
sampling_topk=topk,
suppress_sequences=suppress,
beam_size=1,
)
outputs_tokens = [r.hypotheses[0] for r in results]
for output in outputs_tokens:
outputs_ids.append([tokenizer.token_to_id(t) for t in output])
else:
results = model.generate_batch(
prompts_tok,
repetition_penalty=repetition_penalty,
max_length=max_tokens,
sampling_temperature=temperature,
sampling_topk=topk,
suppress_sequences=suppress,
beam_size=1,
include_prompt_in_result=False,
)
outputs_ids = [r.sequences_ids[0] for r in results]
if stream:
results = model.generate_tokens(
prompts_tok,
repetition_penalty=repetition_penalty,
max_length=max_tokens,
sampling_temperature=temperature,
sampling_topk=topk,
suppress_sequences=suppress,
)

return stream_results(results, tokenizer)
else:
results = model.generate_batch(
prompts_tok,
repetition_penalty=repetition_penalty,
max_length=max_tokens,
sampling_temperature=temperature,
sampling_topk=topk,
suppress_sequences=suppress,
beam_size=1,
include_prompt_in_result=False,
)
outputs_ids = [r.sequences_ids[0] for r in results]

model_info["requests"] = model_info.get("requests", 0) + len(prompts)

Expand Down

0 comments on commit cd1489c

Please sign in to comment.