diff --git a/src/deepsparse/v2/operators/engine_operator.py b/src/deepsparse/v2/operators/engine_operator.py index b7d920a686..c2fc562c63 100644 --- a/src/deepsparse/v2/operators/engine_operator.py +++ b/src/deepsparse/v2/operators/engine_operator.py @@ -87,6 +87,9 @@ def __init__( self._engine_args = engine_args self._engine_type = engine_type + if not engine_kwargs: + engine_kwargs = {} + self.engine = self.create_engine(**engine_kwargs) @property diff --git a/src/deepsparse/v2/text_generation/__init__.py b/src/deepsparse/v2/text_generation/__init__.py index 37ac88d02f..21cd7e2acd 100644 --- a/src/deepsparse/v2/text_generation/__init__.py +++ b/src/deepsparse/v2/text_generation/__init__.py @@ -13,12 +13,19 @@ # limitations under the License. # flake8: noqa from .autoregressive_preprocess_operator import * +from .compile_generated_tokens import * +from .compile_generations import * from .compile_logits import * +from .generate_new_token import * from .kv_cache_operator import * from .multi_engine_prefill_operator import * from .nl_engine_operator import * from .prep_for_prefill import * from .process_inputs import * +from .process_outputs import * +from .token_generator import * # isort:skip +from .prep_for_generation import * # isort:skip + from .pipeline import * # isort:skip diff --git a/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py b/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py index cfe7cb531b..6e97412e43 100644 --- a/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py +++ b/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py @@ -36,7 +36,6 @@ def __init__(self, sequence_length: int, prompt_sequence_length: int): """ self.sequence_length = sequence_length self.prompt_sequence_length = prompt_sequence_length - self.set_capacity = False _LOGGER.warn( "This operator requires the PipelineState to be set-up with the " @@ -51,16 +50,19 @@ def can_operate(self, inp: Any) -> bool: tokens = inp.get("tokens") kv_cache = inp.get("kv_cache") + if inp.get("in_generation"): + return True + remaining_tokens = len(tokens) - kv_cache.total_num_processed_tokens - if remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length: + can_process = ( + remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length + ) + if can_process and inp.get("in_generation") is None: return True return False def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs): - - if not self.set_capacity: - self.set_capacity = True - kv_cache.set_capacity(self.sequence_length - 1) + kv_cache.set_capacity(self.sequence_length - 1) num_total_processed_tokens = kv_cache.total_num_processed_tokens new_token = tokens[num_total_processed_tokens] @@ -88,13 +90,9 @@ def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwarg engine_inputs = [engine_inputs_map[name] for name in engine_input_names] - onnx_input_names_no_cache = pipeline_state.current_state.get( - "onnx_input_names_no_cache" - ) - engine_inputs = [engine_inputs_map[name] for name in onnx_input_names_no_cache] - return { "engine_inputs": engine_inputs, "kv_cache": kv_cache, "tokens": tokens, + "in_generation": kwargs.get("in_generation"), } diff --git a/src/deepsparse/v2/text_generation/compile_generated_tokens.py b/src/deepsparse/v2/text_generation/compile_generated_tokens.py new file mode 100644 index 0000000000..c87436ab3a --- /dev/null +++ b/src/deepsparse/v2/text_generation/compile_generated_tokens.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["CompileGeneratedTokens"] + + +class CompileGeneratedTokens(Operator): + def run( + self, + new_token, + logits, + finish_reason, + kv_cache, + tokens, + inference_state: InferenceState, + **kwargs, + ): + in_generation = True + + generated_tokens = inference_state.current_state.get("generated_tokens") + generated_logits = inference_state.current_state.get("generated_logits") + finished_reason = inference_state.current_state.get("finished_reason") + + generated_tokens.append(new_token) + generated_logits.append(logits) + finished_reason.append(finish_reason) + + if finish_reason is not None: + in_generation = False + + state_update = { # TODO: check if necessary + "finished_reason": finished_reason, + "generated_tokens": generated_tokens, + "generated_logits": generated_logits, + } + + output = { + "tokens": tokens, + "kv_cache": kv_cache, + "in_generation": in_generation, + } + return output, state_update diff --git a/src/deepsparse/v2/text_generation/compile_generations.py b/src/deepsparse/v2/text_generation/compile_generations.py new file mode 100644 index 0000000000..ed8297ac01 --- /dev/null +++ b/src/deepsparse/v2/text_generation/compile_generations.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import numpy +from pydantic import BaseModel, Field + +from deepsparse.transformers.pipelines.text_generation import FinishReason +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["CompileGenerations", "CompileGenerationsOutput"] + + +class CompileGenerationsOutput(BaseModel): + generated_tokens: Any = Field(description="generated_tokens") + generated_logits: Any = Field(description="generated_logits") + finished_reason: Any = Field(description="finished_reason") + + +class CompileGenerations(Operator): + output_schema = CompileGenerationsOutput + + def can_operate(self, inp: Any): + if inp.get("in_generation") is False: + return True + return False + + def run(self, inference_state: InferenceState, **kwargs): + generated_tokens = inference_state.current_state.get("generated_tokens") + generated_logits = inference_state.current_state.get("generated_logits") + finished_reason = inference_state.current_state.get("finished_reason") + + if len(finished_reason) == 0: + finished_reason.append(FinishReason.LENGTH) + + generated_tokens = numpy.array([generated_tokens]) + generated_logits = numpy.concatenate(generated_logits, axis=1) + return { + "generated_tokens": generated_tokens, + "generated_logits": generated_logits, + "finished_reason": finished_reason, + } diff --git a/src/deepsparse/v2/text_generation/compile_logits.py b/src/deepsparse/v2/text_generation/compile_logits.py index 55c87d791d..21bd50e03e 100644 --- a/src/deepsparse/v2/text_generation/compile_logits.py +++ b/src/deepsparse/v2/text_generation/compile_logits.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from deepsparse.v2.operators import Operator from deepsparse.v2.utils import InferenceState @@ -27,6 +28,11 @@ class CompilePromptLogits(Operator): take prompt logits from each iteration run and update the inference state. """ + def can_operate(self, inp: Any): + if inp.get("in_generation") is None: + return True + return False + def run(self, logits, inference_state: InferenceState, **kwargs): logit_type = "prompt_logits" diff --git a/src/deepsparse/v2/text_generation/generate_new_token.py b/src/deepsparse/v2/text_generation/generate_new_token.py new file mode 100644 index 0000000000..33ab546e39 --- /dev/null +++ b/src/deepsparse/v2/text_generation/generate_new_token.py @@ -0,0 +1,90 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Sequence, Union + +import transformers + +from deepsparse.transformers.pipelines.text_generation import FinishReason +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["GenerateNewTokenOperator"] + + +class GenerateNewTokenOperator(Operator): + def __init__( + self, tokenizer: transformers.PreTrainedTokenizerBase, force_max_tokens: bool + ): + self.force_max_tokens = force_max_tokens + self.tokenizer = tokenizer + + def can_operate(self, inp: Any): + if inp.get("in_generation"): + return True + return False + + def run(self, logits, kv_cache, inference_state: InferenceState, **kwargs): + token_generator = inference_state.current_state.get("token_generator") + token = token_generator.generate(logits=logits[0, -1, :]) + finish_reason = None + + callback = inference_state.current_state.get("callback") + stop = inference_state.current_state.get("stop") + + if token == self.tokenizer.eos_token_id and not self.force_max_tokens: + finish_reason = FinishReason.STOP + + if self._stop_token_generated(token, stop_tokens=stop): + print( + "Stop token %s generated. Stopping generation." + % self.tokenizer.decode(token) + ) + finish_reason = FinishReason.STOP + + if callback is not None and callback(token) is False: + print( + "callback %s returned False, stopping generation." + % callback.__qualname__ + ) + finish_reason = FinishReason.CALLBACK + + max_tokens = inference_state.current_state.get("max_tokens") + if len(inference_state.current_state.get("generated_tokens")) + 1 == max_tokens: + finish_reason = inference_state.current_state.get("length_finish_reason") + + state_update = { + "token_generator": token_generator, + } + + new_generation = { + "logits": logits, + "new_token": token, + "finish_reason": finish_reason, + } + output = {"tokens": token_generator.tokens, "kv_cache": kv_cache} + output.update(new_generation) + return output, state_update + + def _stop_token_generated( + self, token, stop_tokens: Union[None, str, Sequence[str]] + ) -> bool: + if stop_tokens is None: + return False + + decoded_token = self.tokenizer.decode(token) + decoded_token = ( + decoded_token if decoded_token.isspace() else decoded_token.strip() + ) + return decoded_token in stop_tokens diff --git a/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py b/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py index 41ee830a8a..9a885c2355 100644 --- a/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py +++ b/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py @@ -97,6 +97,7 @@ def _case_positions(self, num_total_processed_tokens: int): ) def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs): + kv_cache.set_capacity(self.sequence_length - self.prompt_sequence_length) onnx_input_names_no_cache = pipeline_state.current_state.get( "onnx_input_names_no_cache" diff --git a/src/deepsparse/v2/text_generation/nl_engine_operator.py b/src/deepsparse/v2/text_generation/nl_engine_operator.py index 6c1ad1966e..0bd9098a40 100644 --- a/src/deepsparse/v2/text_generation/nl_engine_operator.py +++ b/src/deepsparse/v2/text_generation/nl_engine_operator.py @@ -36,6 +36,7 @@ class NlEngineInput(BaseModel): engine_inputs: List = Field(description="engine inputs") kv_cache: Any = Field(description="kv_cache object") tokens: List = Field(description="tokens") + in_generation: bool = Field(description="in_generation", default=None) class NLEngineOperator(EngineOperator): @@ -119,7 +120,12 @@ def run(self, inp: NlEngineInput, **kwargs) -> Any: kv_cache=kv_cache, ) - output = {"logits": logits, "kv_cache": kv_cache, "tokens": inp.tokens} + output = { + "logits": logits, + "kv_cache": kv_cache, + "tokens": inp.tokens, + "in_generation": inp.in_generation, + } return output def _add_kv_cache_to_input(self, engine_input, kv_cache): diff --git a/src/deepsparse/v2/text_generation/pipeline.py b/src/deepsparse/v2/text_generation/pipeline.py index 9878aa0061..49826b8af7 100644 --- a/src/deepsparse/v2/text_generation/pipeline.py +++ b/src/deepsparse/v2/text_generation/pipeline.py @@ -15,18 +15,23 @@ from typing import Dict from deepsparse.transformers.utils.helpers import process_generation_config -from deepsparse.v2.operators import Operator from deepsparse.v2.pipeline import Pipeline from deepsparse.v2.routers import GraphRouter from deepsparse.v2.schedulers import OperatorScheduler from deepsparse.v2.text_generation import ( AutoRegressiveOperatorPreprocess, + CompileGeneratedTokens, + CompileGenerations, CompilePromptLogits, + GenerateNewTokenOperator, KVCacheCreator, MultiEnginePrefill, NLEngineOperator, PrepareforPrefill, + PrepareGeneration, ProcessInputsTextGeneration, + ProcessOutputs, + TokenGeneratorOperator, ) from deepsparse.v2.utils import PipelineState @@ -109,17 +114,23 @@ def __init__( sequence_length=sequence_length, ) compile_prompt_logits = CompilePromptLogits() - """ - prep_for_single_engine = PrepareforSingleEngine( - prompt_sequence_length=prompt_sequence_length, + + autoregressive_preprocess = AutoRegressiveOperatorPreprocess( sequence_length=sequence_length, + prompt_sequence_length=prompt_sequence_length, ) - """ - autoregressive_preprocess = AutoRegressiveOperatorPreprocess( + token_generator = TokenGeneratorOperator() + prep_for_generation = PrepareGeneration( sequence_length=sequence_length, prompt_sequence_length=prompt_sequence_length, + token_generator=token_generator, + ) + generate_new_token = GenerateNewTokenOperator( + tokenizer=self.tokenizer, force_max_tokens=force_max_tokens ) - final_step = FinalStep() + process_output = ProcessOutputs(tokenizer=self.tokenizer) + compile_generations = CompileGenerations() + compile_generated_tokens = CompileGeneratedTokens() ops = { "process_input": process_inputs, @@ -130,7 +141,11 @@ def __init__( "multi_engine_prefill": multi_engine_prefill, "compile_logits": compile_prompt_logits, "autoregressive_preprocess": autoregressive_preprocess, - "final_step": final_step, + "prep_for_generation": prep_for_generation, + "generate_new_token": generate_new_token, + "process_outputs": process_output, + "compile_generations": compile_generations, + "compile_generated_tokens": compile_generated_tokens, } routes = { @@ -140,12 +155,22 @@ def __init__( "multi_engine": "compile_logits", "compile_logits": [ "multi_engine_prefill", + "prep_for_generation", "autoregressive_preprocess", - "final_step", ], "autoregressive_preprocess": "single_engine", - "single_engine": "compile_logits", - "final_step": "STOP", + "single_engine": [ + "compile_logits", + "generate_new_token", + ], + "prep_for_generation": "autoregressive_preprocess", + "generate_new_token": "compile_generated_tokens", + "compile_generated_tokens": [ + "autoregressive_preprocess", + "compile_generations", + ], + "compile_generations": "process_outputs", + "process_outputs": "STOP", } router = GraphRouter( @@ -197,17 +222,3 @@ def setup_onnx_file_path(self, model_path, sequence_length) -> str: "See `tokenizer` and `config` arguments for details." ) return onnx_path - - -# NOTE: This is a dummy last step which will be removed. Used as a final step -# for the current routes. -class FinalStep(Operator): - def can_operate(self, *args, **kwargs): - return True - - def run(self, *args, **kwargs): - import numpy - - inference_state = kwargs.get("inference_state") - prompt_logits = inference_state.current_state.get("prompt_logits") - return numpy.concatenate(prompt_logits, axis=1) diff --git a/src/deepsparse/v2/text_generation/prep_for_generation.py b/src/deepsparse/v2/text_generation/prep_for_generation.py new file mode 100644 index 0000000000..544af43980 --- /dev/null +++ b/src/deepsparse/v2/text_generation/prep_for_generation.py @@ -0,0 +1,140 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import numpy + +from deepsparse.transformers.pipelines.text_generation import FinishReason +from deepsparse.v2.operators import Operator +from deepsparse.v2.text_generation import TokenGeneratorOperator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["PrepareGeneration"] + + +class PrepareGeneration(Operator): + def __init__( + self, + token_generator: TokenGeneratorOperator, + prompt_sequence_length: int, + sequence_length: int, + ): + self.prompt_sequence_length = prompt_sequence_length + self.sequence_length = sequence_length + self.token_generator_creator = token_generator + + def can_operate(self, inp: Any): + kv_cache = inp.get("kv_cache") + tokens = inp.get("tokens") + + # If the number of prompt tokens is greater than what we've processed, + # don't start generation. Should be equal when started as all prompt logits + # should be accounted for and we should have updated the kv_cache for the single + # token engine. + if len(tokens) == kv_cache.total_num_processed_tokens: + return True + return False + + @staticmethod + def set_generated_length( + max_length: int, + prompt_tokens_length: int, + sequence_length: int, + prompt_sequence_length: int, + max_new_tokens: int, + finish_reason_choices: "FinishReason", # noqa + ): + """ + Determine the length of the generated tokens. The hard cap on the total number + of tokens is based on the sequence length. If max_length is provided and is less + than the sequence length, it will be used to cap the total number of tokens + generated. If it is not provided, the max_new_tokens attribute will be used and + also capped by the sequence length. + + :param max_length: max_length attribute, provided as input during inference + :param prompt_tokens_length: the number of prompt tokens used as part of the + generated output + :param sequence_length: the sequence length used for the pipeline + :param prompt_sequence_length: the prompt sequence length used for the pipeline + :param max_new_tokens: the max_new_tokens attribute, which may be provided + as part of the input during inference + """ + if max_length: + # if max_length provided, use that to cap total tokens generated + max_tokens = max_length + finish_reason = finish_reason_choices.LENGTH + else: + # if not provided, max tokens is based on max_new_tokens + prompt tokens + max_tokens = ( + min(max_new_tokens, sequence_length - prompt_sequence_length) + + prompt_tokens_length + ) + finish_reason = finish_reason_choices.MAX_NEW_TOKENS + + # hard model/pipeline cap + return ( + (sequence_length, finish_reason_choices.CAPACITY) + if sequence_length < max_tokens + else (max_tokens, finish_reason) + ) + + def run( + self, tokens: Any, kv_cache: Any, inference_state: InferenceState, **kwargs + ): + prompt_logits = inference_state.current_state.get("prompt_logits") + prompt_logits = numpy.concatenate(prompt_logits, axis=1) + # TODO: clean this up such that dont have to keep writing current_state + # everywhere + + generation_config = inference_state.current_state.get("generation_config") + include_prompt_logits = inference_state.current_state.get( + "include_prompt_logits" + ) + + token_generator_creator_output = self.token_generator_creator.run( + logits_shape=prompt_logits[0, -1, :].shape, + deterministic=not generation_config.do_sample, + sampling_temperature=generation_config.temperature, + tokens=tokens, + **inference_state.current_state, + ) + token_generator = token_generator_creator_output.get("token_generator") + token_generator.generate(prompt_logits[0, -1, :]) + + max_tokens, length_finish_reason = PrepareGeneration.set_generated_length( + max_length=generation_config.max_length, + prompt_tokens_length=1, + max_new_tokens=generation_config.max_new_tokens, + sequence_length=self.sequence_length, + prompt_sequence_length=self.prompt_sequence_length, + finish_reason_choices=FinishReason, + ) + state_update = { + "max_tokens": max_tokens, + "length_finish_reason": length_finish_reason, + "generated_tokens": [token_generator.tokens[-1]], + "generated_logits": [prompt_logits] + if include_prompt_logits + else [numpy.expand_dims(prompt_logits[:, -1, :], 0)], + "finished_reason": [], + "token_generator": token_generator, + } + + output = { + "tokens": token_generator.tokens, + "kv_cache": kv_cache, + "in_generation": True, + } + return output, state_update diff --git a/src/deepsparse/v2/text_generation/process_inputs.py b/src/deepsparse/v2/text_generation/process_inputs.py index 528dcee0b7..e57e402983 100644 --- a/src/deepsparse/v2/text_generation/process_inputs.py +++ b/src/deepsparse/v2/text_generation/process_inputs.py @@ -28,7 +28,7 @@ class GenerationDefaults: num_return_sequences = 1 - max_length = 1024 + max_length = 100 max_new_tokens = None output_scores = False top_k = 0 diff --git a/src/deepsparse/v2/text_generation/process_outputs.py b/src/deepsparse/v2/text_generation/process_outputs.py new file mode 100644 index 0000000000..ca1cf78521 --- /dev/null +++ b/src/deepsparse/v2/text_generation/process_outputs.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +from typing import Optional + +import numpy + +from deepsparse.transformers.pipelines.text_generation import ( + FinishReason, + GeneratedText, + TextGenerationOutput, +) +from deepsparse.v2.operators import Operator +from deepsparse.v2.text_generation.compile_generations import CompileGenerationsOutput +from deepsparse.v2.utils import InferenceState + + +class ProcessOutputs(Operator): + output_schema = TextGenerationOutput + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def _create_generated_text_output( + self, + sequence: str, + finish_reason: Optional[FinishReason] = None, + logits: Optional[numpy.array] = None, + ): + if finish_reason: + return GeneratedText( + text=sequence, + score=logits, + finished=True, + finished_reason=finish_reason.value, + ) + return GeneratedText( + text=sequence, + score=logits, + finished=False, + ) + + def run( + self, inp: CompileGenerationsOutput, inference_state: InferenceState, **kwargs + ): + generation_config = inference_state.current_state.get("generation_config") + generated_tokens = inp.generated_tokens + generated_logits = ( + inp.generated_logits if generation_config.output_scores else None + ) + finished_reason = inp.finished_reason + sequences = self.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + ) + + finished_reason = [f for f in finished_reason if f] + + if generated_logits is not None: + generations = list( + map( + self._create_generated_text_output, + sequences, + finished_reason, + generated_logits, + ) + ) + else: + generations = list( + map(self._create_generated_text_output, sequences, finished_reason) + ) + outputs = dict( + created=datetime.datetime.now(), + prompts=inference_state.current_state.get("prompts"), + generations=generations, + ) + + return outputs diff --git a/src/deepsparse/v2/text_generation/token_generator.py b/src/deepsparse/v2/text_generation/token_generator.py new file mode 100644 index 0000000000..9148d71cc8 --- /dev/null +++ b/src/deepsparse/v2/text_generation/token_generator.py @@ -0,0 +1,30 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from deepsparse.transformers.utils.token_generator import TokenGenerator +from deepsparse.v2.operators import Operator + + +__all__ = ["TokenGeneratorOperator"] + + +class TokenGeneratorOperator(Operator): + def run(self, logits_shape, deterministic, tokens, sampling_temperature, **kwargs): + token_generator = TokenGenerator( + logits_shape=logits_shape, + deterministic=deterministic, + tokens=tokens, + sampling_temperature=sampling_temperature, + **kwargs, + ) + return {"token_generator": token_generator}