Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep old hook functions #359

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions llm_bench/python/utils/hook_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
import warnings
import transformers
import torch.distributed as dist
import logging as log
import utils.hook_common as hook_common
from torch import nn
Expand All @@ -19,6 +18,7 @@
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.beam_search import BeamScorer
from transformers.utils import ModelOutput
import utils.hook_beam_search_old as hook_old_beam


class GenerateBeamDecoderOnlyOutput(ModelOutput):
Expand Down Expand Up @@ -204,8 +204,7 @@ def new_beam_search(
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_logits = output_logits if output_logits is not None else \
(self.generation_config.output_logits if hasattr(self.generation_config, 'output_logits') else None)
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
Expand All @@ -222,6 +221,8 @@ def new_beam_search(
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
Expand Down Expand Up @@ -256,7 +257,7 @@ def new_beam_search(

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder

while hook_common._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
tic = time.perf_counter()
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down Expand Up @@ -385,13 +386,8 @@ def new_beam_search(
# increase cur_len
cur_len = cur_len + 1
tm_list.append(time.perf_counter() - tic)
bStop = stopping_criteria(input_ids, scores)
if isinstance(bStop, bool):
if beam_scorer.is_done or bStop:
this_peer_finished = True
else:
if beam_scorer.is_done or all(bStop):
this_peer_finished = True
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
input_ids,
Expand Down Expand Up @@ -472,9 +468,8 @@ def new_forward(self, model, model_type=None):
if trans_version < min_version:
log.warning(f'The function of getting latency of beam search will not be available with current transformers version:{trans_version}')
else:
bound_method = new_beam_search.__get__(model, model.__class__)
min_second_version = version.parse(hook_common.TRANS_SENCOND_VERSION)
if trans_version >= min_second_version:
model._beam_search = bound_method
model._beam_search = new_beam_search.__get__(model, model.__class__)
else:
model.beam_search = bound_method
model.beam_search = hook_old_beam.old_beam_search.__get__(model, model.__class__)
Loading
Loading