Skip to content

Commit

Permalink
fix: pass do sample generation
Browse files Browse the repository at this point in the history
  • Loading branch information
lorr1 committed Jan 17, 2024
1 parent 637fb14 commit 40a106a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions manifest/api/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __call__(
top_p=kwargs.get("top_p"),
repetition_penalty=kwargs.get("repetition_penalty"),
num_return_sequences=kwargs.get("num_return_sequences"),
do_sample=kwargs.get("do_sample"),
)
kwargs_to_pass = {k: v for k, v in kwargs_to_pass.items() if v is not None}
output_dict = self.model.generate( # type: ignore
Expand Down Expand Up @@ -587,7 +588,7 @@ def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[int], List[float]]]:
) -> List[Tuple[Any, float, List[str], List[float]]]:
"""
Generate the prompt from model.
Expand Down Expand Up @@ -616,7 +617,7 @@ def generate(
(
cast(str, r["generated_text"]),
sum(cast(List[float], r["logprobs"])),
cast(List[int], r["tokens"]),
cast(List[str], r["tokens"]),
cast(List[float], r["logprobs"]),
)
for r in result
Expand Down

0 comments on commit 40a106a

Please sign in to comment.