Skip to content

Commit 5badc10

Browse files
committed
haflway through modularizing
1 parent c1e1b2f commit 5badc10

File tree

2 files changed

+74
-137
lines changed

2 files changed

+74
-137
lines changed

torch_geometric/nn/models/g_retriever.py

Lines changed: 10 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -179,49 +179,12 @@ def forward(
179179
batch_label_input_ids = []
180180
num_nodes_per_graph = ptr[1:] - ptr[:-1]
181181
for i in range(batch_size):
182-
# Add bos & eos token
183-
label_input_ids = labels.input_ids[
184-
i][:MAX_NEW_TOKENS] + eos_tokens.input_ids
185-
if additional_text_context is not None:
186-
input_ids = context.input_ids[
187-
i][:MAX_TXT_LEN] + questions.input_ids[
188-
i] + eos_user_tokens.input_ids + label_input_ids
189-
else:
190-
input_ids = questions.input_ids[
191-
i] + eos_user_tokens.input_ids + label_input_ids
192-
inputs_embeds = self.word_embedding(
193-
torch.tensor(input_ids).to(self.llm_device))
194-
to_cat = [bos_embeds]
195-
if num_nodes_per_graph[i] != 0:
196-
to_cat.append(graph_embeds[i].unsqueeze(0))
197-
to_cat.append(inputs_embeds)
198-
inputs_embeds = torch.cat([i.to(self.llm_device) for i in to_cat],
199-
dim=0)
200-
batch_inputs_embeds.append(inputs_embeds)
201-
batch_attention_mask.append([1] * inputs_embeds.shape[0])
202-
label_input_ids = [IGNORE_INDEX
203-
] * (inputs_embeds.shape[0] -
204-
len(label_input_ids)) + label_input_ids
205-
batch_label_input_ids.append(label_input_ids)
206-
207-
# pad inputs_embeds
208-
max_length = max([x.shape[0] for x in batch_inputs_embeds])
209-
for i in range(batch_size):
210-
pad_length = max_length - batch_inputs_embeds[i].shape[0]
211-
batch_inputs_embeds[i] = torch.cat([
212-
pad_embeds.repeat(pad_length, 1).to(self.llm_device),
213-
batch_inputs_embeds[i].to(self.llm_device)
214-
])
215-
batch_attention_mask[i] = [0
216-
] * pad_length + batch_attention_mask[i]
217-
batch_label_input_ids[
218-
i] = [IGNORE_INDEX] * pad_length + batch_label_input_ids[i]
219-
220-
inputs_embeds = torch.stack(batch_inputs_embeds,
221-
dim=0).to(self.llm_device)
222-
attention_mask = torch.tensor(batch_attention_mask).to(self.llm_device)
223-
label_input_ids = torch.tensor(batch_label_input_ids).to(
224-
self.llm_device)
182+
label_input_ids = self.llm_to_use._label_input_ids(label, eos_tokens)
183+
input_ids = self.llm_to_use._input_ids(additional_text_context, question, eos_user_tokens)
184+
input_ids += label_input_ids
185+
inputs_embeds = self.llm_to_use._inputs_embeds(input_ids, bos_embeds, graph_embeds[i].unsqueeze(0) if num_nodes_per_graph[i] != 0 else None)
186+
batch_inputs_embeds, batch_attention_mask, batch_label_input_ids = self.llm_to_use.append_embeds(inputs_embeds, batch_inputs_embeds, batch_attention_mask, label_input_ids, batch_label_input_ids)
187+
inputs_embeds, attention_mask, label_input_ids = self.llm_to_use.pad_embeds(batch_inputs_embeds, batch_attention_mask, batch_label_input_ids)
225188
with self.llm_to_use.autocast_context:
226189
outputs = self.llm_generator(
227190
inputs_embeds=inputs_embeds,
@@ -274,36 +237,11 @@ def inference(
274237
batch_attention_mask = []
275238
num_nodes_per_graph = ptr[1:] - ptr[:-1]
276239
for i in range(batch_size):
277-
# Add bos & eos token
278-
if additional_text_context is not None:
279-
input_ids = context.input_ids[
280-
i][:MAX_TXT_LEN] + questions.input_ids[
281-
i] + eos_user_tokens.input_ids
282-
else:
283-
input_ids = questions.input_ids[i] + eos_user_tokens.input_ids
284-
inputs_embeds = self.word_embedding(
285-
torch.tensor(input_ids).to(self.llm_device))
286-
to_cat = [bos_embeds]
287-
if num_nodes_per_graph[i] != 0:
288-
to_cat.append(graph_embeds[i].unsqueeze(0))
289-
to_cat.append(inputs_embeds)
290-
inputs_embeds = torch.cat([i.to(self.llm_device) for i in to_cat],
291-
dim=0)
292-
batch_inputs_embeds.append(inputs_embeds)
293-
batch_attention_mask.append([1] * inputs_embeds.shape[0])
294-
295-
# pad inputs_embeds
296-
max_length = max([x.shape[0] for x in batch_inputs_embeds])
297-
for i in range(batch_size):
298-
pad_length = max_length - batch_inputs_embeds[i].shape[0]
299-
batch_inputs_embeds[i] = torch.cat(
300-
[pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]])
301-
batch_attention_mask[i] = [0
302-
] * pad_length + batch_attention_mask[i]
240+
input_ids = self.llm_to_use._input_ids(additional_text_context, question, eos_user_tokens)
241+
inputs_embeds = self.llm_to_use._inputs_embeds(input_ids, bos_embeds, graph_embeds[i].unsqueeze(0) if num_nodes_per_graph[i] != 0 else None)
242+
batch_inputs_embeds, batch_attention_mask, _ = self.llm_to_use.append_embeds(inputs_embeds, batch_inputs_embeds, batch_attention_mask)
303243

304-
inputs_embeds = torch.stack(batch_inputs_embeds,
305-
dim=0).to(self.llm_device)
306-
attention_mask = torch.tensor(batch_attention_mask).to(self.llm_device)
244+
inputs_embeds, attention_mask, _ = self.llm_to_use.pad_embeds(batch_inputs_embeds, batch_attention_mask)
307245

308246
with self.llm_to_use.autocast_context:
309247
outputs = self.llm_generator.generate(

torch_geometric/nn/nlp/llm.py

Lines changed: 64 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,61 @@ def _encode_inputs(
117117
return (batch_size, questions, context, eos_user_tokens, bos_embeds,
118118
pad_embeds)
119119

120+
def _label_input_ids(self, label, eos_tokens):
121+
label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS]
122+
label_input_ids += eos_tokens.input_ids # Add EOS token.
123+
return label_input_ids
124+
125+
def _input_ids(self, context, question, eos_user_tokens):
126+
input_ids: List[int] = []
127+
if context is not None:
128+
input_ids += context.input_ids[i][:MAX_TXT_LEN]
129+
input_ids += question.input_ids[i]
130+
input_ids += eos_user_tokens.input_ids
131+
return input_ids
132+
133+
def _inputs_embeds(self, input_ids, bos_embeds, embedding=None):
134+
inputs_embeds = self.word_embedding(
135+
torch.tensor(input_ids, device=self.llm_device))
136+
137+
to_cat = [bos_embeds]
138+
if embedding is not None:
139+
to_cat.append(embedding[i])
140+
to_cat.append(inputs_embeds)
141+
inputs_embeds = torch.cat([i.to(self.llm_device) for i in to_cat], dim=0)
142+
return inputs_embeds
143+
144+
def append_embeds(self, inputs_embeds, batch_inputs_embeds, batch_attention_mask, label_input_ids=None, batch_label_input_ids=None):
145+
batch_inputs_embeds.append(inputs_embeds)
146+
batch_attention_mask.append([1] * inputs_embeds.size(0))
147+
if label_input_ids is not None:
148+
label_input_ids = [IGNORE_INDEX] * (
149+
inputs_embeds.size(0) - len(label_input_ids)) + label_input_ids
150+
batch_label_input_ids.append(label_input_ids)
151+
return batch_inputs_embeds, batch_attention_mask, batch_label_input_ids
152+
153+
def pad_embeds(batch_inputs_embeds, batch_attention_mask, batch_label_input_ids=None):
154+
max_length = max([x.size(0) for x in batch_inputs_embeds])
155+
for i in range(batch_size):
156+
pad = max_length - batch_inputs_embeds[i].size(0)
157+
batch_inputs_embeds[i] = torch.cat([
158+
pad_embeds.repeat(pad, 1),
159+
batch_inputs_embeds[i],
160+
])
161+
batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
162+
if batch_label_input_ids is not None:
163+
batch_label_input_ids[i] = ([IGNORE_INDEX] * pad +
164+
batch_label_input_ids[i])
165+
inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
166+
attention_mask = torch.tensor(batch_attention_mask,
167+
device=self.llm_device)
168+
if batch_label_input_ids is not None:
169+
label_input_ids = torch.tensor(batch_label_input_ids,
170+
device=self.llm_device)
171+
else:
172+
label_input_ids = None
173+
return inputs_embeds, attention_mask, label_input_ids
174+
120175
def forward(
121176
self,
122177
question: List[str],
@@ -150,48 +205,15 @@ def forward(
150205
batch_attention_mask = []
151206
batch_label_input_ids = []
152207
for i in range(batch_size):
153-
label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS]
154-
label_input_ids += eos_tokens.input_ids # Add EOS token.
155-
156-
input_ids: List[int] = []
157-
if context is not None:
158-
input_ids += context.input_ids[i][:MAX_TXT_LEN]
159-
input_ids += question.input_ids[i]
160-
input_ids += eos_user_tokens.input_ids
208+
label_input_ids = self._label_input_ids(label, eos_tokens)
209+
input_ids = self._input_ids(context, question, eos_user_tokens)
161210
input_ids += label_input_ids
162211

163-
inputs_embeds = self.word_embedding(
164-
torch.tensor(input_ids, device=self.llm_device))
165-
166-
to_cat = [bos_embeds]
167-
if embedding is not None:
168-
to_cat.append(embedding[i])
169-
to_cat.append(inputs_embeds)
170-
inputs_embeds = torch.cat(to_cat, dim=0)
171-
172-
batch_inputs_embeds.append(inputs_embeds)
173-
batch_attention_mask.append([1] * inputs_embeds.size(0))
174-
label_input_ids = [IGNORE_INDEX] * (
175-
inputs_embeds.size(0) - len(label_input_ids)) + label_input_ids
176-
batch_label_input_ids.append(label_input_ids)
212+
inputs_embeds = self._inputs_embeds(input_ids, bos_embeds, embedding)
177213

178-
# Pad input embeddings:
179-
max_length = max([x.size(0) for x in batch_inputs_embeds])
180-
for i in range(batch_size):
181-
pad = max_length - batch_inputs_embeds[i].size(0)
182-
batch_inputs_embeds[i] = torch.cat([
183-
pad_embeds.repeat(pad, 1),
184-
batch_inputs_embeds[i],
185-
])
186-
batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
187-
batch_label_input_ids[i] = ([IGNORE_INDEX] * pad +
188-
batch_label_input_ids[i])
214+
batch_inputs_embeds, batch_attention_mask, batch_label_input_ids = self.append_embeds(inputs_embeds, batch_inputs_embeds, batch_attention_mask, label_input_ids, batch_label_input_ids)
189215

190-
inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
191-
attention_mask = torch.tensor(batch_attention_mask,
192-
device=self.llm_device)
193-
label_input_ids = torch.tensor(batch_label_input_ids,
194-
device=self.llm_device)
216+
inputs_embeds, attention_mask, label_input_ids = self.pad_embeds(batch_inputs_embeds, batch_attention_mask, batch_label_input_ids)
195217

196218
with self.autocast_context:
197219
outputs = self.llm(
@@ -235,37 +257,14 @@ def inference(
235257
batch_attention_mask = []
236258
for i in range(batch_size):
237259
input_ids: List[int] = []
238-
if context is not None:
239-
input_ids = context.input_ids[i][:MAX_TXT_LEN]
240-
input_ids += question.input_ids[i]
241-
input_ids += eos_user_tokens.input_ids
242-
243-
inputs_embeds = self.word_embedding(
244-
torch.tensor(input_ids, device=self.llm_device))
260+
input_ids = self._input_ids(context, question, eos_user_tokens)
245261

246-
to_cat = [bos_embeds]
247-
if embedding is not None:
248-
to_cat.append(embedding[i])
249-
to_cat.append(inputs_embeds)
250-
inputs_embeds = torch.cat(to_cat, dim=0)
262+
inputs_embeds = self._inputs_embeds(input_ids, bos_embeds, embedding)
251263

252-
batch_inputs_embeds.append(inputs_embeds)
253-
batch_attention_mask.append([1] * inputs_embeds.size(0))
264+
batch_inputs_embeds, batch_attention_mask, _ = self.append_embeds(inputs_embeds, batch_inputs_embeds, batch_attention_mask)
254265

255-
# Pad input embeddings:
256-
max_length = max([x.size(0) for x in batch_inputs_embeds])
257-
for i in range(batch_size):
258-
pad = max_length - batch_inputs_embeds[i].size(0)
259-
batch_inputs_embeds[i] = torch.cat([
260-
pad_embeds.repeat(pad, 1),
261-
batch_inputs_embeds[i],
262-
])
263-
batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
264-
265-
inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
266-
attention_mask = torch.tensor(batch_attention_mask,
267-
device=self.llm_device)
268266

267+
inputs_embeds, attention_mask, _ = self.pad_embeds(batch_inputs_embeds, batch_attention_mask)
269268
bos_token = self.tokenizer(
270269
BOS,
271270
add_special_tokens=False,

0 commit comments

Comments
 (0)