Skip to content

Commit

Permalink
add gpu utilization
Browse files Browse the repository at this point in the history
  • Loading branch information
Padraig20 committed Jun 14, 2024
1 parent 0ac34b5 commit 7603331
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/extractors/extract_icd_ndc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,30 @@
from tqdm import tqdm
import os

def get_embeddings_file(text_list, code_type = 'icd'):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_embeddings_file(text_list, code_type='icd'):
embeddings = []
file_path = f'../data/embeddings_{code_type}.npy'
if os.path.exists(file_path):
embeddings = np.load(file_path)
else:
for i in tqdm(range(0, len(text_list), 10)): # stepsize
batch = text_list[i:i+10]
tokens = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
for i in tqdm(range(0, len(text_list), 500)): # stepsize
batch = text_list[i:i+500]
tokens = tokenizer(batch, return_tensors='pt', padding=True, truncation=True).to(device)
with torch.no_grad():
outputs = model(**tokens)
batch_embeddings = outputs.last_hidden_state[:, 0, :].numpy()
batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
embeddings.extend(batch_embeddings)
embeddings = np.array(embeddings)
np.save(file_path, embeddings)
return embeddings

def get_embeddings(text_list):
tokens = tokenizer(text_list, return_tensors='pt', padding=True, truncation=True)
tokens = tokenizer(text_list, return_tensors='pt', padding=True, truncation=True).to(device)
with torch.no_grad():
outputs = model(**tokens)
embeddings = outputs.last_hidden_state[:, 0, :]
embeddings = outputs.last_hidden_state[:, 0, :].cpu()
return embeddings.numpy()

def load_icddata(file_path):
Expand Down Expand Up @@ -76,7 +78,7 @@ def find_nearest_ndc_code(entity, threshold=0.5):
# ------- MAIN ------- #

tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext").to(device)

cm_data = load_icddata('../data/icd10cm_codes_2024.txt')
pcs_data = load_icddata('../data/icd10pcs_codes_2024.txt')
Expand Down

0 comments on commit 7603331

Please sign in to comment.