-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
30 lines (21 loc) · 857 Bytes
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from typing import List
class Dataset:
def __init__(self, tokenizer, sdrs: np.ndarray[np.bool_], dataset_str: str="NeelNanda/pile-10k"):
self.tokenizer = tokenizer
self.dataset_str = dataset_str
self.sdrs = sdrs
def __iter__(self):
dataset: List[str] = load_dataset(self.dataset_str)['train']['text']
size = len(dataset)
bos_token: str = self.tokenizer.bos_token
idxs = np.random.permutation(size)
for i in tqdm(idxs):
prompt: str = bos_token + dataset[i]
tokens: List[int] = self.tokenizer.encode(prompt)
for tok in tokens:
sdr: np.ndarray[np.bool_] = self.sdrs[tok] # (N)
yield sdr