Skip to content

Commit

Permalink
241126
Browse files Browse the repository at this point in the history
  • Loading branch information
ssocean committed Nov 25, 2024
1 parent ad96b64 commit 050cdce
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 857 deletions.
38 changes: 38 additions & 0 deletions NAID/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

import torch
from torch.utils.data import Dataset

class TextDataset(Dataset):
def __init__(self, data, tokenizer, max_length=1024, prompt_style=0):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
self.prompt_style = prompt_style

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
row = self.data.iloc[idx]
label = float(row['TNCSI_SP'])
if self.prompt_style == 0:
text = f'''Given a certain paper, Title: {row['title']}\n Abstract: {row['abstract']}. \n Predict its normalized academic impact (between 0 and 1):'''
inputs = self.tokenizer(text, max_length=self.max_length, padding='max_length', truncation=True,
return_tensors="pt")
return {
'input_ids': inputs['input_ids'].squeeze(0),
'attention_mask': inputs['attention_mask'].squeeze(0),
'labels': torch.tensor(label, dtype=torch.float)
}
elif self.prompt_style == 1:
label = float(row['TNCSI'])
text = f'''Given a certain paper, Title: {row['title']}\n Abstract: {row['abstract']}. \n Predict its normalized academic impact (between 0 and 1):'''
inputs = self.tokenizer(text, max_length=self.max_length, padding='max_length', truncation=True,
return_tensors="pt")
return {
'input_ids': inputs['input_ids'].squeeze(0),
'attention_mask': inputs['attention_mask'].squeeze(0),
'labels': torch.tensor(label, dtype=torch.float)
}
else:
raise NotImplementedError('prompt_style not registed in NAID/dataset.py')
Loading

0 comments on commit 050cdce

Please sign in to comment.