Skip to content

Commit

Permalink
lint (#361)
Browse files Browse the repository at this point in the history
  • Loading branch information
hitenvidhani authored Sep 23, 2023
1 parent 0b80c30 commit dcc04e9
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions prompt2model/model_trainer/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
executor_batch_size: int = 10,
tokenizer_max_length: int = 512,
sequence_max_length: int = 1024,
trust_remote_code: bool = False,
):
"""Initializes a new instance of GenerationModelTrainer.
Expand All @@ -49,6 +50,8 @@ def __init__(
allowed to generate when being evaluated on validation dataset.
Note that sequence_max_length might be scaled in the ModelExecutor
if it exceeds the model's max_embedding.
trust_remote_code: This parameter controls whether the library should
trust remote code during model initialization or not.
"""
self.has_encoder = has_encoder
self.tokenizer_max_length = tokenizer_max_length
Expand All @@ -63,17 +66,19 @@ def __init__(
)
if self.has_encoder:
self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
pretrained_model_name
pretrained_model_name, trust_remote_code=trust_remote_code
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained_model_name
pretrained_model_name, trust_remote_code=trust_remote_code
)
else:
self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained_model_name
pretrained_model_name, trust_remote_code=trust_remote_code
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained_model_name, padding_side="left"
pretrained_model_name,
padding_side="left",
trust_remote_code=trust_remote_code,
)

if self.tokenizer.pad_token is None:
Expand Down

0 comments on commit dcc04e9

Please sign in to comment.