-
Notifications
You must be signed in to change notification settings - Fork 4
/
embed_pubmed_st.py
137 lines (103 loc) · 4.39 KB
/
embed_pubmed_st.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from transformers import AutoTokenizer, AutoModel
import os
import torch
import sqlite3
import numpy as np
from tqdm import tqdm
from array_io import write_pair_to_file, read_pair_from_file
# file paths
DB_FILE="pubmed_data.db"
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
model = AutoModel.from_pretrained("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True).to(device)
model.eval() # Set the model to evaluation mode
# Function to embed text snippets
def embed_texts(texts, tokenizer, model, device):
with torch.no_grad():
inputs = tokenizer(
["search_document: " + text for text in texts],
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
outputs = model(**inputs)
# Get the attention mask and convert it to float
attention_mask = inputs['attention_mask'].unsqueeze(-1).expand(outputs.last_hidden_state.size()).float()
# Mask the hidden states
masked_hidden_states = outputs.last_hidden_state * attention_mask
# Calculate the sum of the hidden states and the sum of the attention mask
summed_hidden_states = masked_hidden_states.sum(dim=1)
summed_mask = attention_mask.sum(dim=1)
# Calculate the mean by dividing the summed hidden states by the summed mask
embeddings = (summed_hidden_states / summed_mask).cpu().numpy()
return embeddings
# Function to process a chunk of records
def process_chunk(cursor, tokenizer, model, start, chunk_size, device):
records = cursor.fetchmany(chunk_size)
if not records:
return False # No more records to process
# Use title if abstract is None or empty
pmids = []
texts = []
for pmid, title, abstract in records:
if abstract is None or abstract.strip() == "":
texts.append(title)
else:
texts.append(abstract)
pmids.append(int(pmid))
start_idx = process_chunk.start_idx
if np.sum(process_chunk.pmid_mmap[start_idx:start_idx + len(pmids)]) > 0:
#print(f'processed: {start_idx} -- {start_idx + len(pmids)}')
process_chunk.start_idx += len(pmids)
return True
# Embed the texts
embeddings = embed_texts(texts, tokenizer, model, device)
# Save PMIDs and embeddings to memory-mapped arrays
process_chunk.pmid_mmap[start_idx:start_idx + len(pmids)] = pmids
process_chunk.embedding_mmap[start_idx:start_idx + len(pmids), :] = embeddings
process_chunk.start_idx += len(pmids)
return True # Records processed
# Function to get the saved offset
def get_saved_offset(offset_file):
if os.path.exists(offset_file):
with open(offset_file, 'r') as f:
return int(f.read().strip())
return 0
# Function to save the current offset
def save_offset(offset_file, offset):
with open(offset_file, 'w') as f:
f.write(str(offset))
# Connect to SQLite database
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
# Get the total number of records
#cursor.execute("SELECT COUNT(*) FROM articles")
#total_records = cursor.fetchone()[0]
total_records = 36_510_005 # the correct number
# Define chunk size
chunk_size = 10
# Determine the embedding size (dimensionality)
dummy_embedding = embed_texts(["dummy text"], tokenizer, model, device)
embedding_dim = dummy_embedding.shape[1]
print(f"Embedding dim {embedding_dim}")
# Create memory-mapped files for PMIDs and embeddings
# be careful of the mode
pmid_mmap = np.memmap('pmids_test.dat', dtype='int64', mode='r+', shape=(total_records,))
embedding_mmap = np.memmap('embeddings_test.dat', dtype='float32', mode='r+', shape=(total_records, embedding_dim))
start_offset = 0
cursor = cursor.execute(f"SELECT pmid, title, abstract FROM articles")
# Attach memory-mapped arrays and start index to the function
process_chunk.pmid_mmap = pmid_mmap
process_chunk.embedding_mmap = embedding_mmap
process_chunk.start_idx = 0
# Process records in chunks
for start in tqdm(range(start_offset, total_records, chunk_size)):
process_chunk(cursor, tokenizer, model, start, chunk_size, device)
# Flush changes to disk
pmid_mmap.flush()
embedding_mmap.flush()
# Close the connection
conn.close()
print("Embedding process completed and saved to a binary file.")