Skip to content

Commit

Permalink
[Pipeline Refactor] Additional Operators, Route update and completed …
Browse files Browse the repository at this point in the history
…generation functionality (#1356)

* initial functionality and working example with image classification

* remove testing image

* rebase fixes

* initial functionality and working example with image classification

* text gen

* updates func

* prompt inference, initial functionality

* remove image; update state docstring

* Fix typo

* add todo for split/join

* remove context, clean-up args, remove prefill_preprocess_operaator

* fix docstrings

* initial functionality and working example with image classification

* updates func

* prompt inference, initial functionality

* finish generation operators and update routes

* further breakdown operators

* add operators

* fix can_operate condition

* update can_operate to not rely on the inference_state

* rebase + update

* fix condition

* fix capacity settting again

* typo fixes
  • Loading branch information
dsikka authored and dbogunowicz committed Dec 18, 2023
1 parent d54ef26 commit 5ee36d6
Show file tree
Hide file tree
Showing 14 changed files with 529 additions and 38 deletions.
3 changes: 3 additions & 0 deletions src/deepsparse/v2/operators/engine_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def __init__(
self._engine_args = engine_args
self._engine_type = engine_type

if not engine_kwargs:
engine_kwargs = {}

self.engine = self.create_engine(**engine_kwargs)

@property
Expand Down
7 changes: 7 additions & 0 deletions src/deepsparse/v2/text_generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
# limitations under the License.
# flake8: noqa
from .autoregressive_preprocess_operator import *
from .compile_generated_tokens import *
from .compile_generations import *
from .compile_logits import *
from .generate_new_token import *
from .kv_cache_operator import *
from .multi_engine_prefill_operator import *
from .nl_engine_operator import *
from .prep_for_prefill import *
from .process_inputs import *
from .process_outputs import *


from .token_generator import * # isort:skip
from .prep_for_generation import * # isort:skip

from .pipeline import * # isort:skip
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(self, sequence_length: int, prompt_sequence_length: int):
"""
self.sequence_length = sequence_length
self.prompt_sequence_length = prompt_sequence_length
self.set_capacity = False

_LOGGER.warn(
"This operator requires the PipelineState to be set-up with the "
Expand All @@ -51,16 +50,19 @@ def can_operate(self, inp: Any) -> bool:
tokens = inp.get("tokens")
kv_cache = inp.get("kv_cache")

if inp.get("in_generation"):
return True

remaining_tokens = len(tokens) - kv_cache.total_num_processed_tokens
if remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length:
can_process = (
remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length
)
if can_process and inp.get("in_generation") is None:
return True
return False

def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs):

if not self.set_capacity:
self.set_capacity = True
kv_cache.set_capacity(self.sequence_length - 1)
kv_cache.set_capacity(self.sequence_length - 1)

num_total_processed_tokens = kv_cache.total_num_processed_tokens
new_token = tokens[num_total_processed_tokens]
Expand Down Expand Up @@ -88,13 +90,9 @@ def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwarg

engine_inputs = [engine_inputs_map[name] for name in engine_input_names]

onnx_input_names_no_cache = pipeline_state.current_state.get(
"onnx_input_names_no_cache"
)
engine_inputs = [engine_inputs_map[name] for name in onnx_input_names_no_cache]

return {
"engine_inputs": engine_inputs,
"kv_cache": kv_cache,
"tokens": tokens,
"in_generation": kwargs.get("in_generation"),
}
56 changes: 56 additions & 0 deletions src/deepsparse/v2/text_generation/compile_generated_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepsparse.v2.operators import Operator
from deepsparse.v2.utils import InferenceState


__all__ = ["CompileGeneratedTokens"]


class CompileGeneratedTokens(Operator):
def run(
self,
new_token,
logits,
finish_reason,
kv_cache,
tokens,
inference_state: InferenceState,
**kwargs,
):
in_generation = True

generated_tokens = inference_state.current_state.get("generated_tokens")
generated_logits = inference_state.current_state.get("generated_logits")
finished_reason = inference_state.current_state.get("finished_reason")

generated_tokens.append(new_token)
generated_logits.append(logits)
finished_reason.append(finish_reason)

if finish_reason is not None:
in_generation = False

state_update = { # TODO: check if necessary
"finished_reason": finished_reason,
"generated_tokens": generated_tokens,
"generated_logits": generated_logits,
}

output = {
"tokens": tokens,
"kv_cache": kv_cache,
"in_generation": in_generation,
}
return output, state_update
55 changes: 55 additions & 0 deletions src/deepsparse/v2/text_generation/compile_generations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import numpy
from pydantic import BaseModel, Field

from deepsparse.transformers.pipelines.text_generation import FinishReason
from deepsparse.v2.operators import Operator
from deepsparse.v2.utils import InferenceState


__all__ = ["CompileGenerations", "CompileGenerationsOutput"]


class CompileGenerationsOutput(BaseModel):
generated_tokens: Any = Field(description="generated_tokens")
generated_logits: Any = Field(description="generated_logits")
finished_reason: Any = Field(description="finished_reason")


class CompileGenerations(Operator):
output_schema = CompileGenerationsOutput

def can_operate(self, inp: Any):
if inp.get("in_generation") is False:
return True
return False

def run(self, inference_state: InferenceState, **kwargs):
generated_tokens = inference_state.current_state.get("generated_tokens")
generated_logits = inference_state.current_state.get("generated_logits")
finished_reason = inference_state.current_state.get("finished_reason")

if len(finished_reason) == 0:
finished_reason.append(FinishReason.LENGTH)

generated_tokens = numpy.array([generated_tokens])
generated_logits = numpy.concatenate(generated_logits, axis=1)
return {
"generated_tokens": generated_tokens,
"generated_logits": generated_logits,
"finished_reason": finished_reason,
}
6 changes: 6 additions & 0 deletions src/deepsparse/v2/text_generation/compile_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from deepsparse.v2.operators import Operator
from deepsparse.v2.utils import InferenceState
Expand All @@ -27,6 +28,11 @@ class CompilePromptLogits(Operator):
take prompt logits from each iteration run and update the inference state.
"""

def can_operate(self, inp: Any):
if inp.get("in_generation") is None:
return True
return False

def run(self, logits, inference_state: InferenceState, **kwargs):
logit_type = "prompt_logits"

Expand Down
90 changes: 90 additions & 0 deletions src/deepsparse/v2/text_generation/generate_new_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Sequence, Union

import transformers

from deepsparse.transformers.pipelines.text_generation import FinishReason
from deepsparse.v2.operators import Operator
from deepsparse.v2.utils import InferenceState


__all__ = ["GenerateNewTokenOperator"]


class GenerateNewTokenOperator(Operator):
def __init__(
self, tokenizer: transformers.PreTrainedTokenizerBase, force_max_tokens: bool
):
self.force_max_tokens = force_max_tokens
self.tokenizer = tokenizer

def can_operate(self, inp: Any):
if inp.get("in_generation"):
return True
return False

def run(self, logits, kv_cache, inference_state: InferenceState, **kwargs):
token_generator = inference_state.current_state.get("token_generator")
token = token_generator.generate(logits=logits[0, -1, :])
finish_reason = None

callback = inference_state.current_state.get("callback")
stop = inference_state.current_state.get("stop")

if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
finish_reason = FinishReason.STOP

if self._stop_token_generated(token, stop_tokens=stop):
print(
"Stop token %s generated. Stopping generation."
% self.tokenizer.decode(token)
)
finish_reason = FinishReason.STOP

if callback is not None and callback(token) is False:
print(
"callback %s returned False, stopping generation."
% callback.__qualname__
)
finish_reason = FinishReason.CALLBACK

max_tokens = inference_state.current_state.get("max_tokens")
if len(inference_state.current_state.get("generated_tokens")) + 1 == max_tokens:
finish_reason = inference_state.current_state.get("length_finish_reason")

state_update = {
"token_generator": token_generator,
}

new_generation = {
"logits": logits,
"new_token": token,
"finish_reason": finish_reason,
}
output = {"tokens": token_generator.tokens, "kv_cache": kv_cache}
output.update(new_generation)
return output, state_update

def _stop_token_generated(
self, token, stop_tokens: Union[None, str, Sequence[str]]
) -> bool:
if stop_tokens is None:
return False

decoded_token = self.tokenizer.decode(token)
decoded_token = (
decoded_token if decoded_token.isspace() else decoded_token.strip()
)
return decoded_token in stop_tokens
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _case_positions(self, num_total_processed_tokens: int):
)

def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs):
kv_cache.set_capacity(self.sequence_length - self.prompt_sequence_length)

onnx_input_names_no_cache = pipeline_state.current_state.get(
"onnx_input_names_no_cache"
Expand Down
8 changes: 7 additions & 1 deletion src/deepsparse/v2/text_generation/nl_engine_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class NlEngineInput(BaseModel):
engine_inputs: List = Field(description="engine inputs")
kv_cache: Any = Field(description="kv_cache object")
tokens: List = Field(description="tokens")
in_generation: bool = Field(description="in_generation", default=None)


class NLEngineOperator(EngineOperator):
Expand Down Expand Up @@ -119,7 +120,12 @@ def run(self, inp: NlEngineInput, **kwargs) -> Any:
kv_cache=kv_cache,
)

output = {"logits": logits, "kv_cache": kv_cache, "tokens": inp.tokens}
output = {
"logits": logits,
"kv_cache": kv_cache,
"tokens": inp.tokens,
"in_generation": inp.in_generation,
}
return output

def _add_kv_cache_to_input(self, engine_input, kv_cache):
Expand Down
Loading

0 comments on commit 5ee36d6

Please sign in to comment.