Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Logging in RemoteGenerationMixin for Better Debugging #612

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/petals/client/remote_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,18 @@ def generate(
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
):
self._fix_generate_kwargs(kwargs)
logger.debug("Entered generate method with kwargs: %s", kwargs) # Added logging
if inputs is None:
inputs = kwargs.pop("input_ids", None)

if session is not None:
# If a session specified explicitly, use it
context_manager = self.use_session(session)
logger.debug("Using specified session: %s", session) # Added logging
elif self.active_session is not None:
# If there's an active session, don't do anything
context_manager = contextlib.nullcontext(self.active_session)
logger.debug("Using active session: %s", self.active_session) # Added logging
else:
# If there's no active session, create a new one

Expand All @@ -109,6 +112,7 @@ def generate(
else:
session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
context_manager = self.inference_session(max_length=session_max_length)
logger.debug("Created new session with max length: %d", session_max_length) # Added logging

with context_manager as session:
# Prepend the tokens from the previous .generate() call
Expand All @@ -134,7 +138,9 @@ def generate(
past_key_values.update_seen(session.position)
kwargs["past_key_values"] = past_key_values

logger.debug("Starting generation with input ids: %s", inputs) # Added logging
result = super().generate(inputs, *args, **kwargs)
logger.debug("Generated result: %s", result) # Added logging

sequences = result.sequences if isinstance(result, ModelOutput) else result
# Save tokens from this .generate() call
Expand Down Expand Up @@ -162,3 +168,5 @@ def _fix_generate_kwargs(kwargs: dict):
@staticmethod
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)