-
Notifications
You must be signed in to change notification settings - Fork 181
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Pipeline Refactor] Additional Operators, Route update and completed …
…generation functionality (#1356) * initial functionality and working example with image classification * remove testing image * rebase fixes * initial functionality and working example with image classification * text gen * updates func * prompt inference, initial functionality * remove image; update state docstring * Fix typo * add todo for split/join * remove context, clean-up args, remove prefill_preprocess_operaator * fix docstrings * initial functionality and working example with image classification * updates func * prompt inference, initial functionality * finish generation operators and update routes * further breakdown operators * add operators * fix can_operate condition * update can_operate to not rely on the inference_state * rebase + update * fix condition * fix capacity settting again * typo fixes
- Loading branch information
1 parent
d54ef26
commit 5ee36d6
Showing
14 changed files
with
529 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
src/deepsparse/v2/text_generation/compile_generated_tokens.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from deepsparse.v2.operators import Operator | ||
from deepsparse.v2.utils import InferenceState | ||
|
||
|
||
__all__ = ["CompileGeneratedTokens"] | ||
|
||
|
||
class CompileGeneratedTokens(Operator): | ||
def run( | ||
self, | ||
new_token, | ||
logits, | ||
finish_reason, | ||
kv_cache, | ||
tokens, | ||
inference_state: InferenceState, | ||
**kwargs, | ||
): | ||
in_generation = True | ||
|
||
generated_tokens = inference_state.current_state.get("generated_tokens") | ||
generated_logits = inference_state.current_state.get("generated_logits") | ||
finished_reason = inference_state.current_state.get("finished_reason") | ||
|
||
generated_tokens.append(new_token) | ||
generated_logits.append(logits) | ||
finished_reason.append(finish_reason) | ||
|
||
if finish_reason is not None: | ||
in_generation = False | ||
|
||
state_update = { # TODO: check if necessary | ||
"finished_reason": finished_reason, | ||
"generated_tokens": generated_tokens, | ||
"generated_logits": generated_logits, | ||
} | ||
|
||
output = { | ||
"tokens": tokens, | ||
"kv_cache": kv_cache, | ||
"in_generation": in_generation, | ||
} | ||
return output, state_update |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any | ||
|
||
import numpy | ||
from pydantic import BaseModel, Field | ||
|
||
from deepsparse.transformers.pipelines.text_generation import FinishReason | ||
from deepsparse.v2.operators import Operator | ||
from deepsparse.v2.utils import InferenceState | ||
|
||
|
||
__all__ = ["CompileGenerations", "CompileGenerationsOutput"] | ||
|
||
|
||
class CompileGenerationsOutput(BaseModel): | ||
generated_tokens: Any = Field(description="generated_tokens") | ||
generated_logits: Any = Field(description="generated_logits") | ||
finished_reason: Any = Field(description="finished_reason") | ||
|
||
|
||
class CompileGenerations(Operator): | ||
output_schema = CompileGenerationsOutput | ||
|
||
def can_operate(self, inp: Any): | ||
if inp.get("in_generation") is False: | ||
return True | ||
return False | ||
|
||
def run(self, inference_state: InferenceState, **kwargs): | ||
generated_tokens = inference_state.current_state.get("generated_tokens") | ||
generated_logits = inference_state.current_state.get("generated_logits") | ||
finished_reason = inference_state.current_state.get("finished_reason") | ||
|
||
if len(finished_reason) == 0: | ||
finished_reason.append(FinishReason.LENGTH) | ||
|
||
generated_tokens = numpy.array([generated_tokens]) | ||
generated_logits = numpy.concatenate(generated_logits, axis=1) | ||
return { | ||
"generated_tokens": generated_tokens, | ||
"generated_logits": generated_logits, | ||
"finished_reason": finished_reason, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Sequence, Union | ||
|
||
import transformers | ||
|
||
from deepsparse.transformers.pipelines.text_generation import FinishReason | ||
from deepsparse.v2.operators import Operator | ||
from deepsparse.v2.utils import InferenceState | ||
|
||
|
||
__all__ = ["GenerateNewTokenOperator"] | ||
|
||
|
||
class GenerateNewTokenOperator(Operator): | ||
def __init__( | ||
self, tokenizer: transformers.PreTrainedTokenizerBase, force_max_tokens: bool | ||
): | ||
self.force_max_tokens = force_max_tokens | ||
self.tokenizer = tokenizer | ||
|
||
def can_operate(self, inp: Any): | ||
if inp.get("in_generation"): | ||
return True | ||
return False | ||
|
||
def run(self, logits, kv_cache, inference_state: InferenceState, **kwargs): | ||
token_generator = inference_state.current_state.get("token_generator") | ||
token = token_generator.generate(logits=logits[0, -1, :]) | ||
finish_reason = None | ||
|
||
callback = inference_state.current_state.get("callback") | ||
stop = inference_state.current_state.get("stop") | ||
|
||
if token == self.tokenizer.eos_token_id and not self.force_max_tokens: | ||
finish_reason = FinishReason.STOP | ||
|
||
if self._stop_token_generated(token, stop_tokens=stop): | ||
print( | ||
"Stop token %s generated. Stopping generation." | ||
% self.tokenizer.decode(token) | ||
) | ||
finish_reason = FinishReason.STOP | ||
|
||
if callback is not None and callback(token) is False: | ||
print( | ||
"callback %s returned False, stopping generation." | ||
% callback.__qualname__ | ||
) | ||
finish_reason = FinishReason.CALLBACK | ||
|
||
max_tokens = inference_state.current_state.get("max_tokens") | ||
if len(inference_state.current_state.get("generated_tokens")) + 1 == max_tokens: | ||
finish_reason = inference_state.current_state.get("length_finish_reason") | ||
|
||
state_update = { | ||
"token_generator": token_generator, | ||
} | ||
|
||
new_generation = { | ||
"logits": logits, | ||
"new_token": token, | ||
"finish_reason": finish_reason, | ||
} | ||
output = {"tokens": token_generator.tokens, "kv_cache": kv_cache} | ||
output.update(new_generation) | ||
return output, state_update | ||
|
||
def _stop_token_generated( | ||
self, token, stop_tokens: Union[None, str, Sequence[str]] | ||
) -> bool: | ||
if stop_tokens is None: | ||
return False | ||
|
||
decoded_token = self.tokenizer.decode(token) | ||
decoded_token = ( | ||
decoded_token if decoded_token.isspace() else decoded_token.strip() | ||
) | ||
return decoded_token in stop_tokens |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.