diff --git a/torch_geometric/nn/nlp/llm.py b/torch_geometric/nn/nlp/llm.py index 73e735ec444d..6dcdf68b3227 100644 --- a/torch_geometric/nn/nlp/llm.py +++ b/torch_geometric/nn/nlp/llm.py @@ -172,6 +172,7 @@ def pad_embeds(batch_inputs_embeds, batch_attention_mask, batch_label_input_ids= label_input_ids = None return inputs_embeds, attention_mask, label_input_ids + def _get_embeds def forward( self, question: List[str],