-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
def caption_generation(image_feature, model: GPT2LMHeadModel, tokenizer, device):
text = "prefix prefix prefix prefix prefix:"
inputs = tokenizer(text, return_tensors="pt")
output = model.generate(inputs["input_ids"].to(device), 40, prefix = image_feature, do_sample = False, num_beams=5)[0]
output = tokenizer.decode(output)
return output.split(':')[1].split('.')[0].lower()
如上这段代码model.generate()方法中用到了一个prefix参数,我在查阅Huggingface的文档中并没有找到关于prefix参数的解释。
在modeling_gpt2.py文件中,我找到了如下部分代码:
def forward(
...
prefix: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
...
以及:
...
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if prefix != None:
prefix = prefix.expand(inputs_embeds.shape[0], 5, inputs_embeds.shape[2])
inputs_embeds = torch.cat((prefix, inputs_embeds[:, 5:, :]), dim = 1)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
...
这段部分的添加应该是作者的修改对吗?期待您的回复。
Metadata
Metadata
Assignees
Labels
No labels