Skip to content

请教一下generate()方法中prefix参数。 #4

@kk-dark

Description

@kk-dark
def caption_generation(image_feature, model: GPT2LMHeadModel, tokenizer, device):
	text = "prefix prefix prefix prefix prefix:"
	inputs = tokenizer(text, return_tensors="pt")
	output = model.generate(inputs["input_ids"].to(device), 40, prefix = image_feature, do_sample = False, num_beams=5)[0]
	output = tokenizer.decode(output)
	return output.split(':')[1].split('.')[0].lower()

如上这段代码model.generate()方法中用到了一个prefix参数,我在查阅Huggingface的文档中并没有找到关于prefix参数的解释。

在modeling_gpt2.py文件中,我找到了如下部分代码:

def forward(
        ...
        prefix: Optional[torch.FloatTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
        ...

以及:

...
if inputs_embeds is None:
    inputs_embeds = self.wte(input_ids)
if prefix != None:
    prefix = prefix.expand(inputs_embeds.shape[0], 5, inputs_embeds.shape[2])
    inputs_embeds = torch.cat((prefix, inputs_embeds[:, 5:, :]), dim = 1)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
...

这段部分的添加应该是作者的修改对吗?期待您的回复。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions